package hex.tree.gbm;

import hex.genmodel.GenModel;
import hex.tree.SharedTreeModel;
import water.Key;
import water.fvec.Chunk;
import water.util.SB;

/* loaded from: input_file:hex/tree/gbm/GBMModel.class */
public class GBMModel extends SharedTreeModel<GBMModel, GBMParameters, GBMOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/tree/gbm/GBMModel$GBMOutput.class */
    public static class GBMOutput extends SharedTreeModel.SharedTreeOutput {
        public GBMOutput(GBM gbm, double d, double d2) {
            super(gbm, d, d2);
        }
    }

    /* loaded from: input_file:hex/tree/gbm/GBMModel$GBMParameters.class */
    public static class GBMParameters extends SharedTreeModel.SharedTreeParameters {
        public Family _distribution = Family.AUTO;
        public float _learn_rate = 0.1f;

        /* loaded from: input_file:hex/tree/gbm/GBMModel$GBMParameters$Family.class */
        public enum Family {
            AUTO,
            bernoulli,
            multinomial,
            gaussian
        }
    }

    public GBMModel(Key key, GBMParameters gBMParameters, GBMOutput gBMOutput) {
        super(key, gBMParameters, gBMOutput);
    }

    public double[] score0(Chunk[] chunkArr, int i, double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && chunkArr.length < dArr.length) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
        return score0(dArr, dArr2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.tree.SharedTreeModel
    public double[] score0(double[] dArr, double[] dArr2) {
        super.score0(dArr, dArr2);
        if (this._parms._distribution == GBMParameters.Family.bernoulli) {
            dArr2[2] = 1.0d / (1.0d + Math.exp(-(dArr2[1] + this._output._init_f)));
            dArr2[1] = 1.0d - dArr2[2];
            if (this._parms._balance_classes) {
                GenModel.correctProbabilities(dArr2, this._output._priorClassDist, this._output._modelClassDist);
            }
            dArr2[0] = GenModel.getPrediction(dArr2, dArr, defaultThreshold());
            return dArr2;
        }
        if (this._output.nclasses() == 1) {
            dArr2[0] = dArr2[0] + this._output._init_f;
            return dArr2;
        }
        if (this._output.nclasses() == 2) {
            dArr2[1] = dArr2[1] + this._output._init_f;
            dArr2[2] = -dArr2[1];
        }
        GenModel.GBM_rescale(dArr2);
        if (this._parms._balance_classes) {
            GenModel.correctProbabilities(dArr2, this._output._priorClassDist, this._output._modelClassDist);
        }
        dArr2[0] = GenModel.getPrediction(dArr2, dArr, defaultThreshold());
        return dArr2;
    }

    @Override // hex.tree.SharedTreeModel
    protected void toJavaUnifyPreds(SB sb, SB sb2) {
        if (this._parms._distribution == GBMParameters.Family.bernoulli) {
            sb.ip("double fx = preds[1] + ").p(this._output._init_f).p(";").nl();
            sb.ip("preds[2] = 1.0/(1.0+Math.exp(-fx));").nl();
            sb.ip("preds[1] = 1.0-preds[2];").nl();
            if (this._parms._balance_classes) {
                sb.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            sb.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, " + defaultThreshold() + ");").nl();
            return;
        }
        if (this._output.nclasses() == 1) {
            sb.ip("preds[0] += ").p(this._output._init_f).p(";");
            return;
        }
        if (this._output.nclasses() == 2) {
            sb.ip("preds[1] += ").p(this._output._init_f).p(";").nl();
            sb.ip("preds[2] = - preds[1];").nl();
        }
        sb.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
        if (this._parms._balance_classes) {
            sb.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
        }
        sb.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, " + defaultThreshold() + ");").nl();
    }

    static {
        $assertionsDisabled = !GBMModel.class.desiredAssertionStatus();
    }
}
