package hex.tree.xgboost;

import hex.DataInfo;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMTask;
import hex.tree.SharedTree;
import hex.tree.TreeStats;
import hex.tree.xgboost.XGBoostModel;
import java.io.Closeable;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoostError;
import water.ExtensionManager;
import water.H2O;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.FileUtils;
import water.util.Log;

/* loaded from: input_file:hex/tree/xgboost/XGBoost.class */
public class XGBoost extends ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.tree.xgboost.XGBoost$1, reason: invalid class name */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$genmodel$utils$DistributionFamily = new int[DistributionFamily.values().length];

        static {
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.bernoulli.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.modified_huber.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.multinomial.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.huber.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.poisson.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.gamma.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.tweedie.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.gaussian.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.laplace.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.quantile.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.AUTO.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.tree.xgboost.XGBoost$1SparseItem, reason: invalid class name */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$1SparseItem.class */
    public class C1SparseItem {
        int pos;
        double val;

        C1SparseItem() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$XGBoostDriver.class */
    public class XGBoostDriver extends ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput>.Driver {
        private static final String FEATURE_MAP_FILENAME = "featureMap.txt";
        long _firstScore;
        long _timeLastScoreStart;
        long _timeLastScoreEnd;

        private XGBoostDriver() {
            super(XGBoost.this);
            this._firstScore = 0L;
            this._timeLastScoreStart = 0L;
            this._timeLastScoreEnd = 0L;
        }

        public void computeImpl() {
            XGBoost.this.init(true);
            ((XGBoostModel.XGBoostParameters) XGBoost.this._parms).checksum();
            if (XGBoost.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(XGBoost.this);
            }
            buildModel();
        }

        /* JADX WARN: Finally extract failed */
        final void buildModel() {
            XGBoostModel xGBoostModel = new XGBoostModel(XGBoost.this._result, (XGBoostModel.XGBoostParameters) XGBoost.this._parms, new XGBoostOutput(XGBoost.this), XGBoost.this._train, XGBoost.this._valid);
            xGBoostModel.write_lock(XGBoost.this._job);
            String[] strArr = {""};
            if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._dmatrix_type == XGBoostModel.XGBoostParameters.DMatrixType.sparse) {
                ((XGBoostOutput) xGBoostModel._output)._sparse = true;
            } else if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._dmatrix_type == XGBoostModel.XGBoostParameters.DMatrixType.dense) {
                ((XGBoostOutput) xGBoostModel._output)._sparse = false;
            } else {
                float f = 0.0f;
                int i = 0;
                for (int i2 = 0; i2 < XGBoost.this._train.numCols(); i2++) {
                    if (!XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._response_column) && !XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._weights_column) && !XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._fold_column) && !XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._offset_column)) {
                        f += (float) (XGBoost.this._train.vec(i2).nzCnt() / XGBoost.this._train.numRows());
                        i++;
                    }
                }
                float f2 = f / i;
                Log.info(new Object[]{"fill ratio: " + f2});
                ((XGBoostOutput) xGBoostModel._output)._sparse = ((double) f2) < 0.5d || XGBoost.this._train.numRows() * ((long) XGBoost.this._train.numCols()) > 2147483647L;
            }
            try {
                DMatrix convertFrametoDMatrix = XGBoost.convertFrametoDMatrix(xGBoostModel.model_info()._dataInfoKey, XGBoost.this._train, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._response_column, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._weights_column, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._fold_column, strArr, ((XGBoostOutput) xGBoostModel._output)._sparse);
                DMatrix convertFrametoDMatrix2 = XGBoost.this._valid != null ? XGBoost.convertFrametoDMatrix(xGBoostModel.model_info()._dataInfoKey, XGBoost.this._valid, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._response_column, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._weights_column, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._fold_column, strArr, ((XGBoostOutput) xGBoostModel._output)._sparse) : null;
                FileOutputStream fileOutputStream = null;
                File file = null;
                try {
                    try {
                        file = Files.createTempDirectory("xgboost-model-" + XGBoost.this._result.toString(), new FileAttribute[0]).toFile();
                        fileOutputStream = new FileOutputStream(new File(file, FEATURE_MAP_FILENAME));
                        fileOutputStream.write(strArr[0].getBytes());
                        FileUtils.close(new Closeable[]{fileOutputStream});
                    } catch (Throwable th) {
                        FileUtils.close(new Closeable[]{fileOutputStream});
                        throw th;
                    }
                } catch (IOException e) {
                    H2O.fail("Cannot generate featureMap.txt", e);
                    FileUtils.close(new Closeable[]{fileOutputStream});
                }
                xGBoostModel.model_info()._booster = ml.dmlc.xgboost4j.java.XGBoost.train(convertFrametoDMatrix, xGBoostModel.createParams(), 0, new HashMap(), (IObjective) null, (IEvaluation) null);
                scoreAndBuildTrees(xGBoostModel, convertFrametoDMatrix, convertFrametoDMatrix2, file);
                doScoring(xGBoostModel, xGBoostModel.model_info()._booster, convertFrametoDMatrix, convertFrametoDMatrix2, true, file);
                xGBoostModel.model_info().nativeToJava();
            } catch (XGBoostError e2) {
                e2.printStackTrace();
                H2O.fail("XGBoost failure", e2);
            }
            ((XGBoostOutput) xGBoostModel._output)._boosterBytes = xGBoostModel.model_info()._boosterBytes;
            xGBoostModel.unlock(XGBoost.this._job);
        }

        /* JADX WARN: Code restructure failed: missing block: B:37:0x020a, code lost:
        
            doScoring(r10, r10.model_info()._booster, r11, r12, true, r13);
         */
        /* JADX WARN: Code restructure failed: missing block: B:38:0x021c, code lost:
        
            return;
         */
        /*
            Code decompiled incorrectly, please refer to instructions dump.
            To view partially-correct add '--show-bad-code' argument
        */
        protected final void scoreAndBuildTrees(hex.tree.xgboost.XGBoostModel r10, ml.dmlc.xgboost4j.java.DMatrix r11, ml.dmlc.xgboost4j.java.DMatrix r12, java.io.File r13) throws ml.dmlc.xgboost4j.java.XGBoostError {
            /*
                Method dump skipped, instructions count: 541
                To view this dump add '--comments-level debug' option
            */
            throw new UnsupportedOperationException("Method not decompiled: hex.tree.xgboost.XGBoost.XGBoostDriver.scoreAndBuildTrees(hex.tree.xgboost.XGBoostModel, ml.dmlc.xgboost4j.java.DMatrix, ml.dmlc.xgboost4j.java.DMatrix, java.io.File):void");
        }

        private boolean doScoring(XGBoostModel xGBoostModel, Booster booster, DMatrix dMatrix, DMatrix dMatrix2, boolean z, File file) throws XGBoostError {
            boolean z2 = false;
            long currentTimeMillis = System.currentTimeMillis();
            if (this._firstScore == 0) {
                this._firstScore = currentTimeMillis;
            }
            long j = currentTimeMillis - this._timeLastScoreStart;
            XGBoost.this._job.update(0L, "Built " + ((XGBoostOutput) xGBoostModel._output)._ntrees + " trees so far (out of " + ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._ntrees + ").");
            boolean z3 = currentTimeMillis - this._firstScore < ((long) ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._initial_score_interval) || (j > ((long) ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_interval) && ((double) (this._timeLastScoreEnd - this._timeLastScoreStart)) / ((double) j) < 0.1d);
            boolean z4 = ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_tree_interval > 0 && ((XGBoostOutput) xGBoostModel._output)._ntrees % ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_tree_interval == 0;
            if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_each_iteration || z || ((z3 && ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_tree_interval == 0) || z4)) {
                this._timeLastScoreStart = currentTimeMillis;
                xGBoostModel.doScoring(booster, dMatrix, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms).train(), dMatrix2, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms).valid());
                this._timeLastScoreEnd = System.currentTimeMillis();
                xGBoostModel.computeVarImp(booster.getFeatureScore(new File(file, FEATURE_MAP_FILENAME).getAbsolutePath()));
                XGBoostOutput xGBoostOutput = (XGBoostOutput) xGBoostModel._output;
                xGBoostOutput._model_summary = SharedTree.createModelSummaryTable(xGBoostOutput._ntrees, (TreeStats) null);
                xGBoostOutput._scoring_history = SharedTree.createScoringHistoryTable(xGBoostOutput, ((XGBoostOutput) xGBoostModel._output)._scored_train, xGBoostOutput._scored_valid, XGBoost.this._job, xGBoostOutput._training_time_ms);
                xGBoostOutput._variable_importances = ModelMetrics.calcVarImp(xGBoostOutput._varimp);
                xGBoostModel.update(XGBoost.this._job);
                Log.info(new Object[]{xGBoostModel});
                z2 = true;
            }
            return z2;
        }

        /* synthetic */ XGBoostDriver(XGBoost xGBoost, AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    public boolean haveMojo() {
        return true;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ExtensionManager.getInstance().isCoreExtensionsEnabled(XGBoostExtension.NAME) ? ModelBuilder.BuilderVisibility.Stable : ModelBuilder.BuilderVisibility.Experimental;
    }

    public static DMatrix convertFrametoDMatrix(Key<DataInfo> key, Frame frame, String str, String str2, String str3, String[] strArr, boolean z) throws XGBoostError {
        Vec.Reader reader;
        DMatrix dMatrix;
        DataInfo dataInfo = key.get();
        if (strArr != null) {
            String[] coefNames = dataInfo.coefNames();
            StringBuilder sb = new StringBuilder();
            if (!$assertionsDisabled && coefNames.length != dataInfo.fullN()) {
                throw new AssertionError();
            }
            for (int i = 0; i < dataInfo.fullN(); i++) {
                sb.append(i).append(" ").append(coefNames[i].replaceAll("\\s*", "")).append(" ");
                int i2 = dataInfo._catOffsets[dataInfo._catOffsets.length - 1];
                if (i < i2 || frame.vec(i - i2).isBinary()) {
                    sb.append("i");
                } else if (frame.vec(i - i2).isInt()) {
                    sb.append("int");
                } else {
                    sb.append("q");
                }
                sb.append("\n");
            }
            strArr[0] = sb.toString();
        }
        int i3 = 0;
        int i4 = 0;
        int numRows = (int) frame.numRows();
        if (str2 == null) {
            reader = null;
        } else {
            Vec vec = frame.vec(str2);
            vec.getClass();
            reader = new Vec.Reader(vec);
        }
        Vec.Reader reader2 = reader;
        Vec.Reader[] readerArr = new Vec.Reader[frame.numCols()];
        for (int i5 = 0; i5 < readerArr.length; i5++) {
            Vec vec2 = frame.vec(i5);
            vec2.getClass();
            readerArr[i5] = new Vec.Reader(vec2);
        }
        try {
            if (z) {
                Log.info(new Object[]{"Treating matrix as sparse."});
                if (0 != 0) {
                    int i6 = dataInfo._nums;
                    List[] listArr = new List[i6];
                    for (int i7 = 0; i7 < i6; i7++) {
                        listArr[i7] = new ArrayList(Math.min(numRows, 10000));
                    }
                    int i8 = 0;
                    for (int i9 = 0; i9 < i6; i9++) {
                        Vec vec3 = frame.vec(i9);
                        for (int i10 = 0; i10 < vec3.nChunks(); i10++) {
                            Chunk chunkForChunkIdx = vec3.chunkForChunkIdx(i10);
                            int[] iArr = new int[chunkForChunkIdx.sparseLenZero()];
                            int nonzeros = chunkForChunkIdx.nonzeros(iArr);
                            for (int i11 = 0; i11 < nonzeros; i11++) {
                                C1SparseItem c1SparseItem = new C1SparseItem();
                                int i12 = iArr[i11];
                                c1SparseItem.pos = ((int) chunkForChunkIdx.start()) + i12;
                                if ((reader2 == null || reader2.at(c1SparseItem.pos) != 0.0d) && !chunkForChunkIdx.isNA(i12)) {
                                    c1SparseItem.val = chunkForChunkIdx.atd(i12);
                                    listArr[i9].add(c1SparseItem);
                                    i8++;
                                }
                            }
                        }
                    }
                    long[] jArr = new long[i6 + 1];
                    float[] fArr = new float[i8];
                    int[] iArr2 = new int[i8];
                    for (int i13 = 0; i13 < i6; i13++) {
                        List list = listArr[i13];
                        jArr[i13] = i3;
                        for (int i14 = 0; i14 < list.size(); i14++) {
                            C1SparseItem c1SparseItem2 = (C1SparseItem) list.get(i14);
                            iArr2[i3] = c1SparseItem2.pos;
                            fArr[i3] = (float) c1SparseItem2.val;
                            if (!$assertionsDisabled && c1SparseItem2.val == 0.0d) {
                                throw new AssertionError();
                            }
                            if (!$assertionsDisabled && Double.isNaN(c1SparseItem2.val)) {
                                throw new AssertionError();
                            }
                            if (!$assertionsDisabled && reader2 != null && reader2.at(c1SparseItem2.pos) == 0.0d) {
                                throw new AssertionError();
                            }
                            i3++;
                        }
                    }
                    jArr[i6] = i3;
                    float[] copyOf = Arrays.copyOf(fArr, i3);
                    int[] copyOf2 = Arrays.copyOf(iArr2, i3);
                    i4 = countUnique(copyOf2);
                    dMatrix = new DMatrix(jArr, copyOf2, copyOf, DMatrix.SparseType.CSC, i4);
                    if (!$assertionsDisabled && dMatrix.rowNum() != i4) {
                        throw new AssertionError();
                    }
                } else {
                    long[] jArr2 = new long[numRows + 1];
                    float[] fArr2 = new float[1048576];
                    int[] iArr3 = new int[1048576];
                    jArr2[0] = 0;
                    for (int i15 = 0; i15 < numRows; i15++) {
                        if (reader2 == null || reader2.at(i15) != 0.0d) {
                            int i16 = i3;
                            while (fArr2.length < i3 + dataInfo._cats + dataInfo._nums) {
                                int min = (int) Math.min(fArr2.length << 1, 2147483637L);
                                Log.info(new Object[]{"Enlarging sparse data structure from " + fArr2.length + " bytes to " + min + " bytes."});
                                if (fArr2.length == min) {
                                    throw new IllegalArgumentException(H2O.technote(11, "Data is too large to fit into the 32-bit Java float[] array that needs to be passed to the XGBoost C++ backend. Use H2O GBM instead."));
                                }
                                fArr2 = Arrays.copyOf(fArr2, min);
                                iArr3 = Arrays.copyOf(iArr3, min);
                            }
                            for (int i17 = 0; i17 < dataInfo._cats; i17++) {
                                if (!readerArr[i17].isNA(i15)) {
                                    fArr2[i3] = 1.0f;
                                    iArr3[i3] = dataInfo.getCategoricalId(i17, readerArr[i17].at8(i15));
                                    i3++;
                                }
                            }
                            for (int i18 = 0; i18 < dataInfo._nums; i18++) {
                                float at = (float) readerArr[dataInfo._cats + i18].at(i15);
                                if (!Float.isNaN(at) && at != 0.0f) {
                                    fArr2[i3] = at;
                                    iArr3[i3] = dataInfo._catOffsets[dataInfo._catOffsets.length - 1] + i18;
                                    i3++;
                                }
                            }
                            if (i3 == i16) {
                                fArr2[i3] = 0.0f;
                                iArr3[i3] = 0;
                                i3++;
                            }
                            i4++;
                            jArr2[i4] = i3;
                        }
                    }
                    dMatrix = new DMatrix(Arrays.copyOf(jArr2, i4 + 1), Arrays.copyOf(iArr3, i3), Arrays.copyOf(fArr2, i3), DMatrix.SparseType.CSR, dataInfo.fullN());
                    if (!$assertionsDisabled && dMatrix.rowNum() != i4) {
                        throw new AssertionError();
                    }
                }
            } else {
                Log.info(new Object[]{"Treating matrix as dense."});
                float[] fArr3 = new float[1048576];
                int fullN = dataInfo.fullN();
                int i19 = 0;
                for (int i20 = 0; i20 < numRows; i20++) {
                    if (reader2 == null || reader2.at(i20) != 0.0d) {
                        while (fArr3.length < (i4 + 1) * fullN) {
                            int min2 = (int) Math.min(fArr3.length << 1, 2147483637L);
                            Log.info(new Object[]{"Enlarging dense data structure from " + fArr3.length + " bytes to " + min2 + " bytes."});
                            if (fArr3.length == min2) {
                                throw new IllegalArgumentException(H2O.technote(11, "Data is too large to fit into the 32-bit Java float[] array that needs to be passed to the XGBoost C++ backend. Use H2O GBM instead."));
                            }
                            fArr3 = Arrays.copyOf(fArr3, min2);
                        }
                        for (int i21 = 0; i21 < dataInfo._cats; i21++) {
                            if (readerArr[i21].isNA(i20)) {
                                fArr3[i19 + dataInfo.getCategoricalId(i21, Double.NaN)] = 1.0f;
                            } else {
                                fArr3[i19 + dataInfo.getCategoricalId(i21, readerArr[i21].at8(i20))] = 1.0f;
                            }
                        }
                        for (int i22 = 0; i22 < dataInfo._nums; i22++) {
                            if (readerArr[dataInfo._cats + i22].isNA(i20)) {
                                fArr3[i19 + dataInfo._catOffsets[dataInfo._catOffsets.length - 1] + i22] = Float.NaN;
                            } else {
                                fArr3[i19 + dataInfo._catOffsets[dataInfo._catOffsets.length - 1] + i22] = (float) readerArr[dataInfo._cats + i22].at(i20);
                            }
                        }
                        if (!$assertionsDisabled && dataInfo._catOffsets[dataInfo._catOffsets.length - 1] + dataInfo._nums != fullN) {
                            throw new AssertionError();
                        }
                        i19 += fullN;
                        i4++;
                    }
                }
                dMatrix = new DMatrix(Arrays.copyOf(fArr3, i4 * fullN), i4, fullN, Float.NaN);
                if (!$assertionsDisabled && dMatrix.rowNum() != i4) {
                    throw new AssertionError();
                }
            }
            float[] fArr4 = new float[i4];
            if (reader2 != null) {
                int i23 = 0;
                for (int i24 = 0; i24 < numRows; i24++) {
                    if (reader2.at(i24) != 0.0d) {
                        int i25 = i23;
                        i23++;
                        fArr4[i25] = (float) reader2.at(i24);
                    }
                }
                if (!$assertionsDisabled && i23 != i4) {
                    throw new AssertionError();
                }
            }
            Vec vec4 = frame.vec(str);
            vec4.getClass();
            Vec.Reader reader3 = new Vec.Reader(vec4);
            float[] fArr5 = new float[i4];
            int i26 = 0;
            for (int i27 = 0; i27 < numRows; i27++) {
                if (reader2 == null || reader2.at(i27) != 0.0d) {
                    int i28 = i26;
                    i26++;
                    fArr5[i28] = (float) reader3.at(i27);
                }
            }
            if (!$assertionsDisabled && i26 != i4) {
                throw new AssertionError();
            }
            float[] copyOf3 = Arrays.copyOf(fArr5, i4);
            float[] copyOf4 = Arrays.copyOf(fArr4, i4);
            dMatrix.setLabel(copyOf3);
            if (reader2 != null) {
                dMatrix.setWeight(copyOf4);
            }
            return dMatrix;
        } catch (NegativeArraySizeException e) {
            throw new IllegalArgumentException(H2O.technote(11, "Data is too large to fit into the 32-bit Java float[] array that needs to be passed to the XGBoost C++ backend. Use H2O GBM instead."));
        }
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public XGBoost(XGBoostModel.XGBoostParameters xGBoostParameters) {
        super(xGBoostParameters);
        init(false);
    }

    public XGBoost(XGBoostModel.XGBoostParameters xGBoostParameters, Key<XGBoostModel> key) {
        super(xGBoostParameters, key);
        init(false);
    }

    public XGBoost(boolean z) {
        super(new XGBoostModel.XGBoostParameters(), z);
    }

    public boolean isSupervised() {
        return true;
    }

    protected int nModelsInParallel() {
        return 2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: trainModelImpl, reason: merged with bridge method [inline-methods] */
    public XGBoostDriver m5trainModelImpl() {
        return new XGBoostDriver(this, null);
    }

    public void init(boolean z) {
        super.init(z);
        if (z && this._response.naCnt() > 0) {
            error("_response_column", "Response contains missing values (NAs) - not supported by XGBoost.");
        }
        if (z) {
            if (error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
            }
            if (hasOffsetCol()) {
                error("_offset_column", "Offset is not supported for XGBoost.");
            }
        }
        if (H2O.CLOUD.size() > 1) {
            throw new IllegalArgumentException("XGBoost is currently only supported in single-node mode.");
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._backend == XGBoostModel.XGBoostParameters.Backend.gpu && !hasGPU(((XGBoostModel.XGBoostParameters) this._parms)._gpu_id)) {
            error("_backend", "GPU backend (gpu_id: " + ((XGBoostModel.XGBoostParameters) this._parms)._gpu_id + ") is not functional. Check CUDA_PATH and/or GPU installation.");
        }
        switch (AnonymousClass1.$SwitchMap$hex$genmodel$utils$DistributionFamily[((XGBoostModel.XGBoostParameters) this._parms)._distribution.ordinal()]) {
            case 1:
                if (this._nclass != 2) {
                    error("_distribution", H2O.technote(2, "Binomial requires the response to be a 2-class categorical"));
                    break;
                }
                break;
            case 2:
                if (this._nclass != 2) {
                    error("_distribution", H2O.technote(2, "Modified Huber requires the response to be a 2-class categorical."));
                    break;
                }
                break;
            case 3:
                if (!isClassifier()) {
                    error("_distribution", H2O.technote(2, "Multinomial requires an categorical response."));
                    break;
                }
                break;
            case 4:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Huber requires the response to be numeric."));
                    break;
                }
                break;
            case 5:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Poisson requires the response to be numeric."));
                    break;
                }
                break;
            case 6:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Gamma requires the response to be numeric."));
                    break;
                }
                break;
            case 7:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Tweedie requires the response to be numeric."));
                    break;
                }
                break;
            case 8:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Gaussian requires the response to be numeric."));
                    break;
                }
                break;
            case 9:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Laplace requires the response to be numeric."));
                    break;
                }
                break;
            case 10:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Quantile requires the response to be numeric."));
                    break;
                }
                break;
            case 11:
                break;
            default:
                error("_distribution", "Invalid distribution: " + ((XGBoostModel.XGBoostParameters) this._parms)._distribution);
                break;
        }
        if (0.0d >= ((XGBoostModel.XGBoostParameters) this._parms)._learn_rate || ((XGBoostModel.XGBoostParameters) this._parms)._learn_rate > 1.0d) {
            error("_learn_rate", "learn_rate must be between 0 and 1");
        }
        if (0.0d >= ((XGBoostModel.XGBoostParameters) this._parms)._col_sample_rate || ((XGBoostModel.XGBoostParameters) this._parms)._col_sample_rate > 1.0d) {
            error("_col_sample_rate", "col_sample_rate must be between 0 and 1");
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._grow_policy != XGBoostModel.XGBoostParameters.GrowPolicy.lossguide || ((XGBoostModel.XGBoostParameters) this._parms)._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.hist) {
            return;
        }
        error("_grow_policy", "must use tree_method=hist for grow_policy=lossguide");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DataInfo makeDataInfo(Frame frame, Frame frame2, XGBoostModel.XGBoostParameters xGBoostParameters, int i) {
        DataInfo dataInfo = new DataInfo(frame, frame2, 1, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, xGBoostParameters._missing_values_handling == XGBoostModel.XGBoostParameters.MissingValuesHandling.Skip, false, true, xGBoostParameters._weights_column != null, xGBoostParameters._offset_column != null, xGBoostParameters._fold_column != null);
        GLMTask.YMUTask doAll = new GLMTask.YMUTask(dataInfo, i, i == 1, xGBoostParameters._missing_values_handling == XGBoostModel.XGBoostParameters.MissingValuesHandling.Skip, true, true).doAll(dataInfo._adaptedFrame);
        if (doAll.wsum() == 0.0d && xGBoostParameters._missing_values_handling == XGBoostModel.XGBoostParameters.MissingValuesHandling.Skip) {
            throw new H2OIllegalArgumentException("No rows left in the dataset after filtering out rows with missing values. Ignore columns with many NAs or set missing_values_handling to 'MeanImputation'.");
        }
        if (xGBoostParameters._weights_column != null && xGBoostParameters._offset_column != null) {
            Log.warn(new Object[]{"Combination of offset and weights can lead to slight differences because Rollupstats aren't weighted - need to re-calculate weighted mean/sigma of the response including offset terms."});
        }
        if (xGBoostParameters._weights_column != null && xGBoostParameters._offset_column == null) {
            dataInfo.updateWeightedSigmaAndMean(doAll.predictorSDs(), doAll.predictorMeans());
            if (i == 1) {
                dataInfo.updateWeightedSigmaAndMeanForResponse(doAll.responseSDs(), doAll.responseMeans());
            }
        }
        return dataInfo;
    }

    private double effective_learning_rate(XGBoostModel xGBoostModel) {
        return ((XGBoostModel.XGBoostParameters) this._parms)._learn_rate * Math.pow(((XGBoostModel.XGBoostParameters) this._parms)._learn_rate_annealing, ((XGBoostOutput) xGBoostModel._output)._ntrees - 1);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean hasGPU(int i) {
        DMatrix dMatrix = null;
        try {
            dMatrix = new DMatrix(new float[]{1.0f, 2.0f, 1.0f, 2.0f}, 2, 2);
            dMatrix.setLabel(new float[]{1.0f, 0.0f});
        } catch (XGBoostError e) {
            e.printStackTrace();
        }
        HashMap hashMap = new HashMap();
        hashMap.put("updater", "grow_gpu_hist");
        hashMap.put("silent", 1);
        hashMap.put("gpu_id", Integer.valueOf(i));
        HashMap hashMap2 = new HashMap();
        hashMap2.put("train", dMatrix);
        try {
            ml.dmlc.xgboost4j.java.XGBoost.train(dMatrix, hashMap, 1, hashMap2, (IObjective) null, (IEvaluation) null);
            return true;
        } catch (XGBoostError e2) {
            return false;
        }
    }

    public static int countUnique(int[] iArr) {
        if (iArr.length == 0) {
            return 0;
        }
        int[] copyOf = Arrays.copyOf(iArr, iArr.length);
        Arrays.sort(copyOf);
        int i = 1;
        for (int i2 = 0; i2 < copyOf.length - 1; i2++) {
            if (copyOf[i2] != copyOf[i2 + 1]) {
                i++;
            }
        }
        return i;
    }

    static {
        $assertionsDisabled = !XGBoost.class.desiredAssertionStatus();
    }
}
