package hex.genmodel.algos.gbm;

import hex.genmodel.GenModel;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;

/* loaded from: input_file:hex/genmodel/algos/gbm/GbmMojoModel.class */
public final class GbmMojoModel extends SharedTreeMojoModelWithContributions implements SharedTreeGraphConverter {
    public DistributionFamily r;
    public LinkFunctionType s;
    public double t;

    public GbmMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions
    protected final SharedTreeMojoModelWithContributions.ContributionsPredictor a(TreeSHAPPredictor<double[]> treeSHAPPredictor) {
        return new SharedTreeMojoModelWithContributions.ContributionsPredictor(this, treeSHAPPredictor);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions
    public final double l() {
        return this.t;
    }

    @Override // hex.genmodel.GenModel
    public final double[] a(double[] dArr, double d, double[] dArr2) {
        super.b(dArr, dArr2);
        return b(dArr, d, dArr2);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeMojoModel
    public final double[] b(double[] dArr, double d, double[] dArr2) {
        if (this.r == DistributionFamily.bernoulli || this.r == DistributionFamily.quasibinomial || this.r == DistributionFamily.modified_huber) {
            dArr2[2] = a(this.s, dArr2[1] + this.t + d);
            dArr2[1] = 1.0d - dArr2[2];
        } else {
            if (this.r != DistributionFamily.multinomial) {
                dArr2[0] = a(this.s, dArr2[0] + this.t + d);
                return dArr2;
            }
            if (this.j == 2) {
                dArr2[1] = dArr2[1] + this.t + d;
                dArr2[2] = -dArr2[1];
            }
            GenModel.b(dArr2);
        }
        if (this.k) {
            GenModel.a(dArr2, this.m, this.n);
        }
        dArr2[0] = GenModel.a(dArr2, this.m, dArr, this.l);
        return dArr2;
    }

    private static double a(LinkFunctionType linkFunctionType, double d) {
        switch (linkFunctionType) {
            case log:
                return f(d);
            case logit:
            case ologit:
                return 1.0d / (1.0d + f(-d));
            case ologlog:
                return 1.0d - f((-1.0d) * f(d));
            case oprobit:
                return 0.0d;
            case inverse:
                return 1.0d / (d < 0.0d ? Math.min(-1.0E-5d, d) : Math.max(-1.0E-5d, d));
            default:
                return d;
        }
    }

    private static double f(double d) {
        return Math.min(1.0E19d, Math.exp(d));
    }

    @Override // hex.genmodel.GenModel
    public final double[] a(double[] dArr, double[] dArr2) {
        return a(dArr, 0.0d, dArr2);
    }
}
