package hex.glrm;

import Jama.Matrix;
import Jama.QRDecomposition;
import Jama.SingularValueDecomposition;
import hex.DataInfo;
import hex.FrameTask;
import hex.Model;
import hex.ModelBuilder;
import hex.glrm.GLRMModel;
import hex.gram.Gram;
import hex.kmeans.KMeans;
import hex.kmeans.KMeansModel;
import hex.schemas.GLRMV3;
import hex.schemas.ModelBuilderSchema;
import hex.svd.SVD;
import hex.svd.SVDModel;
import java.util.Arrays;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/glrm/GLRM.class */
public class GLRM extends ModelBuilder<GLRMModel, GLRMModel.GLRMParameters, GLRMModel.GLRMOutput> {
    private final double TOLERANCE = 1.0E-6d;
    private transient int _ncolA;
    private transient int _ncolX;

    /* loaded from: input_file:hex/glrm/GLRM$BMulTask.class */
    private static class BMulTask extends FrameTask<BMulTask> {
        double[][] _yt;
        static final /* synthetic */ boolean $assertionsDisabled;

        BMulTask(Key key, DataInfo dataInfo, double[][] dArr) {
            super(key, dataInfo);
            this._yt = dArr;
        }

        @Override // hex.FrameTask
        protected void processRow(long j, DataInfo.Row row, NewChunk[] newChunkArr) {
            double[] dArr = row.numVals;
            if (!$assertionsDisabled && dArr.length != this._yt[0].length) {
                throw new AssertionError();
            }
            for (int i = 0; i < this._yt[0].length; i++) {
                double d = 0.0d;
                int numStart = this._dinfo.numStart();
                for (double d2 : dArr) {
                    int i2 = numStart;
                    numStart++;
                    d += d2 * this._yt[i][i2];
                }
                if (!$assertionsDisabled && numStart != this._yt[0].length) {
                    throw new AssertionError();
                }
                newChunkArr[i].addNum(d);
            }
        }

