package ml.dmlc.xgboost4j.java.spark.rapids;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.Table;
import java.util.ArrayList;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/spark/rapids/GpuColumnBatch.class */
public class GpuColumnBatch {
    private final Table table;
    private final StructType schema;

    public GpuColumnBatch(Table table, StructType structType) {
        this.table = table;
        this.schema = structType;
    }

    public StructType getSchema() {
        return this.schema;
    }

    public long getNumRows() {
        return this.table.getRowCount();
    }

    public int getNumColumns() {
        return this.table.getNumberOfColumns();
    }

    public ColumnVector getColumnVector(int i) {
        return this.table.getColumn(i);
    }

    public long getColumn(int i) {
        return this.table.getColumn(i).getNativeCudfColumnAddress();
    }

    public int[] groupByColumnWithCountHost(int i) {
        ColumnVector columnVector = getColumnVector(i);
        columnVector.ensureOnHost();
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < columnVector.getRowCount(); i4++) {
            if (i2 != columnVector.getInt(i4)) {
                i2 = columnVector.getInt(i4);
                if (i3 > 0) {
                    arrayList.add(Integer.valueOf(i3));
                }
                i3 = 1;
            } else {
                i3++;
            }
        }
        if (i3 > 0) {
            arrayList.add(Integer.valueOf(i3));
        }
        int[] iArr = new int[arrayList.size()];
        for (int i5 = 0; i5 < iArr.length; i5++) {
            iArr[i5] = ((Integer) arrayList.get(i5)).intValue();
        }
        return iArr;
    }

    public long[] groupByColumnWithAggregation(int i, int i2, boolean z) {
        ColumnVector columnVector = getColumnVector(i);
        columnVector.ensureOnHost();
        ColumnVector columnVector2 = getColumnVector(i2);
        columnVector2.ensureOnHost();
        ArrayList arrayList = new ArrayList();
        int i3 = 0;
        Float f = null;
        for (int i4 = 0; i4 < columnVector.getRowCount(); i4++) {
            if (i3 != columnVector.getInt(i4)) {
                i3 = columnVector.getInt(i4);
                if (f != null) {
                    arrayList.add(f);
                }
                f = Float.valueOf(columnVector2.getFloat(i4));
            } else if (z && f != null && f.floatValue() != columnVector2.getFloat(i4)) {
                return null;
            }
        }
        if (f != null) {
            arrayList.add(f);
        }
        ColumnVector fromBoxedFloats = ColumnVector.fromBoxedFloats((Float[]) arrayList.toArray(new Float[0]));
        fromBoxedFloats.ensureOnDevice();
        return new long[]{fromBoxedFloats.getNativeCudfColumnAddress()};
    }

    public static DType getRapidsType(DataType dataType) {
        DType rapidsOrNull = toRapidsOrNull(dataType);
        if (rapidsOrNull == null) {
            throw new IllegalArgumentException(dataType + " is not supported for GPU processing yet.");
        }
        return rapidsOrNull;
    }

    private static DType toRapidsOrNull(DataType dataType) {
        if (dataType instanceof LongType) {
            return DType.INT64;
        }
        if (dataType instanceof DoubleType) {
            return DType.FLOAT64;
        }
        if (dataType instanceof ByteType) {
            return DType.INT8;
        }
        if (dataType instanceof BooleanType) {
            return DType.BOOL8;
        }
        if (dataType instanceof ShortType) {
            return DType.INT16;
        }
        if (dataType instanceof IntegerType) {
            return DType.INT32;
        }
        if (dataType instanceof FloatType) {
            return DType.FLOAT32;
        }
        if (dataType instanceof DateType) {
            return DType.DATE32;
        }
        if (dataType instanceof TimestampType) {
            return DType.TIMESTAMP;
        }
        if (dataType instanceof StringType) {
            return DType.STRING;
        }
        return null;
    }
}
