package hex.genmodel.easy;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException;
import hex.genmodel.easy.exception.PredictUnknownTypeException;
import hex.genmodel.easy.exception.PredictWrongModelCategoryException;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.DimReductionModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import hex.genmodel.easy.prediction.SortedClassProbability;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

/* loaded from: input_file:hex/genmodel/easy/EasyPredictModelWrapper.class */
public class EasyPredictModelWrapper implements Serializable {
    private final GenModel m;
    private final HashMap<String, Integer> modelColumnNameToIndexMap;
    private final HashMap<Integer, HashMap<String, Integer>> domainMap;
    private final boolean convertUnknownCategoricalLevelsToNa;
    private final ConcurrentHashMap<String, AtomicLong> unknownCategoricalLevelsSeenPerColumn;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/genmodel/easy/EasyPredictModelWrapper$Config.class */
    public static class Config {
        private GenModel model;
        private boolean convertUnknownCategoricalLevelsToNa = false;

        public Config setModel(GenModel genModel) {
            this.model = genModel;
            return this;
        }

        public GenModel getModel() {
            return this.model;
        }

        public Config setConvertUnknownCategoricalLevelsToNa(boolean z) {
            this.convertUnknownCategoricalLevelsToNa = z;
            return this;
        }

        public boolean getConvertUnknownCategoricalLevelsToNa() {
            return this.convertUnknownCategoricalLevelsToNa;
        }
    }

    public EasyPredictModelWrapper(Config config) {
        this.m = config.getModel();
        this.modelColumnNameToIndexMap = new HashMap<>();
        String[] names = this.m.getNames();
        for (int i = 0; i < names.length; i++) {
            this.modelColumnNameToIndexMap.put(names[i], Integer.valueOf(i));
        }
        this.unknownCategoricalLevelsSeenPerColumn = new ConcurrentHashMap<>();
        this.convertUnknownCategoricalLevelsToNa = config.getConvertUnknownCategoricalLevelsToNa();
        setupConvertUnknownCategoricalLevelsToNa();
        this.domainMap = new HashMap<>();
        for (int i2 = 0; i2 < this.m.getNumCols(); i2++) {
            String[] domainValues = this.m.getDomainValues(i2);
            if (domainValues != null) {
                HashMap<String, Integer> hashMap = new HashMap<>();
                for (int i3 = 0; i3 < domainValues.length; i3++) {
                    hashMap.put(domainValues[i3], Integer.valueOf(i3));
                }
                this.domainMap.put(Integer.valueOf(i2), hashMap);
            }
        }
    }

    public EasyPredictModelWrapper(GenModel genModel) {
        this(new Config().setModel(genModel));
    }

    public long getTotalUnknownCategoricalLevelsSeen() {
        long j = 0;
        Iterator<AtomicLong> it = getUnknownCategoricalLevelsSeenPerColumn().values().iterator();
        while (it.hasNext()) {
            j += it.next().get();
        }
        return j;
    }

    public ConcurrentHashMap<String, AtomicLong> getUnknownCategoricalLevelsSeenPerColumn() {
        return this.unknownCategoricalLevelsSeenPerColumn;
    }

    public AbstractPrediction predict(RowData rowData) throws PredictException {
        switch (this.m.getModelCategory()) {
            case AutoEncoder:
                return predictAutoEncoder(rowData);
            case Binomial:
                return predictBinomial(rowData);
            case Multinomial:
                return predictMultinomial(rowData);
            case Clustering:
                return predictClustering(rowData);
            case Regression:
                return predictRegression(rowData);
            case DimReduction:
                return predictDimReduction(rowData);
            case Unknown:
                throw new PredictException("Unknown model category");
            default:
                throw new PredictException("Unhandled model category (" + this.m.getModelCategory() + ") in switch statement");
        }
    }

    public AutoEncoderModelPrediction predictAutoEncoder(RowData rowData) throws PredictException {
        throw new RuntimeException("Unimplemented " + preamble(ModelCategory.AutoEncoder, rowData).length);
    }

