package hex.gbm;

import hex.Model;
import hex.gbm.SharedTreeModel;
import hex.schemas.GBMModelV2;
import java.util.Arrays;
import water.H2O;
import water.Key;
import water.api.ModelSchema;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.ModelUtils;

/* loaded from: input_file:hex/gbm/GBMModel.class */
public class GBMModel extends SharedTreeModel<GBMModel, GBMParameters, GBMOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/gbm/GBMModel$GBMOutput.class */
    public static class GBMOutput extends SharedTreeModel.SharedTreeOutput {
        double initialPrediction;

        @Override // hex.gbm.SharedTreeModel.SharedTreeOutput
        public int nfeatures() {
            return this._names.length;
        }

        @Override // hex.gbm.SharedTreeModel.SharedTreeOutput
        public Model.ModelCategory getModelCategory() {
            throw H2O.unimpl();
        }
    }

    /* loaded from: input_file:hex/gbm/GBMModel$GBMParameters.class */
    public static class GBMParameters extends SharedTreeModel.SharedTreeParameters {
        public Family _loss = Family.AUTO;
        public float _learn_rate = 0.1f;
        public long _seed;

        /* loaded from: input_file:hex/gbm/GBMModel$GBMParameters$Family.class */
        public enum Family {
            AUTO,
            bernoulli
        }

        @Override // hex.gbm.SharedTreeModel.SharedTreeParameters
        public int sanityCheckParameters() {
            super.sanityCheckParameters();
            if (0.0d >= this._learn_rate || this._learn_rate > 1.0d) {
                validation_error("learn_rate", "learn_rate must be between 0 and 1");
            }
            if (this._loss != Family.bernoulli || (this._classification && this._nclass == 2)) {
                return this._validation_error_count;
            }
            throw new IllegalArgumentException("Bernoulli requires the response to be a 2-class categorical");
        }
    }

    public GBMModel(Key key, Frame frame, GBMParameters gBMParameters, GBMOutput gBMOutput, int i) {
        super(key, frame, gBMParameters, gBMOutput);
    }

    public ModelSchema schema() {
        return new GBMModelV2();
    }

    protected float[] score0(Chunk[] chunkArr, int i, double[] dArr, float[] fArr) {
        if (!$assertionsDisabled && chunkArr.length < ((GBMOutput) this._output)._names.length) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < ((GBMOutput) this._output)._names.length; i2++) {
            dArr[i2] = chunkArr[i2].at0(i);
        }
        return score0(dArr, fArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.gbm.SharedTreeModel
    public float[] score0(double[] dArr, float[] fArr) {
        float[] score0 = super.score0(dArr, fArr);
        if (this._parms._loss == GBMParameters.Family.bernoulli) {
            score0[2] = 1.0f / ((float) (1.0d + Math.exp(-(score0[1] + ((GBMOutput) this._output).initialPrediction))));
            score0[1] = 1.0f - score0[2];
            score0[0] = ModelUtils.getPrediction(score0, dArr);
            return score0;
        }
        if (((GBMOutput) this._output).nclasses() > 1) {
            float f = Float.NEGATIVE_INFINITY;
            float f2 = 0.0f;
            if (((GBMOutput) this._output).nclasses() == 2) {
                score0[2] = -score0[1];
            }
            for (int i = 1; i < score0.length; i++) {
                f = Math.max(f, score0[i]);
            }
            if (!$assertionsDisabled && Float.isInfinite(f)) {
                throw new AssertionError("Something is wrong with GBM trees since returned prediction is " + Arrays.toString(score0));
            }
            for (int i2 = 1; i2 < score0.length; i2++) {
                float exp = (float) Math.exp(score0[i2] - f);
                score0[i2] = exp;
                f2 += exp;
            }
            ArrayUtils.div(score0, f2);
            score0[0] = ModelUtils.getPrediction(score0, dArr);
        } else {
            fArr[0] = (float) (fArr[0] + ((GBMOutput) this._output).initialPrediction);
        }
        return score0;
    }

    static {
        $assertionsDisabled = !GBMModel.class.desiredAssertionStatus();
    }
}
