package hex.svd;

import hex.DataInfo;
import hex.FrameTask;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.gram.Gram;
import hex.schemas.ModelBuilderSchema;
import hex.schemas.SVDV99;
import hex.svd.SVDModel;
import java.util.Arrays;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/svd/SVD.class */
public class SVD extends ModelBuilder<SVDModel, SVDModel.SVDParameters, SVDModel.SVDOutput> {
    private final double TOLERANCE = 1.0E-6d;
    private final int MAX_COLS_EXPANDED = 5000;
    private transient int _ncolExp;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/svd/SVD$CalcSigmaU.class */
    private static class CalcSigmaU extends FrameTask<CalcSigmaU> {
        final double[] _svec;
        public double _sval;
        public long _nobs;

        public CalcSigmaU(Key key, DataInfo dataInfo, double[] dArr) {
            super(key, dataInfo);
            this._svec = dArr;
            this._sval = 0.0d;
        }

        @Override // hex.FrameTask
        protected void processRow(long j, DataInfo.Row row, NewChunk[] newChunkArr) {
            double innerProduct = row.innerProduct(this._svec);
            newChunkArr[0].addNum(innerProduct);
            this._sval += innerProduct * innerProduct;
            this._nobs++;
        }

        public void reduce(CalcSigmaU calcSigmaU) {
            this._nobs += calcSigmaU._nobs;
            this._sval += calcSigmaU._sval;
        }

        protected void postGlobal() {
            this._sval = Math.sqrt(this._sval);
        }
    }

    /* loaded from: input_file:hex/svd/SVD$GramUpdate.class */
    private static class GramUpdate extends FrameTask<GramUpdate> {
        final double[][] _ivv;
        public Gram _gram;
        public long _nobs;
        static final /* synthetic */ boolean $assertionsDisabled;

        public GramUpdate(Key key, DataInfo dataInfo, double[][] dArr) {
            super(key, dataInfo);
            if (!$assertionsDisabled && (null == dArr || dArr.length != dArr[0].length)) {
                throw new AssertionError();
            }
            this._ivv = dArr;
        }

        @Override // hex.FrameTask
        protected boolean chunkInit() {
            this._gram = new Gram(this._dinfo.fullN(), 0, this._ivv.length, 0, false);
            return true;
        }

        @Override // hex.FrameTask
        protected void processRow(long j, DataInfo.Row row) {
            double[] dArr = new double[this._ivv.length];
            for (int i = 0; i < this._ivv.length; i++) {
                dArr[i] = row.innerProduct(this._ivv[i]);
            }
            this._gram.addRow(this._dinfo.newDenseRow(dArr), 1.0d);
            this._nobs++;
        }

        @Override // hex.FrameTask
        protected void chunkDone(long j) {
            this._gram.mul(1.0d / this._nobs);
        }

        public void reduce(GramUpdate gramUpdate) {
            this._gram.mul(this._nobs / (this._nobs + gramUpdate._nobs));
            gramUpdate._gram.mul(gramUpdate._nobs / (this._nobs + gramUpdate._nobs));
            this._gram.add(gramUpdate._gram);
            this._nobs += gramUpdate._nobs;
        }