        static {
            $assertionsDisabled = !GLRM.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/glrm/GLRM$GLRMDriver.class */
    class GLRMDriver extends H2O.H2OCountedCompleter<GLRMDriver> {
        GLRMDriver() {
        }

        /* JADX WARN: Finally extract failed */
        public double[][] initialY() {
            double[][] transform;
            if (null != ((GLRMModel.GLRMParameters) GLRM.this._parms)._user_points) {
                Vec[] vecs = ((GLRMModel.GLRMParameters) GLRM.this._parms)._user_points.get().vecs();
                transform = new double[((GLRMModel.GLRMParameters) GLRM.this._parms)._k][GLRM.this._ncolA];
                for (int i = 0; i < ((GLRMModel.GLRMParameters) GLRM.this._parms)._k; i++) {
                    for (int i2 = 0; i2 < GLRM.this._ncolA; i2++) {
                        transform[i][i2] = vecs[i2].at(i);
                    }
                }
            } else if (((GLRMModel.GLRMParameters) GLRM.this._parms)._init == Initialization.SVD) {
                SVDModel.SVDParameters sVDParameters = new SVDModel.SVDParameters();
                sVDParameters._train = ((GLRMModel.GLRMParameters) GLRM.this._parms)._train;
                sVDParameters._nv = ((GLRMModel.GLRMParameters) GLRM.this._parms)._k;
                sVDParameters._max_iterations = ((GLRMModel.GLRMParameters) GLRM.this._parms)._max_iterations;
                sVDParameters._transform = ((GLRMModel.GLRMParameters) GLRM.this._parms)._transform;
                sVDParameters._seed = ((GLRMModel.GLRMParameters) GLRM.this._parms)._seed;
                sVDParameters._only_v = true;
                SVDModel sVDModel = null;
                SVD svd = null;
                try {
                    svd = new SVD(sVDParameters);
                    sVDModel = (SVDModel) svd.trainModel().get();
                    if (svd != null) {
                        svd.remove();
                    }
                    if (sVDModel != null) {
                        sVDModel.remove();
                    }
                    transform = ArrayUtils.transpose(((SVDModel.SVDOutput) sVDModel._output)._v);
                } catch (Throwable th) {
                    if (svd != null) {
                        svd.remove();
                    }
                    if (sVDModel != null) {
                        sVDModel.remove();
                    }
                    throw th;
                }
            } else {
                KMeansModel.KMeansParameters kMeansParameters = new KMeansModel.KMeansParameters();
                kMeansParameters._train = ((GLRMModel.GLRMParameters) GLRM.this._parms)._train;
                kMeansParameters._ignored_columns = ((GLRMModel.GLRMParameters) GLRM.this._parms)._ignored_columns;
                kMeansParameters._dropConsCols = ((GLRMModel.GLRMParameters) GLRM.this._parms)._dropConsCols;
                kMeansParameters._drop_na20_cols = ((GLRMModel.GLRMParameters) GLRM.this._parms)._drop_na20_cols;
                kMeansParameters._score_each_iteration = ((GLRMModel.GLRMParameters) GLRM.this._parms)._score_each_iteration;
                kMeansParameters._init = KMeans.Initialization.PlusPlus;
                kMeansParameters._k = ((GLRMModel.GLRMParameters) GLRM.this._parms)._k;
                kMeansParameters._max_iterations = ((GLRMModel.GLRMParameters) GLRM.this._parms)._max_iterations;
                kMeansParameters._standardize = true;
                kMeansParameters._seed = ((GLRMModel.GLRMParameters) GLRM.this._parms)._seed;
                KMeansModel kMeansModel = null;
                KMeans kMeans = null;
                try {
                    kMeans = new KMeans(kMeansParameters);
                    kMeansModel = (KMeansModel) kMeans.trainModel().get();
                    if (kMeans != null) {
                        kMeans.remove();
                    }
                    if (kMeansModel != null) {
                        kMeansModel.remove();
                    }
                    transform = GLRM.transform(kMeansModel._output._centers_raw, 0, kMeansModel._output._normSub, kMeansModel._output._normMul);
                } catch (Throwable th2) {
                    if (kMeans != null) {
                        kMeans.remove();
                    }
                    if (kMeansModel != null) {
                        kMeansModel.remove();
                    }
                    throw th2;
                }
            }
            double frobenius2 = GLRM.frobenius2(transform);
            if (frobenius2 == 0.0d || Double.isNaN(frobenius2)) {
                GLRM.this.warn("_init", "Initialization failed. Setting initial Y to standard normal random matrix instead...");
                transform = ArrayUtils.gaussianArray(((GLRMModel.GLRMParameters) GLRM.this._parms)._k, GLRM.this._ncolA);
            }
            return transform;
        }

        private boolean isDone(GLRMModel gLRMModel, int i, double d) {
            if (GLRM.this.isRunning() && ((GLRMModel.GLRMOutput) gLRMModel._output)._iterations <= ((GLRMModel.GLRMParameters) GLRM.this._parms)._max_iterations && d > ((GLRMModel.GLRMParameters) GLRM.this._parms)._min_step_size) {
                return ((GLRMModel.GLRMOutput) gLRMModel._output)._iterations > 10 && i > 3 && Math.abs(((GLRMModel.GLRMOutput) gLRMModel._output)._avg_change_obj) < 1.0E-6d;
            }
            return true;
        }

        public Gram.Cholesky regularizedCholesky(Gram gram, int i) {
            int i2 = 0;
            double d = 0.0d;
            Gram.Cholesky cholesky = gram.cholesky(null);
            while (!cholesky.isSPD() && i2 < i) {
                d = d == 0.0d ? 1.0E-5d : d * 10.0d;
                i2++;
                gram.addDiag(d);
                Log.info(new Object[]{"Added L2 regularization = " + d + " to diagonal of X Gram matrix"});
                gram.cholesky(cholesky);
            }
            if (cholesky.isSPD()) {
                return cholesky;
            }
            throw new Gram.NonSPDMatrixException();
        }

        public Gram.Cholesky regularizedCholesky(Gram gram) {
            return regularizedCholesky(gram, 10);
        }

        /* JADX WARN: Type inference failed for: r10v4, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r10v6, types: [java.lang.String[], java.lang.String[][]] */
        /* JADX WARN: Type inference failed for: r11v5, types: [double[], double[][]] */
        public void recoverPCA(GLRMModel gLRMModel, DataInfo dataInfo) {
            Matrix times = new Matrix(regularizedCholesky(((Gram.GramTask) new Gram.GramTask(self(), dataInfo).doAll(dataInfo._adaptedFrame))._gram).getL()).transpose().times(Math.sqrt(GLRM.this._train.numRows()));
            QRDecomposition qRDecomposition = new QRDecomposition(new Matrix(((GLRMModel.GLRMOutput) gLRMModel._output)._archetypes));
            SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(times.times(qRDecomposition.getR().transpose()));
            ((GLRMModel.GLRMOutput) gLRMModel._output)._eigenvectors_raw = qRDecomposition.getQ().times(singularValueDecomposition.getV()).getArray();
            String[] strArr = new String[((GLRMModel.GLRMParameters) GLRM.this._parms)._k];
            String[] strArr2 = new String[((GLRMModel.GLRMParameters) GLRM.this._parms)._k];
            String[] strArr3 = new String[((GLRMModel.GLRMParameters) GLRM.this._parms)._k];
            Arrays.fill(strArr, "double");
            Arrays.fill(strArr2, "%5f");
            for (int i = 0; i < strArr3.length; i++) {
                strArr3[i] = "PC" + String.valueOf(i + 1);
            }
            ((GLRMModel.GLRMOutput) gLRMModel._output)._eigenvectors = new TwoDimTable("Rotation", (String) null, GLRM.this._train.names(), strArr3, strArr, strArr2, "", (String[][]) new String[GLRM.this._train.numCols()], ((GLRMModel.GLRMOutput) gLRMModel._output)._eigenvectors_raw);
            double[] singularValues = singularValueDecomposition.getSingularValues();
            double[] dArr = new double[singularValues.length];
            double[] dArr2 = new double[singularValues.length];
            double d = 0.0d;
            double sqrt = 1.0d / Math.sqrt(GLRM.this._train.numRows() - 1.0d);
            for (int i2 = 0; i2 < singularValues.length; i2++) {
                dArr[i2] = sqrt * singularValues[i2];
                dArr2[i2] = dArr[i2] * dArr[i2];
                d += dArr2[i2];
            }
            ((GLRMModel.GLRMOutput) gLRMModel._output)._std_deviation = dArr;
            double[] dArr3 = new double[singularValues.length];
            double[] dArr4 = new double[singularValues.length];
            int i3 = 0;
            while (i3 < singularValues.length) {
                dArr3[i3] = dArr2[i3] / d;
                dArr4[i3] = i3 == 0 ? dArr3[0] : dArr4[i3 - 1] + dArr3[i3];
                i3++;
            }
            ((GLRMModel.GLRMOutput) gLRMModel._output)._pc_importance = new TwoDimTable("Importance of components", (String) null, new String[]{"Standard deviation", "Proportion of Variance", "Cumulative Proportion"}, strArr3, strArr, strArr2, "", (String[][]) new String[3], (double[][]) new double[]{dArr, dArr3, dArr4});
        }

        protected void compute2() {
            GLRMModel gLRMModel = null;
            DataInfo dataInfo = null;
            DataInfo dataInfo2 = null;
            Frame frame = null;
            try {
                try {
                    ((GLRMModel.GLRMParameters) GLRM.this._parms).read_lock_frames(GLRM.this);
                    GLRM.this.init(true);
                } catch (Throwable th) {
                    if (DKV.getGet(GLRM.this._key)._state != Job.JobState.CANCELLED) {
                        th.printStackTrace();
                        GLRM.this.failed(th);
                        throw th;
                    }
                    Log.info(new Object[]{"Job cancelled by user."});
                    ((GLRMModel.GLRMParameters) GLRM.this._parms).read_unlock_frames(GLRM.this);
                    if (0 != 0) {
                        gLRMModel.unlock(GLRM.this._key);
                    }
                    if (0 != 0) {
                        dataInfo.remove();
                    }
                    if (0 != 0) {
                        dataInfo2.remove();
                    }
                    if (0 != 0) {
                        for (int i = 0; i < GLRM.this._ncolX; i++) {
                            frame.vec(GLRM.idx_xold(i, GLRM.this._ncolA)).remove();
                        }
                    }
                }
                if (GLRM.this.error_count() > 0) {
                    throw new IllegalArgumentException("Found validation errors: " + GLRM.this.validationErrors());
                }
                GLRMModel gLRMModel2 = new GLRMModel(GLRM.this.dest(), (GLRMModel.GLRMParameters) GLRM.this._parms, new GLRMModel.GLRMOutput(GLRM.this));
                gLRMModel2.delete_and_lock(GLRM.this._key);
                double numRows = GLRM.this._train.numRows() * GLRM.this._train.numCols();
                double[][] transpose = ArrayUtils.transpose(initialY());
                Vec[] vecArr = new Vec[GLRM.this._ncolA + (2 * GLRM.this._ncolX)];
                for (int i2 = 0; i2 < GLRM.this._ncolA; i2++) {
                    vecArr[i2] = GLRM.this._train.vec(i2);
                }
                for (int i3 = GLRM.this._ncolA; i3 < vecArr.length; i3++) {
                    vecArr[i3] = GLRM.this._train.anyVec().makeRand(((GLRMModel.GLRMParameters) GLRM.this._parms)._seed);
                }
                Frame frame2 = new Frame((String[]) null, vecArr);
                DataInfo dataInfo3 = new DataInfo(Key.make(), frame2, (Frame) null, 0, false, ((GLRMModel.GLRMParameters) GLRM.this._parms)._transform, DataInfo.TransformType.NONE, true, false);
                DKV.put(dataInfo3._key, dataInfo3);
                ((GLRMModel.GLRMOutput) gLRMModel2._output)._normSub = dataInfo3._normSub == null ? new double[GLRM.this._ncolA] : Arrays.copyOf(dataInfo3._normSub, GLRM.this._ncolA);
                if (dataInfo3._normMul == null) {
                    ((GLRMModel.GLRMOutput) gLRMModel2._output)._normMul = new double[GLRM.this._ncolA];
                    Arrays.fill(((GLRMModel.GLRMOutput) gLRMModel2._output)._normMul, 1.0d);
                } else {
                    ((GLRMModel.GLRMOutput) gLRMModel2._output)._normMul = Arrays.copyOf(dataInfo3._normMul, GLRM.this._ncolA);
                }
                ((GLRMModel.GLRMOutput) gLRMModel2._output)._objective = ((ObjCalc) new ObjCalc((GLRMModel.GLRMParameters) GLRM.this._parms, transpose, GLRM.this._ncolA, GLRM.this._ncolX, ((GLRMModel.GLRMOutput) gLRMModel2._output)._normSub, ((GLRMModel.GLRMOutput) gLRMModel2._output)._normMul).doAll(dataInfo3._adaptedFrame))._loss + (((GLRMModel.GLRMParameters) GLRM.this._parms)._gamma * ((GLRMModel.GLRMParameters) GLRM.this._parms).regularize(transpose));
                ((GLRMModel.GLRMOutput) gLRMModel2._output)._iterations = 0;
                ((GLRMModel.GLRMOutput) gLRMModel2._output)._avg_change_obj = 2.0E-6d;
                boolean z = false;
                double d = ((GLRMModel.GLRMParameters) GLRM.this._parms)._init_step_size;
                int i4 = 0;
                while (!isDone(gLRMModel2, i4, d)) {
                    UpdateX updateX = new UpdateX((GLRMModel.GLRMParameters) GLRM.this._parms, transpose, d / GLRM.this._ncolA, z, GLRM.this._ncolA, GLRM.this._ncolX, ((GLRMModel.GLRMOutput) gLRMModel2._output)._normSub, ((GLRMModel.GLRMOutput) gLRMModel2._output)._normMul);
                    updateX.doAll(dataInfo3._adaptedFrame);
                    UpdateY updateY = new UpdateY((GLRMModel.GLRMParameters) GLRM.this._parms, transpose, d / GLRM.this._ncolA, GLRM.this._ncolA, GLRM.this._ncolX, ((GLRMModel.GLRMOutput) gLRMModel2._output)._normSub, ((GLRMModel.GLRMOutput) gLRMModel2._output)._normMul);
                    double[][] dArr = ((UpdateY) updateY.doAll(dataInfo3._adaptedFrame))._ytnew;
                    double d2 = ((ObjCalc) new ObjCalc((GLRMModel.GLRMParameters) GLRM.this._parms, dArr, GLRM.this._ncolA, GLRM.this._ncolX, ((GLRMModel.GLRMOutput) gLRMModel2._output)._normSub, ((GLRMModel.GLRMOutput) gLRMModel2._output)._normMul).doAll(dataInfo3._adaptedFrame))._loss + (((GLRMModel.GLRMParameters) GLRM.this._parms)._gamma * (updateX._xreg + updateY._yreg));
                    ((GLRMModel.GLRMOutput) gLRMModel2._output)._avg_change_obj = (((GLRMModel.GLRMOutput) gLRMModel2._output)._objective - d2) / numRows;
                    ((GLRMModel.GLRMOutput) gLRMModel2._output)._iterations++;
                    if (((GLRMModel.GLRMOutput) gLRMModel2._output)._avg_change_obj > 0.0d) {
                        transpose = dArr;
                        ((GLRMModel.GLRMOutput) gLRMModel2._output)._objective = d2;
                        d *= 1.05d;
                        i4 = Math.max(1, i4 + 1);
                        z = true;
                    } else {
                        d /= Math.max(1.5d, -i4);
                        i4 = Math.min(0, i4 - 1);
                        z = false;
                    }
                    gLRMModel2.update(GLRM.this._key);
                    GLRM.this.update(1L);
                }
                Vec[] vecArr2 = new Vec[GLRM.this._ncolX];
                for (int i5 = 0; i5 < GLRM.this._ncolX; i5++) {
                    vecArr2[i5] = frame2.vec(GLRM.idx_xnew(i5, GLRM.this._ncolA, GLRM.this._ncolX));
                }
                Frame frame3 = new Frame(((GLRMModel.GLRMParameters) GLRM.this._parms)._loading_key, (String[]) null, vecArr2);
                DataInfo dataInfo4 = new DataInfo(Key.make(), frame3, (Frame) null, 0, false, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false);
                DKV.put(frame3._key, frame3);
                DKV.put(dataInfo4._key, dataInfo4);
                ((GLRMModel.GLRMOutput) gLRMModel2._output)._loading_key = ((GLRMModel.GLRMParameters) GLRM.this._parms)._loading_key;
                ((GLRMModel.GLRMOutput) gLRMModel2._output)._archetypes = transpose;
                ((GLRMModel.GLRMOutput) gLRMModel2._output)._step_size = d;
                if (((GLRMModel.GLRMParameters) GLRM.this._parms)._recover_pca) {
                    recoverPCA(gLRMModel2, dataInfo4);
                }
                GLRM.this.done();
                ((GLRMModel.GLRMParameters) GLRM.this._parms).read_unlock_frames(GLRM.this);
                if (gLRMModel2 != null) {
                    gLRMModel2.unlock(GLRM.this._key);
                }
                if (dataInfo3 != null) {
                    dataInfo3.remove();
                }
                if (dataInfo4 != null) {
                    dataInfo4.remove();
                }
                if (frame2 != null) {
                    for (int i6 = 0; i6 < GLRM.this._ncolX; i6++) {
                        frame2.vec(GLRM.idx_xold(i6, GLRM.this._ncolA)).remove();
                    }
                }
                tryComplete();
            } catch (Throwable th2) {
                ((GLRMModel.GLRMParameters) GLRM.this._parms).read_unlock_frames(GLRM.this);
                if (0 != 0) {
                    gLRMModel.unlock(GLRM.this._key);
                }
                if (0 != 0) {
                    dataInfo.remove();
                }
                if (0 != 0) {
                    dataInfo2.remove();
                }
                if (0 != 0) {
                    for (int i7 = 0; i7 < GLRM.this._ncolX; i7++) {
                        frame.vec(GLRM.idx_xold(i7, GLRM.this._ncolA)).remove();
                    }
                }
                throw th2;
            }
        }

        Key self() {
            return GLRM.this._key;
        }
    }

    /* loaded from: input_file:hex/glrm/GLRM$Initialization.class */
    public enum Initialization {
        SVD,
        PlusPlus,
        User
    }

    /* loaded from: input_file:hex/glrm/GLRM$ObjCalc.class */
    private static class ObjCalc extends MRTask<ObjCalc> {
        GLRMModel.GLRMParameters _parms;
        final double[][] _yt;
        final int _ncolA;
        final int _ncolX;
        final double[] _normSub;
        final double[] _normMul;
        double _loss;
        static final /* synthetic */ boolean $assertionsDisabled;

        ObjCalc(GLRMModel.GLRMParameters gLRMParameters, double[][] dArr, int i, int i2, double[] dArr2, double[] dArr3) {
            if (!$assertionsDisabled && (dArr == null || dArr.length != i || dArr[0].length != i2)) {
                throw new AssertionError();
            }
            this._parms = gLRMParameters;
            this._yt = dArr;
            this._ncolA = i;
            this._ncolX = i2;
            this._normSub = dArr2;
            this._normMul = dArr3;
            this._loss = 0.0d;
        }

        public void map(Chunk[] chunkArr) {
            if (!$assertionsDisabled && this._ncolA + (2 * this._ncolX) != chunkArr.length) {
                throw new AssertionError();
            }
            for (int i = 0; i < chunkArr[0]._len; i++) {
                for (int i2 = 0; i2 < this._ncolA; i2++) {
                    double atd = chunkArr[i2].atd(i);
                    if (!Double.isNaN(atd)) {
                        double d = 0.0d;
                        for (int i3 = 0; i3 < this._ncolX; i3++) {
                            d += GLRM.chk_xnew(chunkArr, i3, this._ncolA, this._ncolX).atd(i) * this._yt[i2][i3];
                        }
                        this._loss += this._parms.loss(d, (atd - this._normSub[i2]) * this._normMul[i2]);
                    }
                }
            }
        }

        static {
            $assertionsDisabled = !GLRM.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/glrm/GLRM$UpdateX.class */
    private static class UpdateX extends MRTask<UpdateX> {
        GLRMModel.GLRMParameters _parms;
        final double _alpha;
        final boolean _update;
        final double[][] _yt;
        final int _ncolA;
        final int _ncolX;
        final double[] _normSub;
        final double[] _normMul;
        double _loss;
        double _xreg;
        static final /* synthetic */ boolean $assertionsDisabled;

        UpdateX(GLRMModel.GLRMParameters gLRMParameters, double[][] dArr, double d, boolean z, int i, int i2, double[] dArr2, double[] dArr3) {
            if (!$assertionsDisabled && (dArr == null || dArr.length != i || dArr[0].length != i2)) {
                throw new AssertionError();
            }
            this._ncolA = i;
            this._ncolX = i2;
            this._normSub = dArr2;
            this._normMul = dArr3;
            this._parms = gLRMParameters;
            this._alpha = d;
            this._update = z;
            this._yt = dArr;
        }

        public void map(Chunk[] chunkArr) {
            if (!$assertionsDisabled && this._ncolA + (2 * this._ncolX) != chunkArr.length) {
                throw new AssertionError();
            }
            double[] dArr = new double[this._ncolA];
            this._xreg = 0.0d;
            this._loss = 0.0d;
            for (int i = 0; i < chunkArr[0]._len; i++) {
                double[] dArr2 = new double[this._ncolX];
                double[] dArr3 = new double[this._ncolX];
                if (this._update) {
                    for (int i2 = 0; i2 < this._ncolX; i2++) {
                        GLRM.chk_xold(chunkArr, i2, this._ncolA).set(i, GLRM.chk_xnew(chunkArr, i2, this._ncolA, this._ncolX).atd(i));
                    }
                }
                for (int i3 = 0; i3 < this._ncolA; i3++) {
                    dArr[i3] = chunkArr[i3].atd(i);
                    if (!Double.isNaN(dArr[i3])) {
                        double d = 0.0d;
                        for (int i4 = 0; i4 < this._ncolX; i4++) {
                            d += GLRM.chk_xold(chunkArr, i4, this._ncolA).atd(i) * this._yt[i3][i4];
                        }
                        double lgrad = this._parms.lgrad(d, (dArr[i3] - this._normSub[i3]) * this._normMul[i3]);
                        for (int i5 = 0; i5 < this._ncolX; i5++) {
                            int i6 = i5;
                            dArr2[i6] = dArr2[i6] + (lgrad * this._yt[i3][i5]);
                        }
                    }
                }
                for (int i7 = 0; i7 < this._ncolX; i7++) {
                    dArr3[i7] = this._parms.rproxgrad(GLRM.chk_xold(chunkArr, i7, this._ncolA).atd(i) - (this._alpha * dArr2[i7]), this._alpha);
                    GLRM.chk_xnew(chunkArr, i7, this._ncolA, this._ncolX).set(i, dArr3[i7]);
                    this._xreg += this._parms.regularize(dArr3[i7]);
                }
                for (int i8 = 0; i8 < this._ncolA; i8++) {
                    if (!Double.isNaN(dArr[i8])) {
                        this._loss += this._parms.loss(ArrayUtils.innerProduct(dArr3, this._yt[i8]), dArr[i8]);
                    }
                }
            }
        }

        public void reduce(UpdateX updateX) {
            this._loss += updateX._loss;
            this._xreg += updateX._xreg;
        }

        static {
            $assertionsDisabled = !GLRM.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/glrm/GLRM$UpdateY.class */
    private static class UpdateY extends MRTask<UpdateY> {
        GLRMModel.GLRMParameters _parms;
        final double _alpha;
        final double[][] _ytold;
        final int _ncolA;
        final int _ncolX;
        final double[] _normSub;
        final double[] _normMul;
        double[][] _ytnew;
        double _yreg;
        static final /* synthetic */ boolean $assertionsDisabled;

        UpdateY(GLRMModel.GLRMParameters gLRMParameters, double[][] dArr, double d, int i, int i2, double[] dArr2, double[] dArr3) {
            if (!$assertionsDisabled && (dArr == null || dArr.length != i || dArr[0].length != i2)) {
                throw new AssertionError();
            }
            this._parms = gLRMParameters;
            this._alpha = d;
            this._ncolA = i;
            this._ncolX = i2;
            this._normSub = dArr2;
            this._normMul = dArr3;
            this._ytold = dArr;
            this._yreg = 0.0d;
        }

        public void map(Chunk[] chunkArr) {
            if (!$assertionsDisabled && this._ncolA + (2 * this._ncolX) != chunkArr.length) {
                throw new AssertionError();
            }
            this._ytnew = new double[this._ncolA][this._ncolX];
            for (int i = 0; i < this._ncolA; i++) {
                for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                    double atd = chunkArr[i].atd(i2);
                    if (!Double.isNaN(atd)) {
                        double d = 0.0d;
                        for (int i3 = 0; i3 < this._ncolX; i3++) {
                            d += GLRM.chk_xnew(chunkArr, i3, this._ncolA, this._ncolX).atd(i2) * this._ytold[i][i3];
                        }
                        double lgrad = this._parms.lgrad(d, (atd - this._normSub[i]) * this._normMul[i]);
                        for (int i4 = 0; i4 < this._ncolX; i4++) {
                            double[] dArr = this._ytnew[i];
                            int i5 = i4;
                            dArr[i5] = dArr[i5] + (lgrad * GLRM.chk_xnew(chunkArr, i4, this._ncolA, this._ncolX).atd(i2));
                        }
                    }
                }
            }
        }

        public void reduce(UpdateY updateY) {
            ArrayUtils.add(this._ytnew, updateY._ytnew);
        }

        protected void postGlobal() {
            for (int i = 0; i < this._ncolA; i++) {
                for (int i2 = 0; i2 < this._ncolX; i2++) {
                    this._ytnew[i][i2] = this._parms.rproxgrad(this._ytold[i][i2] - (this._alpha * this._ytnew[i][i2]), this._alpha);
                    this._yreg += this._parms.regularize(this._ytnew[i][i2]);
                }
            }
        }

        static {
            $assertionsDisabled = !GLRM.class.desiredAssertionStatus();
        }
    }

    public ModelBuilderSchema schema() {
        return new GLRMV3();
    }

    public Job<GLRMModel> trainModel() {
        return start(new GLRMDriver(), 0L);
    }

    public Model.ModelCategory[] can_build() {
        return new Model.ModelCategory[]{Model.ModelCategory.Clustering};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public GLRM(GLRMModel.GLRMParameters gLRMParameters) {
        super("GLRM", gLRMParameters);
        this.TOLERANCE = 1.0E-6d;
        init(false);
    }

    public void init(boolean z) {
        super.init(z);
        if (((GLRMModel.GLRMParameters) this._parms)._loading_key == null) {
            ((GLRMModel.GLRMParameters) this._parms)._loading_key = Key.make("GLRMLoading_" + Key.rand());
        }
        if (((GLRMModel.GLRMParameters) this._parms)._gamma < 0.0d) {
            error("_gamma", "gambda must be a non-negative number");
        }
        if (((GLRMModel.GLRMParameters) this._parms)._max_iterations < 1 || ((GLRMModel.GLRMParameters) this._parms)._max_iterations > 1000000.0d) {
            error("_max_iterations", "max_iterations must be between 1 and 1e6 inclusive");
        }
        if (((GLRMModel.GLRMParameters) this._parms)._init_step_size <= 0.0d) {
            error("_init_step_size", "init_step_size must be a positive number");
        }
        if (((GLRMModel.GLRMParameters) this._parms)._min_step_size < 0.0d || ((GLRMModel.GLRMParameters) this._parms)._min_step_size > ((GLRMModel.GLRMParameters) this._parms)._init_step_size) {
            error("_min_step_size", "min_step_size must be between 0 and " + ((GLRMModel.GLRMParameters) this._parms)._init_step_size);
        }
        if (this._train == null) {
            return;
        }
        if (this._train.numCols() < 2) {
            error("_train", "_train must have more than one column");
        }
        int min = (int) Math.min(this._train.numCols(), this._train.numRows());
        if (((GLRMModel.GLRMParameters) this._parms)._k < 1 || ((GLRMModel.GLRMParameters) this._parms)._k > min) {
            error("_k", "_k must be between 1 and " + min);
        }
        if (null != ((GLRMModel.GLRMParameters) this._parms)._user_points) {
            if (((GLRMModel.GLRMParameters) this._parms)._user_points.get().numCols() != this._train.numCols()) {
                error("_user_points", "The user-specified points must have the same number of columns (" + this._train.numCols() + ") as the training observations");
            } else if (((GLRMModel.GLRMParameters) this._parms)._user_points.get().numRows() != ((GLRMModel.GLRMParameters) this._parms)._k) {
                error("_user_points", "The user-specified points must have k = " + ((GLRMModel.GLRMParameters) this._parms)._k + " rows");
            } else {
                int i = 0;
                Vec[] vecs = ((GLRMModel.GLRMParameters) this._parms)._user_points.get().vecs();
                for (int i2 = 0; i2 < this._train.numCols(); i2++) {
                    if (vecs[i2].isConst() && vecs[i2].max() == 0.0d) {
                        i++;
                    }
                }
                if (i == this._train.numCols()) {
                    error("_user_points", "The user-specified points cannot all be zero");
                }
            }
        }
        for (Vec vec : this._train.vecs()) {
            if (!vec.isNumeric()) {
                throw H2O.unimpl();
            }
        }
        this._ncolA = this._train.numCols();
        this._ncolX = ((GLRMModel.GLRMParameters) this._parms)._k;
    }

    public static double frobenius2(double[][] dArr) {
        if (dArr == null) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[0].length; i2++) {
                d += dArr[i][i2] * dArr[i][i2];
            }
        }
        return d;
    }

    public static double[][] transform(double[][] dArr, int i, double[] dArr2, double[] dArr3) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[][] dArr4 = new double[length][length2];
        double[] malloc8d = dArr2 == null ? MemoryManager.malloc8d(length2) : dArr2;
        double[] malloc8d2 = dArr3 == null ? MemoryManager.malloc8d(length2) : dArr3;
        for (int i2 = 0; i2 < length; i2++) {
            System.arraycopy(dArr[i2], 0, dArr4[i2], 0, length2);
            for (int i3 = i; i3 < length2; i3++) {
                dArr4[i2][i3] = (dArr4[i2][i3] - malloc8d[i3]) * malloc8d2[i3];
            }
        }
        return dArr4;
    }

    protected static int idx_xold(int i, int i2) {
        return i2 + i;
    }

    protected static int idx_xnew(int i, int i2, int i3) {
        return i2 + i3 + i;
    }

    protected static Chunk chk_xold(Chunk[] chunkArr, int i, int i2) {
        return chunkArr[i2 + i];
    }

    protected static Chunk chk_xnew(Chunk[] chunkArr, int i, int i2, int i3) {
        return chunkArr[i2 + i3 + i];
    }
}
