/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java.spark.rapids;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.Table;
import java.util.ArrayList;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType;

public class GpuColumnBatch {
    private final Table table;
    private final StructType schema;

    public GpuColumnBatch(Table table, StructType schema) {
        this.table = table;
        this.schema = schema;
    }

    public StructType getSchema() {
        return this.schema;
    }

    public long getNumRows() {
        return this.table.getRowCount();
    }

    public int getNumColumns() {
        return this.table.getNumberOfColumns();
    }

    public ColumnVector getColumnVector(int index) {
        return this.table.getColumn(index);
    }

    public long getColumn(int index) {
        ColumnVector v = this.table.getColumn(index);
        return v.getNativeCudfColumnAddress();
    }

    public int[] groupByColumnWithCountHost(int groupIndex) {
        ColumnVector cv = this.getColumnVector(groupIndex);
        cv.ensureOnHost();
        ArrayList<Integer> countData = new ArrayList<Integer>();
        int groupId = 0;
        int groupSize = 0;
        int i = 0;
        while ((long)i < cv.getRowCount()) {
            if (groupId != cv.getInt((long)i)) {
                groupId = cv.getInt((long)i);
                if (groupSize > 0) {
                    countData.add(groupSize);
                }
                groupSize = 1;
            } else {
                ++groupSize;
            }
            ++i;
        }
        if (groupSize > 0) {
            countData.add(groupSize);
        }
        int[] counts = new int[countData.size()];
        for (int i2 = 0; i2 < counts.length; ++i2) {
            counts[i2] = (Integer)countData.get(i2);
        }
        return counts;
    }

    public long[] groupByColumnWithAggregation(int groupIndex, int oneIndex, boolean checkEqual) {
        ColumnVector cv = this.getColumnVector(groupIndex);
        cv.ensureOnHost();
        ColumnVector aggrCV = this.getColumnVector(oneIndex);
        aggrCV.ensureOnHost();
        ArrayList<Float> onesData = new ArrayList<Float>();
        int groupId = 0;
        Float oneValue = null;
        int i = 0;
        while ((long)i < cv.getRowCount()) {
            if (groupId != cv.getInt((long)i)) {
                groupId = cv.getInt((long)i);
                if (oneValue != null) {
                    onesData.add(oneValue);
                }
                oneValue = Float.valueOf(aggrCV.getFloat((long)i));
            } else if (checkEqual && oneValue != null && oneValue.floatValue() != aggrCV.getFloat((long)i)) {
                return null;
            }
            ++i;
        }
        if (oneValue != null) {
            onesData.add(oneValue);
        }
        ColumnVector retCV = ColumnVector.fromBoxedFloats((Float[])onesData.toArray(new Float[0]));
        retCV.ensureOnDevice();
        return new long[]{retCV.getNativeCudfColumnAddress()};
    }

    public static DType getRapidsType(DataType type) {
        DType result = GpuColumnBatch.toRapidsOrNull(type);
        if (result == null) {
            throw new IllegalArgumentException(type + " is not supported for GPU processing yet.");
        }
        return result;
    }

    private static DType toRapidsOrNull(DataType type) {
        if (type instanceof LongType) {
            return DType.INT64;
        }
        if (type instanceof DoubleType) {
            return DType.FLOAT64;
        }
        if (type instanceof ByteType) {
            return DType.INT8;
        }
        if (type instanceof BooleanType) {
            return DType.BOOL8;
        }
        if (type instanceof ShortType) {
            return DType.INT16;
        }
        if (type instanceof IntegerType) {
            return DType.INT32;
        }
        if (type instanceof FloatType) {
            return DType.FLOAT32;
        }
        if (type instanceof DateType) {
            return DType.DATE32;
        }
        if (type instanceof TimestampType) {
            return DType.TIMESTAMP;
        }
        if (type instanceof StringType) {
            return DType.STRING;
        }
        return null;
    }
}