        static {
            $assertionsDisabled = !SVD.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/svd/SVD$SVDDriver.class */
    class SVDDriver extends H2O.H2OCountedCompleter<SVDDriver> {
        static final /* synthetic */ boolean $assertionsDisabled;

        SVDDriver() {
        }

        /* JADX WARN: Type inference failed for: r0v172, types: [hex.svd.SVD$SVDDriver$1] */
        protected void compute2() {
            SVDModel sVDModel = null;
            DataInfo dataInfo = null;
            Frame frame = null;
            Vec[] vecArr = null;
            try {
                try {
                    SVD.this.init(true);
                    ((SVDModel.SVDParameters) SVD.this._parms).read_lock_frames(SVD.this);
                } catch (Throwable th) {
                    if (DKV.getGet(SVD.this._key)._state != Job.JobState.CANCELLED) {
                        th.printStackTrace();
                        SVD.this.failed(th);
                        throw th;
                    }
                    Log.info(new Object[]{"Job cancelled by user."});
                    if (0 != 0) {
                        sVDModel.unlock(SVD.this._key);
                    }
                    if (0 != 0) {
                        dataInfo.remove();
                    }
                    if ((0 != 0) & (!((SVDModel.SVDParameters) SVD.this._parms)._keep_u)) {
                        frame.delete();
                    }
                    ((SVDModel.SVDParameters) SVD.this._parms).read_unlock_frames(SVD.this);
                }
                if (SVD.this.error_count() > 0) {
                    throw new IllegalArgumentException("Found validation errors: " + SVD.this.validationErrors());
                }
                SVDModel sVDModel2 = new SVDModel(SVD.this.dest(), (SVDModel.SVDParameters) SVD.this._parms, new SVDModel.SVDOutput(SVD.this));
                sVDModel2.delete_and_lock(self());
                DataInfo dataInfo2 = new DataInfo(Key.make(), SVD.this._train, SVD.this._valid, 0, ((SVDModel.SVDParameters) SVD.this._parms)._use_all_factor_levels, ((SVDModel.SVDParameters) SVD.this._parms)._transform, DataInfo.TransformType.NONE, true, false, false, false, false, false);
                DKV.put(dataInfo2._key, dataInfo2);
                ((SVDModel.SVDOutput) sVDModel2._output)._normSub = dataInfo2._normSub == null ? new double[dataInfo2._nums] : dataInfo2._normSub;
                if (dataInfo2._normMul == null) {
                    ((SVDModel.SVDOutput) sVDModel2._output)._normMul = new double[dataInfo2._nums];
                    Arrays.fill(((SVDModel.SVDOutput) sVDModel2._output)._normMul, 1.0d);
                } else {
                    ((SVDModel.SVDOutput) sVDModel2._output)._normMul = dataInfo2._normMul;
                }
                ((SVDModel.SVDOutput) sVDModel2._output)._permutation = dataInfo2._permutation;
                ((SVDModel.SVDOutput) sVDModel2._output)._nnums = dataInfo2._nums;
                ((SVDModel.SVDOutput) sVDModel2._output)._ncats = dataInfo2._cats;
                ((SVDModel.SVDOutput) sVDModel2._output)._catOffsets = dataInfo2._catOffsets;
                ((SVDModel.SVDOutput) sVDModel2._output)._names_expanded = dataInfo2.coefNames();
                Gram.GramTask gramTask = (Gram.GramTask) new Gram.GramTask(self(), dataInfo2).doAll(dataInfo2._adaptedFrame);
                Gram gram = gramTask._gram;
                if (!$assertionsDisabled && gram.fullN() != SVD.this._ncolExp) {
                    throw new AssertionError();
                }
                ((SVDModel.SVDOutput) sVDModel2._output)._nobs = gramTask._nobs;
                ((SVDModel.SVDOutput) sVDModel2._output)._v = new double[((SVDModel.SVDParameters) SVD.this._parms)._nv][SVD.this._ncolExp];
                ((SVDModel.SVDOutput) sVDModel2._output)._total_variance = (gram.diagSum() * gramTask._nobs) / (gramTask._nobs - 1);
                sVDModel2.update(self());
                SVD.this.update(1L);
                ((SVDModel.SVDOutput) sVDModel2._output)._v[0] = SVD.this.powerLoop(gram, ((SVDModel.SVDParameters) SVD.this._parms)._seed);
                double[][] dArr = new double[SVD.this._ncolExp][SVD.this._ncolExp];
                for (int i = 0; i < SVD.this._ncolExp; i++) {
                    dArr[i][i] = 1.0d;
                }
                if (!((SVDModel.SVDParameters) SVD.this._parms)._only_v) {
                    ((SVDModel.SVDOutput) sVDModel2._output)._d = new double[((SVDModel.SVDParameters) SVD.this._parms)._nv];
                    ((SVDModel.SVDOutput) sVDModel2._output)._u_key = Key.make(((SVDModel.SVDParameters) SVD.this._parms)._u_name);
                    vecArr = new Vec[((SVDModel.SVDParameters) SVD.this._parms)._nv];
                    CalcSigmaU calcSigmaU = (CalcSigmaU) new CalcSigmaU(self(), dataInfo2, ArrayUtils.multArrVec(dArr, ((SVDModel.SVDOutput) sVDModel2._output)._v[0])).doAll(1, dataInfo2._adaptedFrame);
                    ((SVDModel.SVDOutput) sVDModel2._output)._d[0] = calcSigmaU._sval;
                    if (!$assertionsDisabled && calcSigmaU._nobs != ((SVDModel.SVDOutput) sVDModel2._output)._nobs) {
                        throw new AssertionError("Processed " + calcSigmaU._nobs + " rows but expected " + ((SVDModel.SVDOutput) sVDModel2._output)._nobs);
                    }
                    Frame outputFrame = calcSigmaU.outputFrame();
                    vecArr[0] = outputFrame.vec(0);
                    outputFrame.unlock(self());
                }
                sVDModel2.update(self());
                SVD.this.update(1L);
                SVD.updateIVVSum(dArr, ((SVDModel.SVDOutput) sVDModel2._output)._v[0]);
                Gram gram2 = ((GramUpdate) new GramUpdate(self(), dataInfo2, dArr).doAll(dataInfo2._adaptedFrame))._gram;
                for (int i2 = 1; i2 < ((SVDModel.SVDParameters) SVD.this._parms)._nv && SVD.this.isRunning(); i2++) {
                    ((SVDModel.SVDOutput) sVDModel2._output)._v[i2] = SVD.this.powerLoop(gram2, ((SVDModel.SVDParameters) SVD.this._parms)._seed);
                    if (!((SVDModel.SVDParameters) SVD.this._parms)._only_v) {
                        CalcSigmaU calcSigmaU2 = (CalcSigmaU) new CalcSigmaU(self(), dataInfo2, ArrayUtils.multArrVec(dArr, ((SVDModel.SVDOutput) sVDModel2._output)._v[i2])).doAll(1, dataInfo2._adaptedFrame);
                        ((SVDModel.SVDOutput) sVDModel2._output)._d[i2] = calcSigmaU2._sval;
                        if (!$assertionsDisabled && calcSigmaU2._nobs != ((SVDModel.SVDOutput) sVDModel2._output)._nobs) {
                            throw new AssertionError("Processed " + calcSigmaU2._nobs + " rows but expected " + ((SVDModel.SVDOutput) sVDModel2._output)._nobs);
                        }
                        Frame outputFrame2 = calcSigmaU2.outputFrame();
                        vecArr[i2] = outputFrame2.vec(0);
                        outputFrame2.unlock(self());
                    }
                    SVD.updateIVVSum(dArr, ((SVDModel.SVDOutput) sVDModel2._output)._v[i2]);
                    gram2 = ((GramUpdate) new GramUpdate(self(), dataInfo2, dArr).doAll(dataInfo2._adaptedFrame))._gram;
                    sVDModel2.update(self());
                    SVD.this.update(1L);
                }
                ((SVDModel.SVDOutput) sVDModel2._output)._v = ArrayUtils.transpose(((SVDModel.SVDOutput) sVDModel2._output)._v);
                if (!((SVDModel.SVDParameters) SVD.this._parms)._only_v && ((SVDModel.SVDParameters) SVD.this._parms)._keep_u) {
                    frame = new Frame(((SVDModel.SVDOutput) sVDModel2._output)._u_key, (String[]) null, vecArr);
                    DKV.put(frame._key, frame);
                    final double[] dArr2 = ((SVDModel.SVDOutput) sVDModel2._output)._d;
                    new MRTask() { // from class: hex.svd.SVD.SVDDriver.1
                        public void map(Chunk[] chunkArr) {
                            for (int i3 = 0; i3 < chunkArr.length; i3++) {
                                for (int i4 = 0; i4 < chunkArr[0].len(); i4++) {
                                    chunkArr[i3].set(i4, chunkArr[i3].atd(i4) / dArr2[i3]);
                                }
                            }
                        }
                    }.doAll(frame);
                }
                sVDModel2.update(self());
                SVD.this.done();
                if (sVDModel2 != null) {
                    sVDModel2.unlock(SVD.this._key);
                }
                if (dataInfo2 != null) {
                    dataInfo2.remove();
                }
                if ((frame != null) & (!((SVDModel.SVDParameters) SVD.this._parms)._keep_u)) {
                    frame.delete();
                }
                ((SVDModel.SVDParameters) SVD.this._parms).read_unlock_frames(SVD.this);
                tryComplete();
            } catch (Throwable th2) {
                if (0 != 0) {
                    sVDModel.unlock(SVD.this._key);
                }
                if (0 != 0) {
                    dataInfo.remove();
                }
                if ((0 != 0) & (!((SVDModel.SVDParameters) SVD.this._parms)._keep_u)) {
                    frame.delete();
                }
                ((SVDModel.SVDParameters) SVD.this._parms).read_unlock_frames(SVD.this);
                throw th2;
            }
        }

        Key self() {
            return SVD.this._key;
        }

        static {
            $assertionsDisabled = !SVD.class.desiredAssertionStatus();
        }
    }

    public ModelBuilderSchema schema() {
        return new SVDV99();
    }

    public Job<SVDModel> trainModelImpl(long j) {
        return start(new SVDDriver(), j);
    }

    public long progressUnits() {
        return ((SVDModel.SVDParameters) this._parms)._nv + 1;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.DimReduction};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public SVD(SVDModel.SVDParameters sVDParameters) {
        super("SVD", sVDParameters);
        this.TOLERANCE = 1.0E-6d;
        this.MAX_COLS_EXPANDED = 5000;
        init(false);
    }

    public void init(boolean z) {
        super.init(z);
        if (((SVDModel.SVDParameters) this._parms)._u_name == null || ((SVDModel.SVDParameters) this._parms)._u_name.length() == 0) {
            ((SVDModel.SVDParameters) this._parms)._u_name = "SVDUMatrix_" + Key.rand();
        }
        if (((SVDModel.SVDParameters) this._parms)._max_iterations < 1) {
            error("_max_iterations", "max_iterations must be at least 1");
        }
        if (this._train == null) {
            return;
        }
        this._ncolExp = this._train.numColsExp(((SVDModel.SVDParameters) this._parms)._use_all_factor_levels, false);
        if (this._ncolExp > 5000) {
            warn("_train", "_train has " + this._ncolExp + " columns when categoricals are expanded. Algorithm may be slow.");
        }
        if (((SVDModel.SVDParameters) this._parms)._nv < 1 || ((SVDModel.SVDParameters) this._parms)._nv > this._ncolExp) {
            error("_nv", "Number of right singular values must be between 1 and " + this._ncolExp);
        }
    }

    public double[] powerLoop(Gram gram, long j) {
        return powerLoop(gram, ArrayUtils.gaussianVector(gram.fullN(), j));
    }

    public double[] powerLoop(Gram gram, double[] dArr) {
        if (!$assertionsDisabled && dArr.length != gram.fullN()) {
            throw new AssertionError();
        }
        double d = 2.0E-6d;
        double[] dArr2 = (double[]) dArr.clone();
        double[] dArr3 = new double[dArr2.length];
        for (int i = 0; i < ((SVDModel.SVDParameters) this._parms)._max_iterations && d > 1.0E-6d; i++) {
            gram.mul(dArr2, dArr3);
            double l2norm = ArrayUtils.l2norm(dArr3);
            double d2 = 0.0d;
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                int i3 = i2;
                dArr3[i3] = dArr3[i3] / l2norm;
                double d3 = dArr2[i2] - dArr3[i2];
                d2 += d3 * d3;
                dArr2[i2] = dArr3[i2];
            }
            d = Math.sqrt(d2);
        }
        return dArr2;
    }

    public static double[][] updateIVVSum(double[][] dArr, double[] dArr2) {
        for (int i = 0; i < dArr2.length; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                double d = dArr[i][i2] - (dArr2[i] * dArr2[i2]);
                dArr[i2][i] = d;
                dArr[i][i2] = d;
            }
            double[] dArr3 = dArr[i];
            int i3 = i;
            dArr3[i3] = dArr3[i3] - (dArr2[i] * dArr2[i]);
        }
        return dArr;
    }

    static {
        $assertionsDisabled = !SVD.class.desiredAssertionStatus();
    }
}
