package hex.svd;

import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsUnsupervised;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.codegen.CodeGeneratorPipeline;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.JCodeGen;
import water.util.SBPrintStream;

/* loaded from: input_file:hex/svd/SVDModel.class */
public class SVDModel extends Model<SVDModel, SVDParameters, SVDOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/svd/SVDModel$ModelMetricsSVD.class */
    public static class ModelMetricsSVD extends ModelMetricsUnsupervised {

        /* loaded from: input_file:hex/svd/SVDModel$ModelMetricsSVD$SVDModelMetrics.class */
        public static class SVDModelMetrics extends ModelMetricsUnsupervised.MetricBuilderUnsupervised {
            public SVDModelMetrics(int i) {
                this._work = new double[i];
            }

            public double[] perRow(double[] dArr, float[] fArr, Model model) {
                return dArr;
            }

            public ModelMetrics makeModelMetrics(Model model, Frame frame, Frame frame2, Frame frame3) {
                return model._output.addModelMetrics(new ModelMetricsSVD(model, frame));
            }
        }

        public ModelMetricsSVD(Model model, Frame frame) {
            super(model, frame, 0L, Double.NaN);
        }
    }

    /* loaded from: input_file:hex/svd/SVDModel$SVDOutput.class */
    public static class SVDOutput extends Model.Output {
        public int _iterations;
        public double[][] _v;
        public Key<Frame> _v_key;
        public double[] _d;
        public Key<Frame> _u_key;
        public int _ncats;
        public int _nnums;
        public long _nobs;
        public double _total_variance;
        public int[] _catOffsets;
        public double[] _normSub;
        public double[] _normMul;
        public int[] _permutation;
        public String[] _names_expanded;

        public SVDOutput(SVD svd) {
            super(svd);
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.DimReduction;
        }
    }

    /* loaded from: input_file:hex/svd/SVDModel$SVDParameters.class */
    public static class SVDParameters extends Model.Parameters {
        public String _u_name;
        public String _v_name;
        public DataInfo.TransformType _transform = DataInfo.TransformType.NONE;
        public Method _svd_method = Method.GramSVD;
        public int _nv = 1;
        public int _max_iterations = 1000;
        public boolean _keep_u = true;
        public boolean _save_v_frame = true;
        public boolean _only_v = false;
        public boolean _use_all_factor_levels = true;
        public boolean _impute_missing = false;

        /* loaded from: input_file:hex/svd/SVDModel$SVDParameters$Method.class */
        public enum Method {
            GramSVD,
            Power,
            Randomized
        }

        public String algoName() {
            return "SVD";
        }

        public String fullName() {
            return "Singular Value Decomposition";
        }

        public String javaName() {
            return SVDModel.class.getName();
        }

        public long progressUnits() {
            switch (this._svd_method) {
                case GramSVD:
                    return 2L;
                case Power:
                    return 1 + this._nv;
                case Randomized:
                    return 5 + this._max_iterations;
                default:
                    return this._nv;
            }
        }
    }

    public SVDModel(Key key, SVDParameters sVDParameters, SVDOutput sVDOutput) {
        super(key, sVDParameters, sVDOutput);
    }

    protected Futures remove_impl(Futures futures) {
        if (null != ((SVDOutput) this._output)._u_key) {
            ((SVDOutput) this._output)._u_key.remove(futures);
        }
        if (null != ((SVDOutput) this._output)._v_key) {
            ((SVDOutput) this._output)._v_key.remove(futures);
        }
        return super.remove_impl(futures);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer autoBuffer) {
        autoBuffer.putKey(((SVDOutput) this._output)._u_key);
        autoBuffer.putKey(((SVDOutput) this._output)._v_key);
        return super.writeAll_impl(autoBuffer);
    }

    protected Keyed readAll_impl(AutoBuffer autoBuffer, Futures futures) {
        autoBuffer.getKey(((SVDOutput) this._output)._u_key, futures);
        autoBuffer.getKey(((SVDOutput) this._output)._v_key, futures);
        return super.readAll_impl(autoBuffer, futures);
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        return new ModelMetricsSVD.SVDModelMetrics(((SVDParameters) this._parms)._nv);
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [hex.svd.SVDModel$1] */
    protected Frame predictScoreImpl(Frame frame, Frame frame2, String str, final Job job) {
        Frame frame3 = new Frame(frame2);
        for (int i = 0; i < ((SVDParameters) this._parms)._nv; i++) {
            frame3.add("PC" + String.valueOf(i + 1), frame3.anyVec().makeZero());
        }
        new MRTask() { // from class: hex.svd.SVDModel.1
            public void map(Chunk[] chunkArr) {
                if (isCancelled()) {
                    return;
                }
                if (job == null || !job.stop_requested()) {
                    double[] dArr = new double[((SVDOutput) SVDModel.this._output)._names.length];
                    double[] dArr2 = new double[((SVDParameters) SVDModel.this._parms)._nv];
                    for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                        double[] score0 = SVDModel.this.score0(chunkArr, i2, dArr, dArr2);
                        for (int i3 = 0; i3 < dArr2.length; i3++) {
                            chunkArr[((SVDOutput) SVDModel.this._output)._names.length + i3].set(i2, score0[i3]);
                        }
                    }
                    if (job != null) {
                        job.update(1L);
                    }
                }
            }
        }.doAll(frame3);
        Frame extractFrame = frame3.extractFrame(((SVDOutput) this._output)._names.length, frame3.numCols());
        Frame frame4 = new Frame(null == str ? Key.make() : Key.make(str), extractFrame.names(), extractFrame.vecs());
        DKV.put(frame4);
        makeMetricBuilder(null).makeModelMetrics(this, frame, (Frame) null, (Frame) null);
        return frame4;
    }

    protected double[] score0(double[] dArr, double[] dArr2) {
        int i = ((SVDOutput) this._output)._catOffsets[((SVDOutput) this._output)._catOffsets.length - 1];
        if (!$assertionsDisabled && dArr.length != ((SVDOutput) this._output)._permutation.length) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < ((SVDParameters) this._parms)._nv; i2++) {
            dArr2[i2] = 0.0d;
            for (int i3 = 0; i3 < ((SVDOutput) this._output)._ncats; i3++) {
                double d = dArr[((SVDOutput) this._output)._permutation[i3]];
                int i4 = (((SVDOutput) this._output)._catOffsets[i3 + 1] - ((SVDOutput) this._output)._catOffsets[i3]) - 1;
                int i5 = Double.isNaN(d) ? i4 : ((int) d) - (((SVDParameters) this._parms)._use_all_factor_levels ? 0 : 1);
                if (i5 >= 0 && i5 <= i4) {
                    int i6 = i2;
                    dArr2[i6] = dArr2[i6] + ((SVDOutput) this._output)._v[((SVDOutput) this._output)._catOffsets[i3] + i5][i2];
                }
            }
            int i7 = ((SVDOutput) this._output)._ncats;
            int i8 = i;
            for (int i9 = 0; i9 < ((SVDOutput) this._output)._nnums; i9++) {
                int i10 = i2;
                dArr2[i10] = dArr2[i10] + ((dArr[((SVDOutput) this._output)._permutation[i7]] - ((SVDOutput) this._output)._normSub[i9]) * ((SVDOutput) this._output)._normMul[i9] * ((SVDOutput) this._output)._v[i8][i2]);
                i7++;
                i8++;
            }
        }
        return dArr2;
    }

    protected SBPrintStream toJavaInit(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline) {
        SBPrintStream javaInit = super.toJavaInit(sBPrintStream, codeGeneratorPipeline);
        javaInit.ip("public boolean isSupervised() { return " + isSupervised() + "; }").nl();
        javaInit.ip("public int nfeatures() { return " + ((SVDOutput) this._output).nfeatures() + "; }").nl();
        javaInit.ip("public int nclasses() { return " + ((SVDParameters) this._parms)._nv + "; }").nl();
        if (((SVDOutput) this._output)._nnums > 0) {
            JCodeGen.toStaticVar(javaInit, "NORMMUL", ((SVDOutput) this._output)._normMul, "Standardization/Normalization scaling factor for numerical variables.");
            JCodeGen.toStaticVar(javaInit, "NORMSUB", ((SVDOutput) this._output)._normSub, "Standardization/Normalization offset for numerical variables.");
        }
        JCodeGen.toStaticVar(javaInit, "CATOFFS", ((SVDOutput) this._output)._catOffsets, "Categorical column offsets.");
        JCodeGen.toStaticVar(javaInit, "PERMUTE", ((SVDOutput) this._output)._permutation, "Permutation index vector.");
        JCodeGen.toStaticVar(javaInit, "EIGVECS", ((SVDOutput) this._output)._v, "Eigenvector matrix.");
        return javaInit;
    }

    protected void toJavaPredictBody(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline, CodeGeneratorPipeline codeGeneratorPipeline2, boolean z) {
        sBPrintStream.i().p("java.util.Arrays.fill(preds,0);").nl();
        int i = ((SVDOutput) this._output)._ncats;
        int i2 = ((SVDOutput) this._output)._nnums;
        sBPrintStream.i().p("final int nstart = CATOFFS[CATOFFS.length-1];").nl();
        sBPrintStream.i().p("for(int i = 0; i < ").p(((SVDParameters) this._parms)._nv).p("; i++) {").nl();
        sBPrintStream.i(1).p("for(int j = 0; j < ").p(i).p("; j++) {").nl();
        sBPrintStream.i(2).p("double d = data[PERMUTE[j]];").nl();
        sBPrintStream.i(2).p("int last = CATOFFS[j+1]-CATOFFS[j]-1;").nl();
        sBPrintStream.i(2).p("int c = Double.isNaN(d) ? last : (int)d").p(((SVDParameters) this._parms)._use_all_factor_levels ? ";" : "-1;").nl();
        sBPrintStream.i(2).p("if(c < 0 || c > last) continue;").nl();
        sBPrintStream.i(2).p("preds[i] += EIGVECS[CATOFFS[j]+c][i];").nl();
        sBPrintStream.i(1).p("}").nl();
        sBPrintStream.i(1).p("for(int j = 0; j < ").p(i2).p("; j++) {").nl();
        sBPrintStream.i(2).p("preds[i] += (data[PERMUTE[j" + (i > 0 ? "+" + i : "") + "]]-NORMSUB[j])*NORMMUL[j]*EIGVECS[j" + (i > 0 ? "+ nstart" : "") + "][i];").nl();
        sBPrintStream.i(1).p("}").nl();
        sBPrintStream.i().p("}").nl();
    }

    static {
        $assertionsDisabled = !SVDModel.class.desiredAssertionStatus();
    }
}
