package hex;

import hex.Model;
import water.H2O;
import water.Iced;
import water.persist.PersistManager;

/* loaded from: input_file:hex/Distribution.class */
public class Distribution extends Iced {
    public static double MIN_LOG;
    public static double MAX;
    public final Family distribution;
    public final double tweediePower;
    public final double quantileAlpha;
    public double huberDelta;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.Distribution$1, reason: invalid class name */
    /* loaded from: input_file:hex/Distribution$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$Distribution$Family = new int[Family.values().length];

        static {
            try {
                $SwitchMap$hex$Distribution$Family[Family.AUTO.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.gaussian.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.huber.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.laplace.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.quantile.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.bernoulli.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.poisson.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.gamma.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.tweedie.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.modified_huber.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$hex$Distribution$Family[Family.multinomial.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
        }
    }

    /* loaded from: input_file:hex/Distribution$Family.class */
    public enum Family {
        AUTO,
        bernoulli,
        modified_huber,
        multinomial,
        gaussian,
        poisson,
        gamma,
        tweedie,
        huber,
        laplace,
        quantile
    }

    public Distribution(Family family) {
        this.distribution = family;
        if (!$assertionsDisabled && family == Family.tweedie) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && family == Family.quantile) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && family == Family.huber) {
            throw new AssertionError();
        }
        this.tweediePower = 1.5d;
        this.quantileAlpha = 0.5d;
        this.huberDelta = Double.NaN;
    }

    public Distribution(Model.Parameters parameters) {
        this.distribution = parameters._distribution;
        this.tweediePower = parameters._tweedie_power;
        this.quantileAlpha = parameters._quantile_alpha;
        this.huberDelta = 1.0d;
        if ($assertionsDisabled) {
            return;
        }
        if (this.tweediePower <= 1.0d || this.tweediePower >= 2.0d) {
            throw new AssertionError();
        }
    }

    public void setHuberDelta(double d) {
        this.huberDelta = d;
    }

    public static double exp(double d) {
        return Math.min(MAX, Math.exp(d));
    }

    public static double log(double d) {
        double max = Math.max(0.0d, d);
        return max == 0.0d ? MIN_LOG : Math.max(MIN_LOG, Math.log(max));
    }

    public static String expString(String str) {
        return "Math.min(" + MAX + ", Math.exp(" + str + "))";
    }

    public double deviance(double d, double d2, double d3) {
        double link = link(d3);
        switch (AnonymousClass1.$SwitchMap$hex$Distribution$Family[this.distribution.ordinal()]) {
            case 1:
            case 2:
                return d * (d2 - link) * (d2 - link);
            case 3:
                return Math.abs(d2 - link) <= this.huberDelta ? d * (d2 - link) * (d2 - link) : 2.0d * d * (Math.abs(d2 - link) - this.huberDelta) * this.huberDelta;
            case 4:
                return d * Math.abs(d2 - link);
            case 5:
                return d2 > link ? d * this.quantileAlpha * (d2 - link) : d * (1.0d - this.quantileAlpha) * (link - d2);
            case 6:
                return (-2.0d) * d * ((d2 * link) - log(1.0d + exp(link)));
            case 7:
                return (-2.0d) * d * ((d2 * link) - exp(link));
            case PersistManager.MAX_BACKENDS /* 8 */:
                return 2.0d * d * ((d2 * exp(-link)) + link);
            case 9:
                if ($assertionsDisabled || (this.tweediePower > 1.0d && this.tweediePower < 2.0d)) {
                    return 2.0d * d * (((Math.pow(d2, 2.0d - this.tweediePower) / ((1.0d - this.tweediePower) * (2.0d - this.tweediePower))) - ((d2 * exp(link * (1.0d - this.tweediePower))) / (1.0d - this.tweediePower))) + (exp(link * (2.0d - this.tweediePower)) / (2.0d - this.tweediePower)));
                }
                throw new AssertionError();
            case 10:
                double d4 = ((2.0d * d2) - 1.0d) * link;
                if (d4 < -1.0d) {
                    return (-d) * 4.0d * d4;
                }
                if (d4 > 1.0d) {
                    return 0.0d;
                }
                return d * d4 * d4;
            default:
                throw H2O.unimpl();
        }
    }

    public double negHalfGradient(double d, double d2) {
        switch (AnonymousClass1.$SwitchMap$hex$Distribution$Family[this.distribution.ordinal()]) {
            case 1:
            case 2:
            case 6:
            case 7:
                return d - linkInv(d2);
            case 3:
                return Math.abs(d - d2) <= this.huberDelta ? d - d2 : d2 >= d ? -this.huberDelta : this.huberDelta;
            case 4:
                return d2 > d ? -0.5d : 0.5d;
            case 5:
                return d > d2 ? 0.5d * this.quantileAlpha : 0.5d * (this.quantileAlpha - 1.0d);
            case PersistManager.MAX_BACKENDS /* 8 */:
                return (d * exp(-d2)) - 1.0d;
            case 9:
                if ($assertionsDisabled || (this.tweediePower > 1.0d && this.tweediePower < 2.0d)) {
                    return (d * exp(d2 * (1.0d - this.tweediePower))) - exp(d2 * (2.0d - this.tweediePower));
                }
                throw new AssertionError();
            case 10:
                double d3 = ((2.0d * d) - 1.0d) * d2;
                if (d3 < -1.0d) {
                    return 2.0d * ((2.0d * d) - 1.0d);
                }
                if (d3 > 1.0d) {
                    return 0.0d;
                }
                return (-d2) * ((2.0d * d) - 1.0d) * ((2.0d * d) - 1.0d);
            default:
                throw H2O.unimpl();
        }
    }

    public double link(double d) {
        switch (AnonymousClass1.$SwitchMap$hex$Distribution$Family[this.distribution.ordinal()]) {
            case 1:
            case 2:
            case 3:
            case 4:
            case 5:
                return d;
            case 6:
            case 10:
                return log(d / (1.0d - d));
            case 7:
            case PersistManager.MAX_BACKENDS /* 8 */:
            case 9:
            case 11:
                return log(d);
            default:
                throw H2O.unimpl();
        }
    }

    public double linkInv(double d) {
        switch (AnonymousClass1.$SwitchMap$hex$Distribution$Family[this.distribution.ordinal()]) {
            case 1:
            case 2:
            case 3:
            case 4:
            case 5:
                return d;
            case 6:
            case 10:
                return 1.0d / (1.0d + exp(-d));
            case 7:
            case PersistManager.MAX_BACKENDS /* 8 */:
            case 9:
            case 11:
                return exp(d);
            default:
                throw H2O.unimpl();
        }
    }

    public String linkInvString(String str) {
        switch (AnonymousClass1.$SwitchMap$hex$Distribution$Family[this.distribution.ordinal()]) {
            case 1:
            case 2:
            case 3:
            case 4:
            case 5:
                return str;
            case 6:
            case 10:
                return "1/(1+" + expString("-" + str) + ")";
            case 7:
            case PersistManager.MAX_BACKENDS /* 8 */:
            case 9:
            case 11:
                return expString(str);
            default:
                throw H2O.unimpl();
        }
    }

    public double initFNum(double d, double d2, double d3) {
        switch (AnonymousClass1.$SwitchMap$hex$Distribution$Family[this.distribution.ordinal()]) {
            case 1:
            case 2:
            case 6:
            case 11:
                return d * (d3 - d2);
            case 3:
            case 4:
            case 5:
            default:
                throw H2O.unimpl();
            case 7:
                return d * d3;
            case PersistManager.MAX_BACKENDS /* 8 */:
                return d * d3 * linkInv(-d2);
            case 9:
                return d * d3 * exp(d2 * (1.0d - this.tweediePower));
            case 10:
                if (d3 == 1.0d) {
                    return d;
                }
                return 0.0d;
        }
    }

    public double initFDenom(double d, double d2, double d3) {
        switch (AnonymousClass1.$SwitchMap$hex$Distribution$Family[this.distribution.ordinal()]) {
            case 1:
            case 2:
            case 6:
            case PersistManager.MAX_BACKENDS /* 8 */:
            case 11:
                return d;
            case 3:
            case 4:
            case 5:
            default:
                throw H2O.unimpl();
            case 7:
                return d * linkInv(d2);
            case 9:
                return d * exp(d2 * (2.0d - this.tweediePower));
            case 10:
                if (d3 == 1.0d) {
                    return 0.0d;
                }
                return d;
        }
    }

    public double gammaNum(double d, double d2, double d3, double d4) {
        switch (AnonymousClass1.$SwitchMap$hex$Distribution$Family[this.distribution.ordinal()]) {
            case 2:
            case 6:
            case 11:
                return d * d3;
            case 3:
            case 4:
            case 5:
            default:
                throw H2O.unimpl();
            case 7:
                return d * d2;
            case PersistManager.MAX_BACKENDS /* 8 */:
                return d * (d3 + 1.0d);
            case 9:
                return d * d2 * exp(d4 * (1.0d - this.tweediePower));
            case 10:
                double d5 = ((2.0d * d2) - 1.0d) * d4;
                if (d5 < -1.0d) {
                    return d * 4.0d * ((2.0d * d2) - 1.0d);
                }
                if (d5 > 1.0d) {
                    return 0.0d;
                }
                return d * 2.0d * ((2.0d * d2) - 1.0d) * (1.0d - d5);
        }
    }

    public double gammaDenom(double d, double d2, double d3, double d4) {
        switch (AnonymousClass1.$SwitchMap$hex$Distribution$Family[this.distribution.ordinal()]) {
            case 2:
            case PersistManager.MAX_BACKENDS /* 8 */:
                return d;
            case 3:
            case 4:
            case 5:
            default:
                throw H2O.unimpl();
            case 6:
                double d5 = d2 - d3;
                return d * d5 * (1.0d - d5);
            case 7:
                return d * (d2 - d3);
            case 9:
                return d * exp(d4 * (2.0d - this.tweediePower));
            case 10:
                double d6 = ((2.0d * d2) - 1.0d) * d4;
                if (d6 < -1.0d) {
                    return (-d) * 4.0d * d6;
                }
                if (d6 > 1.0d) {
                    return 0.0d;
                }
                return d * (1.0d - d6) * (1.0d - d6);
            case 11:
                double abs = Math.abs(d3);
                return d * abs * (1.0d - abs);
        }
    }

    static {
        $assertionsDisabled = !Distribution.class.desiredAssertionStatus();
        MIN_LOG = -19.0d;
        MAX = 1.0E19d;
    }
}
