package hex.glrm;

import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsUnsupervised;
import hex.glrm.GLRM;
import water.H2O;
import water.Key;
import water.fvec.Frame;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/glrm/GLRMModel.class */
public class GLRMModel extends Model<GLRMModel, GLRMParameters, GLRMOutput> {

    /* loaded from: input_file:hex/glrm/GLRMModel$GLRMOutput.class */
    public static class GLRMOutput extends Model.Output {
        public int _iterations;
        public double _objective;
        public double _avg_change_obj;
        public double[][] _archetypes;
        public double _step_size;
        public double[][] _eigenvectors_raw;
        public TwoDimTable _eigenvectors;
        public double[] _std_deviation;
        public TwoDimTable _pc_importance;
        public Key<Frame> _loading_key;
        public double[] _normSub;
        public double[] _normMul;

        public GLRMOutput(GLRM glrm) {
            super(glrm);
        }

        public int nfeatures() {
            return this._names.length;
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.DimReduction;
        }
    }

    /* loaded from: input_file:hex/glrm/GLRMModel$GLRMParameters.class */
    public static class GLRMParameters extends Model.Parameters {
        public Key<Frame> _user_points;
        public Key<Frame> _loading_key;
        public int _k = 1;
        public Loss _loss = Loss.L2;
        public MultiLoss _multi_loss = MultiLoss.Categorical;
        public Regularizer _regularization_x = Regularizer.L2;
        public Regularizer _regularization_y = Regularizer.L2;
        public double _gamma_x = 0.0d;
        public double _gamma_y = 0.0d;
        public int _max_iterations = 1000;
        public double _init_step_size = 1.0d;
        public double _min_step_size = 1.0E-4d;
        public long _seed = System.nanoTime();
        public DataInfo.TransformType _transform = DataInfo.TransformType.NONE;
        public GLRM.Initialization _init = GLRM.Initialization.PlusPlus;
        public boolean _recover_pca = false;

        /* loaded from: input_file:hex/glrm/GLRMModel$GLRMParameters$Loss.class */
        public enum Loss {
            L2,
            L1,
            Huber,
            Poisson,
            Hinge,
            Logistic
        }

        /* loaded from: input_file:hex/glrm/GLRMModel$GLRMParameters$MultiLoss.class */
        public enum MultiLoss {
            Categorical,
            Ordinal
        }

        /* loaded from: input_file:hex/glrm/GLRMModel$GLRMParameters$Regularizer.class */
        public enum Regularizer {
            L2,
            L1
        }

        public final double loss(double d, double d2) {
            switch (this._loss) {
                case L2:
                    return (d - d2) * (d - d2);
                case L1:
                    return Math.abs(d - d2);
                case Huber:
                    return Math.abs(d - d2) <= 1.0d ? 0.5d * (d - d2) * (d - d2) : Math.abs(d - d2) - 0.5d;
                case Poisson:
                    return ((Math.exp(d) - (d2 * d)) + (d2 * Math.log(d2))) - d2;
                case Hinge:
                    return Math.max(1.0d - (d2 * d), 0.0d);
                case Logistic:
                    return Math.log(1.0d + Math.exp((-d2) * d));
                default:
                    throw new RuntimeException("Unknown loss function " + this._loss);
            }
        }

        public final double lgrad(double d, double d2) {
            switch (this._loss) {
                case L2:
                    return 2.0d * (d - d2);
                case L1:
                    return Math.signum(d - d2);
                case Huber:
                    return Math.abs(d - d2) <= 1.0d ? d - d2 : Math.signum(d - d2);
                case Poisson:
                    return Math.exp(d) - d2;
                case Hinge:
                    if (d2 * d <= 1.0d) {
                        return -d2;
                    }
                    return 0.0d;
                case Logistic:
                    return (-d2) / (1.0d + Math.exp(d2 * d));
                default:
                    throw new RuntimeException("Unknown loss function " + this._loss);
            }
        }

        public final double mloss(double[] dArr, int i) {
            if (i < 0 || i > dArr.length - 1) {
                throw new IllegalArgumentException("Index must be between 0 and " + String.valueOf(dArr.length - 1));
            }
            double d = 0.0d;
            switch (this._multi_loss) {
                case Categorical:
                    for (double d2 : dArr) {
                        d += Math.max(1.0d + d2, 0.0d);
                    }
                    return d + (Math.max(1.0d - dArr[i], 0.0d) - Math.max(1.0d + dArr[i], 0.0d));
                case Ordinal:
                    int i2 = 0;
                    while (i2 < dArr.length - 1) {
                        d += Math.max(i > i2 ? 1.0d - dArr[i2] : 1.0d, 0.0d);
                        i2++;
                    }
                    return d;
                default:
                    throw new RuntimeException("Unknown multidimensional loss function " + this._multi_loss);
            }
        }