    public DimReductionModelPrediction predictDimReduction(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.DimReduction, rowData);
        DimReductionModelPrediction dimReductionModelPrediction = new DimReductionModelPrediction();
        dimReductionModelPrediction.dimensions = preamble;
        return dimReductionModelPrediction;
    }

    public BinomialModelPrediction predictBinomial(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.Binomial, rowData);
        BinomialModelPrediction binomialModelPrediction = new BinomialModelPrediction();
        binomialModelPrediction.classProbabilities = new double[this.m.getNumResponseClasses()];
        binomialModelPrediction.labelIndex = (int) preamble[0];
        binomialModelPrediction.label = this.m.getDomainValues(this.m.getResponseIdx())[binomialModelPrediction.labelIndex];
        System.arraycopy(preamble, 1, binomialModelPrediction.classProbabilities, 0, binomialModelPrediction.classProbabilities.length);
        return binomialModelPrediction;
    }

    public MultinomialModelPrediction predictMultinomial(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.Multinomial, rowData);
        MultinomialModelPrediction multinomialModelPrediction = new MultinomialModelPrediction();
        multinomialModelPrediction.classProbabilities = new double[this.m.getNumResponseClasses()];
        multinomialModelPrediction.labelIndex = (int) preamble[0];
        multinomialModelPrediction.label = this.m.getDomainValues(this.m.getResponseIdx())[multinomialModelPrediction.labelIndex];
        System.arraycopy(preamble, 1, multinomialModelPrediction.classProbabilities, 0, multinomialModelPrediction.classProbabilities.length);
        return multinomialModelPrediction;
    }

    private SortedClassProbability[] sortByDescendingClassProbability(String[] strArr, double[] dArr) {
        if (!$assertionsDisabled && dArr.length != strArr.length) {
            throw new AssertionError();
        }
        SortedClassProbability[] sortedClassProbabilityArr = new SortedClassProbability[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            sortedClassProbabilityArr[i] = new SortedClassProbability();
            sortedClassProbabilityArr[i].name = strArr[i];
            sortedClassProbabilityArr[i].probability = dArr[i];
        }
        Arrays.sort(sortedClassProbabilityArr, Collections.reverseOrder());
        return sortedClassProbabilityArr;
    }

    public SortedClassProbability[] sortByDescendingClassProbability(BinomialModelPrediction binomialModelPrediction) {
        return sortByDescendingClassProbability(this.m.getDomainValues(this.m.getResponseIdx()), binomialModelPrediction.classProbabilities);
    }

    public SortedClassProbability[] sortByDescendingClassProbability(MultinomialModelPrediction multinomialModelPrediction) {
        return sortByDescendingClassProbability(this.m.getDomainValues(this.m.getResponseIdx()), multinomialModelPrediction.classProbabilities);
    }

    public ClusteringModelPrediction predictClustering(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.Clustering, rowData);
        ClusteringModelPrediction clusteringModelPrediction = new ClusteringModelPrediction();
        clusteringModelPrediction.cluster = (int) preamble[0];
        return clusteringModelPrediction;
    }

    public RegressionModelPrediction predictRegression(RowData rowData) throws PredictException {
        double[] preamble = preamble(ModelCategory.Regression, rowData);
        RegressionModelPrediction regressionModelPrediction = new RegressionModelPrediction();
        regressionModelPrediction.value = preamble[0];
        return regressionModelPrediction;
    }

    public ModelCategory getModelCategory() {
        return this.m.getModelCategory();
    }

    public String[] getResponseDomainValues() {
        return this.m.getDomainValues(this.m.getResponseIdx());
    }

    public String getHeader() {
        return this.m.getHeader();
    }

    private void setupConvertUnknownCategoricalLevelsToNa() {
        if (!this.convertUnknownCategoricalLevelsToNa) {
            this.unknownCategoricalLevelsSeenPerColumn.clear();
            return;
        }
        for (int i = 0; i < this.m.getNumCols(); i++) {
            if (this.m.getDomainValues(i) != null) {
                this.unknownCategoricalLevelsSeenPerColumn.put(this.m.getNames()[i], new AtomicLong());
            }
        }
    }

    private void validateModelCategory(ModelCategory modelCategory) throws PredictException {
        if (this.m.getModelCategory() != modelCategory) {
            throw new PredictWrongModelCategoryException("Prediction type unsupported by model of category " + this.m.getModelCategory());
        }
    }

    private double[] preamble(ModelCategory modelCategory, RowData rowData) throws PredictException {
        validateModelCategory(modelCategory);
        return predict(rowData, modelCategory == ModelCategory.DimReduction ? new double[this.m.nclasses()] : new double[this.m.getPredsSize()]);
    }

    private void setToNaN(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = Double.NaN;
        }
    }

    private void fillRawData(RowData rowData, double[] dArr) throws PredictException {
        double doubleValue;
        double intValue;
        for (String str : rowData.keySet()) {
            Integer num = this.modelColumnNameToIndexMap.get(str);
            if (num != null && num.intValue() < dArr.length) {
                if (this.m.getDomainValues(num.intValue()) == null) {
                    Object obj = rowData.get(str);
                    if (obj instanceof String) {
                        doubleValue = Double.parseDouble((String) obj);
                    } else {
                        if (!(obj instanceof Double)) {
                            throw new PredictUnknownTypeException("Unknown object type " + obj.getClass().getName());
                        }
                        doubleValue = ((Double) obj).doubleValue();
                    }
                    dArr[num.intValue()] = doubleValue;
                } else {
                    Object obj2 = rowData.get(str);
                    if (!(obj2 instanceof String)) {
                        throw new PredictUnknownTypeException("Unknown object type " + obj2.getClass().getName());
                    }
                    String str2 = (String) obj2;
                    Integer num2 = this.domainMap.get(num).get(str2);
                    if (num2 != null) {
                        intValue = num2.intValue();
                    } else {
                        if (!this.convertUnknownCategoricalLevelsToNa) {
                            throw new PredictUnknownCategoricalLevelException("Unknown categorical level (" + str + "," + str2 + ")", str, str2);
                        }
                        intValue = Double.NaN;
                        this.unknownCategoricalLevelsSeenPerColumn.get(str).incrementAndGet();
                    }
                    dArr[num.intValue()] = intValue;
                }
            }
        }
    }

    private double[] predict(RowData rowData, double[] dArr) throws PredictException {
        double[] dArr2 = new double[this.m.nfeatures()];
        setToNaN(dArr2);
        fillRawData(rowData, dArr2);
        return this.m.score0(dArr2, dArr);
    }

    static {
        $assertionsDisabled = !EasyPredictModelWrapper.class.desiredAssertionStatus();
    }
}
