package hex.tree.xgboost;

import hex.DataInfo;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMTask;
import hex.tree.SharedTree;
import hex.tree.TreeStats;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.rabit.RabitTrackerH2O;
import java.io.Closeable;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.BoosterHelper;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.IRabitTracker;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostCleanupTask;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import ml.dmlc.xgboost4j.java.XGBoostSetupTask;
import ml.dmlc.xgboost4j.java.XGBoostUpdateTask;
import water.DKV;
import water.DTask;
import water.ExtensionManager;
import water.H2O;
import water.H2ONode;
import water.Job;
import water.Key;
import water.RPC;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.FileUtils;
import water.util.Log;
import water.util.Timer;

/* loaded from: input_file:hex/tree/xgboost/XGBoost.class */
public class XGBoost extends ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput> {
    private static final double FILL_RATIO_THRESHOLD = 0.25d;
    private static Set<Integer> GPUS = new HashSet();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.tree.xgboost.XGBoost$1, reason: invalid class name */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$genmodel$utils$DistributionFamily = new int[DistributionFamily.values().length];

        static {
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.bernoulli.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.modified_huber.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.multinomial.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.huber.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.poisson.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.gamma.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.tweedie.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.gaussian.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.laplace.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.quantile.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$hex$genmodel$utils$DistributionFamily[DistributionFamily.AUTO.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$BoosterProvider.class */
    public static final class BoosterProvider {
        XGBoostModelInfo _modelInfo;
        XGBoostUpdateTask _updateTask;

        BoosterProvider(XGBoostModelInfo xGBoostModelInfo, XGBoostUpdateTask xGBoostUpdateTask) {
            this._modelInfo = xGBoostModelInfo;
            this._updateTask = xGBoostUpdateTask;
            this._modelInfo.setBoosterBytes(this._updateTask.getBoosterBytes());
        }

        final void reset(XGBoostUpdateTask xGBoostUpdateTask) {
            this._updateTask = xGBoostUpdateTask;
        }

        final void updateBooster() {
            if (this._updateTask == null) {
                throw new IllegalStateException("Booster can be retrieved only once!");
            }
            this._modelInfo.setBoosterBytes(this._updateTask.getBoosterBytes());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$HasGPUTask.class */
    public static class HasGPUTask extends DTask<HasGPUTask> {
        private final int _gpu_id;
        private boolean _hasGPU;

        private HasGPUTask(int i) {
            this._gpu_id = i;
        }

        public void compute2() {
            this._hasGPU = XGBoost.hasGPU(this._gpu_id);
            tryComplete();
        }

        /* synthetic */ HasGPUTask(int i, AnonymousClass1 anonymousClass1) {
            this(i);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/xgboost/XGBoost$XGBoostDriver.class */
    public class XGBoostDriver extends ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput>.Driver {
        private final String featureMapFileName;
        private File featureMapFile;
        long _firstScore;
        long _timeLastScoreStart;
        long _timeLastScoreEnd;
        static final /* synthetic */ boolean $assertionsDisabled;

        XGBoostDriver() {
            super(XGBoost.this);
            this.featureMapFileName = "featureMap" + UUID.randomUUID().toString() + ".txt";
            this.featureMapFile = null;
            this._firstScore = 0L;
            this._timeLastScoreStart = 0L;
            this._timeLastScoreEnd = 0L;
        }

        public void computeImpl() {
            XGBoost.this.init(true);
            if (XGBoost.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(XGBoost.this);
            }
            buildModel();
        }

        final void buildModel() {
            if ((!XGBoostModel.XGBoostParameters.Backend.auto.equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._backend) && !XGBoostModel.XGBoostParameters.Backend.gpu.equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._backend)) || !XGBoost.hasGPU(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._gpu_id) || H2O.getCloudSize() != 1 || !((XGBoostModel.XGBoostParameters) XGBoost.this._parms).gpuIncompatibleParams().isEmpty()) {
                buildModelImpl();
                return;
            }
            synchronized (XGBoostGPULock.lock(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._gpu_id)) {
                buildModelImpl();
            }
        }

        final void buildModelImpl() {
            XGBoostModel xGBoostModel = new XGBoostModel(XGBoost.this._result, (XGBoostModel.XGBoostParameters) XGBoost.this._parms, new XGBoostOutput(XGBoost.this), XGBoost.this._train, XGBoost.this._valid);
            xGBoostModel.write_lock(XGBoost.this._job);
            if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._dmatrix_type == XGBoostModel.XGBoostParameters.DMatrixType.sparse) {
                ((XGBoostOutput) xGBoostModel._output)._sparse = true;
            } else if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._dmatrix_type == XGBoostModel.XGBoostParameters.DMatrixType.dense) {
                ((XGBoostOutput) xGBoostModel._output)._sparse = false;
            } else {
                ((XGBoostOutput) xGBoostModel._output)._sparse = isTrainDatasetSparse();
            }
            try {
                try {
                    XGBoostSetupTask.FrameNodes findFrameNodes = XGBoostSetupTask.findFrameNodes(XGBoost.this._train);
                    RabitTrackerH2O rabitTrackerH2O = new RabitTrackerH2O(findFrameNodes.getNumNodes());
                    if (!startRabitTracker(rabitTrackerH2O)) {
                        throw new IllegalArgumentException("Cannot start XGboost rabit tracker, please, make sure you have python installed!");
                    }
                    DataInfo dataInfo = xGBoostModel.model_info()._dataInfoKey.get();
                    if (!$assertionsDisabled && dataInfo == null) {
                        throw new AssertionError();
                    }
                    String makeFeatureMap = XGBoostUtils.makeFeatureMap(XGBoost.this._train, dataInfo);
                    xGBoostModel.model_info().setFeatureMap(makeFeatureMap);
                    this.featureMapFile = createFeatureMapFile(makeFeatureMap);
                    BoosterParms createParams = XGBoostModel.createParams((XGBoostModel.XGBoostParameters) XGBoost.this._parms, ((XGBoostOutput) xGBoostModel._output).nclasses());
                    ((XGBoostOutput) xGBoostModel._output)._native_parameters = createParams.toTwoDimTable();
                    XGBoostSetupTask run = new XGBoostSetupTask(xGBoostModel, (XGBoostModel.XGBoostParameters) XGBoost.this._parms, createParams, getWorkerEnvs(rabitTrackerH2O), findFrameNodes).run();
                    try {
                        scoreAndBuildTrees(run, new BoosterProvider(xGBoostModel.model_info(), new XGBoostUpdateTask(run, 0).run()), xGBoostModel);
                        XGBoostCleanupTask.cleanUp(run);
                        waitOnRabitWorkers(rabitTrackerH2O);
                        rabitTrackerH2O.stop();
                        if (0 != 0) {
                            try {
                                XGBoostCleanupTask.cleanUp(null);
                            } catch (Exception e) {
                                Log.err(new Object[]{"XGBoost clean-up failed - this could leak memory!", e});
                            }
                        }
                        xGBoostModel.unlock(XGBoost.this._job);
                    } catch (Throwable th) {
                        rabitTrackerH2O.stop();
                        throw th;
                    }
                } catch (XGBoostError e2) {
                    e2.printStackTrace();
                    throw new RuntimeException("XGBoost failure", e2);
                }
            } catch (Throwable th2) {
                if (0 != 0) {
                    try {
                        XGBoostCleanupTask.cleanUp(null);
                    } catch (Exception e3) {
                        Log.err(new Object[]{"XGBoost clean-up failed - this could leak memory!", e3});
                    }
                }
                xGBoostModel.unlock(XGBoost.this._job);
                throw th2;
            }
        }

        private boolean isTrainDatasetSparse() {
            long j = 0;
            int i = 0;
            long j2 = 0;
            for (int i2 = 0; i2 < XGBoost.this._train.numCols(); i2++) {
                if (!XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._response_column) && !XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._weights_column) && !XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._fold_column) && !XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._offset_column)) {
                    Vec vec = XGBoost.this._train.vec(i2);
                    j = vec.isCategorical() ? j + XGBoost.this._train.numRows() : j + vec.nzCnt();
                    if (vec.isCategorical()) {
                        j2 += vec.cardinality();
                    } else {
                        i++;
                    }
                }
            }
            long j3 = j2 + i;
            double numRows = j / (j3 * XGBoost.this._train.numRows());
            Log.info(new Object[]{"fill ratio: " + numRows});
            return numRows < XGBoost.FILL_RATIO_THRESHOLD || XGBoost.this._train.numRows() * j3 > 2147483647L;
        }

        private File createFeatureMapFile(String str) {
            FileOutputStream fileOutputStream = null;
            try {
                try {
                    File file = new File(Files.createTempDirectory("xgboost-model-" + XGBoost.this._result.toString(), new FileAttribute[0]).toFile(), this.featureMapFileName);
                    fileOutputStream = new FileOutputStream(file);
                    fileOutputStream.write(str.getBytes());
                    fileOutputStream.close();
                    FileUtils.close(new Closeable[]{fileOutputStream});
                    return file;
                } catch (IOException e) {
                    throw new RuntimeException("Cannot generate feature map file " + this.featureMapFileName, e);
                }
            } catch (Throwable th) {
                FileUtils.close(new Closeable[]{fileOutputStream});
                throw th;
            }
        }

        private void scoreAndBuildTrees(XGBoostSetupTask xGBoostSetupTask, BoosterProvider boosterProvider, XGBoostModel xGBoostModel) throws XGBoostError {
            int i = 0;
            while (true) {
                if (i >= ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._ntrees) {
                    break;
                }
                if (doScoring(xGBoostModel, boosterProvider, false)) {
                    if (ScoreKeeper.stopEarly(((XGBoostOutput) xGBoostModel._output).scoreKeepers(), ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._stopping_rounds, XGBoost.this._nclass > 1, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._stopping_metric, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._stopping_tolerance, "model's last", true)) {
                        Log.info(new Object[]{"Early stopping triggered - stopping XGBoost training"});
                        break;
                    }
                }
                Timer timer = new Timer();
                boosterProvider.reset(new XGBoostUpdateTask(xGBoostSetupTask, i).run());
                Log.info(new Object[]{(i + 1) + ". tree was built in " + timer.toString()});
                XGBoost.this._job.update(1L);
                ((XGBoostOutput) xGBoostModel._output)._ntrees++;
                ((XGBoostOutput) xGBoostModel._output)._scored_train = (ScoreKeeper[]) ArrayUtils.copyAndFillOf(((XGBoostOutput) xGBoostModel._output)._scored_train, ((XGBoostOutput) xGBoostModel._output)._ntrees + 1, new ScoreKeeper());
                ((XGBoostOutput) xGBoostModel._output)._scored_valid = ((XGBoostOutput) xGBoostModel._output)._scored_valid != null ? (ScoreKeeper[]) ArrayUtils.copyAndFillOf(((XGBoostOutput) xGBoostModel._output)._scored_valid, ((XGBoostOutput) xGBoostModel._output)._ntrees + 1, new ScoreKeeper()) : null;
                ((XGBoostOutput) xGBoostModel._output)._training_time_ms = ArrayUtils.copyAndFillOf(((XGBoostOutput) xGBoostModel._output)._training_time_ms, ((XGBoostOutput) xGBoostModel._output)._ntrees + 1, System.currentTimeMillis());
                if (XGBoost.this.stop_requested() && !XGBoost.this.timeout()) {
                    throw new Job.JobCancelledException();
                }
                if (XGBoost.this.timeout()) {
                    Log.info(new Object[]{"Stopping XGBoost training because of timeout"});
                    break;
                }
                i++;
            }
            XGBoost.this._job.update(0L, "Scoring the final model");
            doScoring(xGBoostModel, boosterProvider, true);
            XGBoost.this._job.update(((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._ntrees - ((XGBoostOutput) xGBoostModel._output)._ntrees);
        }

        private boolean startRabitTracker(IRabitTracker iRabitTracker) {
            if (H2O.CLOUD.size() > 1) {
                return iRabitTracker.start(0L);
            }
            return true;
        }

        private void waitOnRabitWorkers(IRabitTracker iRabitTracker) {
            if (H2O.CLOUD.size() > 1) {
                iRabitTracker.waitFor(0L);
            }
        }

        private Map<String, String> getWorkerEnvs(IRabitTracker iRabitTracker) {
            return H2O.CLOUD.size() > 1 ? iRabitTracker.getWorkerEnvs() : new HashMap();
        }

        private boolean doScoring(XGBoostModel xGBoostModel, BoosterProvider boosterProvider, boolean z) throws XGBoostError {
            boolean z2 = false;
            long currentTimeMillis = System.currentTimeMillis();
            if (this._firstScore == 0) {
                this._firstScore = currentTimeMillis;
            }
            long j = currentTimeMillis - this._timeLastScoreStart;
            XGBoost.this._job.update(0L, "Built " + ((XGBoostOutput) xGBoostModel._output)._ntrees + " trees so far (out of " + ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._ntrees + ").");
            boolean z3 = currentTimeMillis - this._firstScore < ((long) ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._initial_score_interval) || (j > ((long) ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_interval) && ((double) (this._timeLastScoreEnd - this._timeLastScoreStart)) / ((double) j) < 0.1d);
            boolean z4 = ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_tree_interval > 0 && ((XGBoostOutput) xGBoostModel._output)._ntrees % ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_tree_interval == 0;
            if (((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_each_iteration || z || ((z3 && ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._score_tree_interval == 0) || z4)) {
                this._timeLastScoreStart = currentTimeMillis;
                boosterProvider.updateBooster();
                xGBoostModel.doScoring(XGBoost.this._train, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms).train(), XGBoost.this._valid, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms).valid());
                this._timeLastScoreEnd = System.currentTimeMillis();
                XGBoostOutput xGBoostOutput = (XGBoostOutput) xGBoostModel._output;
                Booster booster = null;
                try {
                    booster = xGBoostModel.model_info().deserializeBooster();
                    Map<String, Integer> map = (Map) BoosterHelper.doWithLocalRabit(new BoosterHelper.BoosterOp<Map<String, Integer>>() { // from class: hex.tree.xgboost.XGBoost.XGBoostDriver.1
                        /* renamed from: apply, reason: merged with bridge method [inline-methods] */
                        public Map<String, Integer> m9apply(Booster booster2) throws XGBoostError {
                            return booster2.getFeatureScore(XGBoostDriver.this.featureMapFile.getAbsolutePath());
                        }
                    }, booster);
                    if (booster != null) {
                        BoosterHelper.dispose(new Object[]{booster});
                    }
                    xGBoostOutput._varimp = xGBoostModel.computeVarImp(map);
                    xGBoostOutput._model_summary = SharedTree.createModelSummaryTable(xGBoostOutput._ntrees, (TreeStats) null);
                    xGBoostOutput._scoring_history = SharedTree.createScoringHistoryTable(xGBoostOutput, ((XGBoostOutput) xGBoostModel._output)._scored_train, xGBoostOutput._scored_valid, XGBoost.this._job, xGBoostOutput._training_time_ms, ((XGBoostModel.XGBoostParameters) XGBoost.this._parms)._custom_metric_func != null);
                    xGBoostOutput._variable_importances = ModelMetrics.calcVarImp(xGBoostOutput._varimp);
                    xGBoostModel.update(XGBoost.this._job);
                    Log.info(new Object[]{xGBoostModel});
                    z2 = true;
                } catch (Throwable th) {
                    if (booster != null) {
                        BoosterHelper.dispose(new Object[]{booster});
                    }
                    throw th;
                }
            }
            return z2;
        }

        static {
            $assertionsDisabled = !XGBoost.class.desiredAssertionStatus();
        }
    }

    public boolean haveMojo() {
        return true;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ExtensionManager.getInstance().isCoreExtensionsEnabled(XGBoostExtension.NAME) ? ModelBuilder.BuilderVisibility.Stable : ModelBuilder.BuilderVisibility.Experimental;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public XGBoost(XGBoostModel.XGBoostParameters xGBoostParameters) {
        super(xGBoostParameters);
        init(false);
    }

    public XGBoost(XGBoostModel.XGBoostParameters xGBoostParameters, Key<XGBoostModel> key) {
        super(xGBoostParameters, key);
        init(false);
    }

    public XGBoost(boolean z) {
        super(new XGBoostModel.XGBoostParameters(), z);
    }

    public boolean isSupervised() {
        return true;
    }

    protected int nModelsInParallel() {
        return 2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: trainModelImpl, reason: merged with bridge method [inline-methods] */
    public XGBoostDriver m6trainModelImpl() {
        return new XGBoostDriver();
    }

    public void init(boolean z) {
        super.init(z);
        if (H2O.CLOUD.size() > 1 && H2O.SELF.getSecurityManager().securityEnabled) {
            throw new H2OIllegalArgumentException("Cannot run XGBoost on an SSL enabled cluster larger than 1 node. XGBoost does not support SSL encryption.");
        }
        if (z) {
            if (this._response.naCnt() > 0) {
                error("_response_column", "Response contains missing values (NAs) - not supported by XGBoost.");
            }
            if (!((XGBoostExtensionCheck) new XGBoostExtensionCheck().doAllNodes()).enabled) {
                error("XGBoost", "XGBoost is not available on all nodes!");
            }
        }
        if (z) {
            if (error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
            }
            if (hasOffsetCol()) {
                error("_offset_column", "Offset is not supported for XGBoost.");
            }
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._backend == XGBoostModel.XGBoostParameters.Backend.gpu) {
            if (!hasGPU(((XGBoostModel.XGBoostParameters) this._parms)._gpu_id)) {
                error("_backend", "GPU backend (gpu_id: " + ((XGBoostModel.XGBoostParameters) this._parms)._gpu_id + ") is not functional. Check CUDA_PATH and/or GPU installation.");
            }
            if (H2O.getCloudSize() > 1) {
                error("_backend", "GPU backend is not supported in distributed mode.");
            }
            Map<String, Object> gpuIncompatibleParams = ((XGBoostModel.XGBoostParameters) this._parms).gpuIncompatibleParams();
            if (!gpuIncompatibleParams.isEmpty()) {
                for (Map.Entry<String, Object> entry : gpuIncompatibleParams.entrySet()) {
                    error("_backend", "GPU backend is not available for parameter setting '" + entry.getKey() + " = " + entry.getValue() + "'. Use CPU backend instead.");
                }
            }
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._distribution == DistributionFamily.quasibinomial) {
            error("_distribution", "Quasibinomial is not supported for XGBoost in current H2O.");
        }
        switch (AnonymousClass1.$SwitchMap$hex$genmodel$utils$DistributionFamily[((XGBoostModel.XGBoostParameters) this._parms)._distribution.ordinal()]) {
            case 1:
                if (this._nclass != 2) {
                    error("_distribution", H2O.technote(2, "Binomial requires the response to be a 2-class categorical"));
                    break;
                }
                break;
            case 2:
                if (this._nclass != 2) {
                    error("_distribution", H2O.technote(2, "Modified Huber requires the response to be a 2-class categorical."));
                    break;
                }
                break;
            case 3:
                if (!isClassifier()) {
                    error("_distribution", H2O.technote(2, "Multinomial requires an categorical response."));
                    break;
                }
                break;
            case 4:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Huber requires the response to be numeric."));
                    break;
                }
                break;
            case 5:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Poisson requires the response to be numeric."));
                    break;
                }
                break;
            case 6:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Gamma requires the response to be numeric."));
                    break;
                }
                break;
            case 7:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Tweedie requires the response to be numeric."));
                    break;
                }
                break;
            case 8:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Gaussian requires the response to be numeric."));
                    break;
                }
                break;
            case 9:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Laplace requires the response to be numeric."));
                    break;
                }
                break;
            case 10:
                if (isClassifier()) {
                    error("_distribution", H2O.technote(2, "Quantile requires the response to be numeric."));
                    break;
                }
                break;
            case 11:
                break;
            default:
                error("_distribution", "Invalid distribution: " + ((XGBoostModel.XGBoostParameters) this._parms)._distribution);
                break;
        }
        if (0.0d >= ((XGBoostModel.XGBoostParameters) this._parms)._learn_rate || ((XGBoostModel.XGBoostParameters) this._parms)._learn_rate > 1.0d) {
            error("_learn_rate", "learn_rate must be between 0 and 1");
        }
        if (0.0d >= ((XGBoostModel.XGBoostParameters) this._parms)._col_sample_rate || ((XGBoostModel.XGBoostParameters) this._parms)._col_sample_rate > 1.0d) {
            error("_col_sample_rate", "col_sample_rate must be between 0 and 1");
        }
        if (((XGBoostModel.XGBoostParameters) this._parms)._grow_policy != XGBoostModel.XGBoostParameters.GrowPolicy.lossguide || ((XGBoostModel.XGBoostParameters) this._parms)._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.hist) {
            return;
        }
        error("_grow_policy", "must use tree_method=hist for grow_policy=lossguide");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DataInfo makeDataInfo(Frame frame, Frame frame2, XGBoostModel.XGBoostParameters xGBoostParameters, int i) {
        DataInfo dataInfo = new DataInfo(frame, frame2, 1, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, xGBoostParameters._missing_values_handling == XGBoostModel.XGBoostParameters.MissingValuesHandling.Skip, false, true, xGBoostParameters._weights_column != null, xGBoostParameters._offset_column != null, xGBoostParameters._fold_column != null);
        GLMTask.YMUTask doAll = new GLMTask.YMUTask(dataInfo, i, i == 1, xGBoostParameters._missing_values_handling == XGBoostModel.XGBoostParameters.MissingValuesHandling.Skip, true, true).doAll(dataInfo._adaptedFrame);
        if (doAll.wsum() == 0.0d && xGBoostParameters._missing_values_handling == XGBoostModel.XGBoostParameters.MissingValuesHandling.Skip) {
            throw new H2OIllegalArgumentException("No rows left in the dataset after filtering out rows with missing values. Ignore columns with many NAs or set missing_values_handling to 'MeanImputation'.");
        }
        if (xGBoostParameters._weights_column != null && xGBoostParameters._offset_column != null) {
            Log.warn(new Object[]{"Combination of offset and weights can lead to slight differences because Rollupstats aren't weighted - need to re-calculate weighted mean/sigma of the response including offset terms."});
        }
        if (xGBoostParameters._weights_column != null && xGBoostParameters._offset_column == null) {
            dataInfo.updateWeightedSigmaAndMean(doAll.predictorSDs(), doAll.predictorMeans());
            if (i == 1) {
                dataInfo.updateWeightedSigmaAndMeanForResponse(doAll.responseSDs(), doAll.responseMeans());
            }
        }
        return dataInfo;
    }

    public static byte[] getRawArray(Booster booster) {
        if (null == booster) {
            return null;
        }
        try {
            Rabit.init(new HashMap());
            byte[] byteArray = booster.toByteArray();
            Rabit.shutdown();
            return byteArray;
        } catch (XGBoostError e) {
            throw new IllegalStateException("Failed to initialize Rabit or serialize the booster.", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean hasGPU(H2ONode h2ONode, int i) {
        boolean z;
        if (H2O.SELF.equals(h2ONode)) {
            z = hasGPU(i);
        } else {
            HasGPUTask hasGPUTask = new HasGPUTask(i, null);
            new RPC(h2ONode, hasGPUTask).call().get();
            z = hasGPUTask._hasGPU;
        }
        Log.debug(new Object[]{"Availability of GPU (id=" + i + ") on node " + h2ONode + ": " + z});
        return z;
    }

    static synchronized boolean hasGPU(int i) {
        if (!XGBoostExtension.isGpuSupportEnabled()) {
            return false;
        }
        if (GPUS.contains(Integer.valueOf(i))) {
            return true;
        }
        try {
            DMatrix dMatrix = new DMatrix(new float[]{1.0f, 2.0f, 1.0f, 2.0f}, 2, 2);
            dMatrix.setLabel(new float[]{1.0f, 0.0f});
            HashMap hashMap = new HashMap();
            hashMap.put("updater", "grow_gpu_hist");
            hashMap.put("silent", 1);
            hashMap.put("gpu_id", Integer.valueOf(i));
            HashMap hashMap2 = new HashMap();
            hashMap2.put("train", dMatrix);
            try {
                try {
                    Rabit.init(new HashMap());
                    ml.dmlc.xgboost4j.java.XGBoost.train(dMatrix, hashMap, 1, hashMap2, (IObjective) null, (IEvaluation) null);
                    GPUS.add(Integer.valueOf(i));
                    try {
                        Rabit.shutdown();
                    } catch (XGBoostError e) {
                        Log.warn(new Object[]{"Cannot shutdown XGBoost Rabit for current thread."});
                    }
                    return true;
                } catch (XGBoostError e2) {
                    try {
                        Rabit.shutdown();
                    } catch (XGBoostError e3) {
                        Log.warn(new Object[]{"Cannot shutdown XGBoost Rabit for current thread."});
                    }
                    return false;
                }
            } catch (Throwable th) {
                try {
                    Rabit.shutdown();
                } catch (XGBoostError e4) {
                    Log.warn(new Object[]{"Cannot shutdown XGBoost Rabit for current thread."});
                }
                throw th;
            }
        } catch (XGBoostError e5) {
            throw new IllegalStateException("Couldn't prepare training matrix for XGBoost.", e5);
        }
    }

    public void cv_computeAndSetOptimalParameters(ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput>[] modelBuilderArr) {
        if (((XGBoostModel.XGBoostParameters) this._parms)._stopping_rounds == 0 && ((XGBoostModel.XGBoostParameters) this._parms)._max_runtime_secs == 0.0d) {
            return;
        }
        ((XGBoostModel.XGBoostParameters) this._parms)._stopping_rounds = 0;
        ((XGBoostModel.XGBoostParameters) this._parms)._max_runtime_secs = 0.0d;
        int i = 0;
        for (ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput> modelBuilder : modelBuilderArr) {
            i += ((XGBoostOutput) DKV.getGet(modelBuilder.dest())._output)._ntrees;
        }
        ((XGBoostModel.XGBoostParameters) this._parms)._ntrees = (int) (i / modelBuilderArr.length);
        warn("_ntrees", "Setting optimal _ntrees to " + ((XGBoostModel.XGBoostParameters) this._parms)._ntrees + " for cross-validation main model based on early stopping of cross-validation models.");
        warn("_stopping_rounds", "Disabling convergence-based early stopping for cross-validation main model.");
        warn("_max_runtime_secs", "Disabling maximum allowed runtime for cross-validation main model.");
    }
}