        public final double[] mlgrad(double[] dArr, int i) {
            if (i < 0 || i > dArr.length - 1) {
                throw new IllegalArgumentException("Index must be between 0 and " + String.valueOf(dArr.length - 1));
            }
            double[] dArr2 = new double[dArr.length];
            switch (this._multi_loss) {
                case Categorical:
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        dArr2[i2] = 1.0d + dArr[i2] > 0.0d ? 1.0d : 0.0d;
                    }
                    dArr2[i] = 1.0d - dArr[i] > 0.0d ? -1.0d : 0.0d;
                    return dArr2;
                case Ordinal:
                    int i3 = 0;
                    while (i3 < dArr.length - 1) {
                        dArr2[i3] = (i <= i3 || 1.0d - dArr[i3] <= 0.0d) ? 0.0d : -1.0d;
                        i3++;
                    }
                    return dArr2;
                default:
                    throw new RuntimeException("Unknown multidimensional loss function " + this._multi_loss);
            }
        }

        public final double regularize_x(double d) {
            return regularize(d, this._regularization_x);
        }

        public final double regularize_y(double d) {
            return regularize(d, this._regularization_y);
        }

        public final double regularize(double d, Regularizer regularizer) {
            switch (regularizer) {
                case L2:
                    return d * d;
                case L1:
                    return Math.abs(d);
                default:
                    throw new RuntimeException("Unknown regularization function " + regularizer);
            }
        }

        public final double regularize_x(double[][] dArr) {
            return regularize(dArr, this._regularization_x);
        }

        public final double regularize_y(double[][] dArr) {
            return regularize(dArr, this._regularization_y);
        }

        public final double regularize(double[][] dArr, Regularizer regularizer) {
            if (dArr == null) {
                return 0.0d;
            }
            double d = 0.0d;
            for (double[] dArr2 : dArr) {
                for (int i = 0; i < dArr[0].length; i++) {
                    d += regularize(dArr2[i], regularizer);
                }
            }
            return d;
        }

        public final double rproxgrad_x(double d, double d2) {
            return rproxgrad(d, d2, this._gamma_x, this._regularization_x);
        }

        public final double rproxgrad_y(double d, double d2) {
            return rproxgrad(d, d2, this._gamma_y, this._regularization_y);
        }

        public final double rproxgrad(double d, double d2, double d3, Regularizer regularizer) {
            switch (regularizer) {
                case L2:
                    return d / (1.0d + ((2.0d * d2) * d3));
                case L1:
                    return Math.max(d - (d2 * d3), 0.0d) + Math.min(d + (d2 * d3), 0.0d);
                default:
                    throw new RuntimeException("Unknown regularization function " + regularizer);
            }
        }
    }

    /* loaded from: input_file:hex/glrm/GLRMModel$ModelMetricsGLRM.class */
    public static class ModelMetricsGLRM extends ModelMetricsUnsupervised {

        /* loaded from: input_file:hex/glrm/GLRMModel$ModelMetricsGLRM$GLRMModelMetrics.class */
        public static class GLRMModelMetrics extends ModelMetricsUnsupervised.MetricBuilderUnsupervised {
            public GLRMModelMetrics(int i) {
                this._work = new double[i];
            }

            public double[] perRow(double[] dArr, float[] fArr, Model model) {
                return dArr;
            }

            public ModelMetrics makeModelMetrics(Model model, Frame frame, double d) {
                return model._output.addModelMetrics(new ModelMetricsGLRM(model, frame));
            }
        }

        public ModelMetricsGLRM(Model model, Frame frame) {
            super(model, frame, Double.NaN);
        }
    }

    public GLRMModel(Key key, GLRMParameters gLRMParameters, GLRMOutput gLRMOutput) {
        super(key, gLRMParameters, gLRMOutput);
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        throw H2O.unimpl();
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        return new ModelMetricsGLRM.GLRMModelMetrics(((GLRMParameters) this._parms)._k);
    }
}
