/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java.spark.rapids;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.Table;
import java.util.List;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType;

public class GpuColumnBatch {
    private final Table table;
    private final StructType schema;

    public GpuColumnBatch(Table table, StructType schema) {
        this.table = table;
        this.schema = schema;
    }

    public StructType getSchema() {
        return this.schema;
    }

    public long getNumRows() {
        return this.table.getRowCount();
    }

    public int getNumColumns() {
        return this.table.getNumberOfColumns();
    }

    public ColumnVector getColumnVector(int index) {
        return this.table.getColumn(index);
    }

    public long getColumn(int index) {
        ColumnVector v = this.table.getColumn(index);
        return v.getNativeCudfColumnAddress();
    }

    public ColumnVector getColumnVectorInitHost(int index) {
        ColumnVector cv = this.table.getColumn(index);
        cv.ensureOnHost();
        return cv;
    }

    private double getNumericValueInColumn(int dataIndex, ColumnVector cv, StructField field) {
        double value;
        DataType type = field.dataType();
        if (type instanceof FloatType) {
            value = cv.getFloat((long)dataIndex);
        } else if (type instanceof IntegerType) {
            value = cv.getInt((long)dataIndex);
        } else if (type instanceof ByteType) {
            value = cv.getByte((long)dataIndex);
        } else if (type instanceof ShortType) {
            value = cv.getShort((long)dataIndex);
        } else if (type instanceof DoubleType) {
            value = cv.getDouble((long)dataIndex);
        } else if (type instanceof LongType) {
            value = cv.getLong((long)dataIndex);
        } else {
            throw new IllegalArgumentException("Not a numeric type in column: " + field.name());
        }
        return value;
    }

    private double getNumericValueInColumn(int dataIndex, int colIndex, double defVal) {
        ColumnVector cv = this.getColumnVector(colIndex);
        cv.ensureOnHost();
        return cv.getRowCount() > 0L ? this.getNumericValueInColumn(dataIndex, cv, this.getSchema().apply(colIndex)) : defVal;
    }

    public int getIntInColumn(int dataIndex, int colIndex, int defVal) {
        return (int)this.getNumericValueInColumn(dataIndex, colIndex, defVal);
    }

    public int groupAndAggregateOnColumnsHost(int groupIdx, int weightIdx, int prevTailGid, List<Integer> groupInfo, List<Float> weightInfo) {
        boolean hasWeight = weightIdx >= 0;
        ColumnVector aggrCV = null;
        Float curWeight = null;
        if (hasWeight) {
            aggrCV = this.getColumnVectorInitHost(weightIdx);
            Float firstWeight = aggrCV.getRowCount() > 0L ? Float.valueOf((float)this.getNumericValueInColumn(0, aggrCV, this.getSchema().apply(weightIdx))) : null;
            curWeight = weightInfo.isEmpty() ? firstWeight : weightInfo.get(weightInfo.size() - 1);
        }
        ColumnVector groupCV = this.getColumnVectorInitHost(groupIdx);
        StructField groupSF = this.getSchema().apply(groupIdx);
        int groupId = prevTailGid;
        int groupSize = groupInfo.isEmpty() ? 0 : groupInfo.get(groupInfo.size() - 1);
        int i = 0;
        while ((long)i < groupCV.getRowCount()) {
            Float weight = Float.valueOf(hasWeight ? (float)this.getNumericValueInColumn(i, aggrCV, this.getSchema().apply(weightIdx)) : 0.0f);
            int gid = (int)this.getNumericValueInColumn(i, groupCV, groupSF);
            if (gid == groupId) {
                ++groupSize;
                if (hasWeight && !weight.equals(curWeight)) {
                    throw new IllegalArgumentException("The instances in the same group have to be assigned with the same weight. Unexpected weight: " + weight);
                }
            } else {
                GpuColumnBatch.addOrUpdateInfos(prevTailGid, groupId, groupSize, curWeight, hasWeight, groupInfo, weightInfo);
                if (hasWeight) {
                    curWeight = weight;
                }
                groupId = gid;
                groupSize = 1;
            }
            ++i;
        }
        GpuColumnBatch.addOrUpdateInfos(prevTailGid, groupId, groupSize, curWeight, hasWeight, groupInfo, weightInfo);
        return groupId;
    }

    private static void addOrUpdateInfos(int prevTailGid, int curGid, int curGroupSize, Float curWeight, boolean hasWeight, List<Integer> groupInfo, List<Float> weightInfo) {
        if (curGroupSize <= 0) {
            return;
        }
        if (groupInfo.isEmpty() || curGid != prevTailGid) {
            groupInfo.add(curGroupSize);
            if (hasWeight && curWeight != null) {
                weightInfo.add(curWeight);
            }
        } else {
            groupInfo.set(groupInfo.size() - 1, curGroupSize);
        }
    }

    public static DType getRapidsType(DataType type) {
        DType result = GpuColumnBatch.toRapidsOrNull(type);
        if (result == null) {
            throw new IllegalArgumentException(type + " is not supported for GPU processing yet.");
        }
        return result;
    }

    private static DType toRapidsOrNull(DataType type) {
        if (type instanceof LongType) {
            return DType.INT64;
        }
        if (type instanceof DoubleType) {
            return DType.FLOAT64;
        }
        if (type instanceof ByteType) {
            return DType.INT8;
        }
        if (type instanceof BooleanType) {
            return DType.BOOL8;
        }
        if (type instanceof ShortType) {
            return DType.INT16;
        }
        if (type instanceof IntegerType) {
            return DType.INT32;
        }
        if (type instanceof FloatType) {
            return DType.FLOAT32;
        }
        if (type instanceof DateType) {
            return DType.DATE32;
        }
        if (type instanceof TimestampType) {
            return DType.TIMESTAMP;
        }
        if (type instanceof StringType) {
            return DType.STRING;
        }
        return null;
    }
}

