package hex.glm;

import hex.FrameTask;
import hex.glm.GLMModel;
import hex.glm.GLMTask;
import hex.glm.GLMValidation;
import hex.gram.Gram;
import java.util.ArrayList;
import java.util.Arrays;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/glm/GLMTask.class */
public abstract class GLMTask<T extends GLMTask<T>> extends FrameTask<T> {
    protected final GLMModel.GLMParameters _glm;

    /* loaded from: input_file:hex/glm/GLMTask$ColGradientTask.class */
    public static class ColGradientTask extends MRTask<ColGradientTask> {
        final GLMModel.GLMParameters _params;
        final double[][] _beta;
        final FrameTask.DataInfo _dinfo;
        final double _reg;
        double[][] _gradient;
        double[] _objVals;

        public ColGradientTask(FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters, double[][] dArr, double d) {
            this._dinfo = dataInfo;
            this._params = gLMParameters;
            this._beta = dArr;
            this._reg = d;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r1v6, types: [double[], double[][]] */
        public void map(Chunk[] chunkArr) {
            double[] dArr = new double[this._beta.length];
            double[] malloc8d = MemoryManager.malloc8d(this._beta.length);
            boolean[] mallocZ = MemoryManager.mallocZ(chunkArr[0]._len);
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = MemoryManager.malloc8d(chunkArr[0]._len);
            }
            this._gradient = new double[this._beta.length];
            for (int i2 = 0; i2 < this._gradient.length; i2++) {
                this._gradient[i2] = MemoryManager.malloc8d(this._beta[i2].length);
            }
            int length = chunkArr.length - 1;
            Chunk chunk = chunkArr[length];
            Chunk chunk2 = null;
            if (this._dinfo._offset) {
                length--;
                chunk2 = chunkArr[length];
            }
            for (int i3 = 0; i3 < this._dinfo._cats; i3++) {
                Chunk chunk3 = chunkArr[i3];
                for (int i4 = 0; i4 < chunk3._len; i4++) {
                    if (mallocZ[i4] || chunk3.isNA(i4)) {
                        mallocZ[i4] = true;
                    } else {
                        int at8 = ((int) chunk3.at8(i4)) + this._dinfo._catOffsets[i3];
                        if (!this._dinfo._useAllFactorLevels) {
                            if (at8 != this._dinfo._catOffsets[i3]) {
                                at8--;
                            }
                        }
                        for (int i5 = 0; i5 < dArr.length; i5++) {
                            double[] dArr2 = dArr[i5];
                            int i6 = i4;
                            dArr2[i6] = dArr2[i6] + this._beta[i5][at8];
                        }
                    }
                }
            }
            for (int i7 = this._dinfo._cats; i7 < length; i7++) {
                Chunk chunk4 = chunkArr[i7];
                for (int i8 = 0; i8 < chunk4._len; i8++) {
                    if (mallocZ[i8] || chunk4.isNA(i8)) {
                        mallocZ[i8] = true;
                        System.out.println("skipping row " + i8);
                    } else {
                        double atd = chunk4.atd(i8);
                        if (this._dinfo._normMul != null) {
                            atd = (atd - this._dinfo._normSub[i7 - this._dinfo._cats]) * this._dinfo._normMul[i7 - this._dinfo._cats];
                        }
                        int numStart = this._dinfo.numStart() - this._dinfo._cats;
                        for (int i9 = 0; i9 < dArr.length; i9++) {
                            double[] dArr3 = dArr[i9];
                            int i10 = i8;
                            dArr3[i10] = dArr3[i10] + (this._beta[i9][numStart + i7] * atd);
                        }
                    }
                }
            }
            for (int i11 = 0; i11 < chunkArr[0]._len; i11++) {
                if (!mallocZ[i11] && !chunk.isNA(i11)) {
                    double atd2 = this._dinfo._offset ? chunk2.atd(i11) : 0.0d;
                    double atd3 = chunk.atd(i11);
                    for (int i12 = 0; i12 < dArr.length; i12++) {
                        double linkInv = this._params.linkInv(dArr[i12][i11] + atd2 + (this._dinfo._intercept ? this._beta[i12][this._beta[i12].length - 1] : 0.0d));
                        int i13 = i12;
                        malloc8d[i13] = malloc8d[i13] + this._params.deviance(atd3, linkInv);
                        double variance = this._params.variance(linkInv);
                        if (variance < 1.0E-6d) {
                            variance = 1.0E-6d;
                        }
                        dArr[i12][i11] = (linkInv - atd3) / (variance * this._params.linkDeriv(linkInv));
                        if (this._dinfo._intercept) {
                            double[] dArr4 = this._gradient[i12];
                            int length2 = this._gradient[i12].length - 1;
                            dArr4[length2] = dArr4[length2] + dArr[i12][i11];
                        }
                    }
                }
            }
            for (int i14 = 0; i14 < this._dinfo._cats; i14++) {
                Chunk chunk5 = chunkArr[i14];
                for (int i15 = 0; i15 < chunk5._len; i15++) {
                    if (!mallocZ[i15]) {
                        int at82 = ((int) chunk5.at8(i15)) + this._dinfo._catOffsets[i14];
                        if (!this._dinfo._useAllFactorLevels) {
                            if (at82 != this._dinfo._catOffsets[i14]) {
                                at82--;
                            }
                        }
                        for (int i16 = 0; i16 < dArr.length; i16++) {
                            double[] dArr5 = this._gradient[i16];
                            int i17 = at82;
                            dArr5[i17] = dArr5[i17] + dArr[i16][i15];
                        }
                    }
                }
            }
            for (int i18 = this._dinfo._cats; i18 < length; i18++) {
                Chunk chunk6 = chunkArr[i18];
                for (int i19 = 0; i19 < chunk6._len; i19++) {
                    if (mallocZ[i19] || chunk6.isNA(i19)) {
                        mallocZ[i19] = true;
                    } else {
                        double atd4 = chunk6.atd(i19);
                        if (this._dinfo._normMul != null) {
                            atd4 = (atd4 - this._dinfo._normSub[i18 - this._dinfo._cats]) * this._dinfo._normMul[i18 - this._dinfo._cats];
                        }
                        int numStart2 = this._dinfo.numStart() - this._dinfo._cats;
                        for (int i20 = 0; i20 < dArr.length; i20++) {
                            double[] dArr6 = this._gradient[i20];
                            int i21 = numStart2 + i18;
                            dArr6[i21] = dArr6[i21] + (dArr[i20][i19] * atd4);
                        }
                    }
                }
            }
            for (int i22 = 0; i22 < this._beta.length; i22++) {
                int i23 = i22;
                malloc8d[i23] = malloc8d[i23] * this._reg;
                for (int i24 = 0; i24 < this._beta[i22].length; i24++) {
                    double[] dArr7 = this._gradient[i22];
                    int i25 = i24;
                    dArr7[i25] = dArr7[i25] * this._reg;
                }
            }
            this._objVals = malloc8d;
        }

        public void reduce(ColGradientTask colGradientTask) {
            ArrayUtils.add(this._objVals, colGradientTask._objVals);
            for (int i = 0; i < this._beta.length; i++) {
                ArrayUtils.add(this._beta[i], colGradientTask._beta[i]);
            }
        }
    }

    /* loaded from: input_file:hex/glm/GLMTask$GLMIterationTask.class */
    public static class GLMIterationTask extends GLMTask<GLMIterationTask> {
        final double[] _beta;
        protected Gram _gram;
        double[] _xy;
        protected double[] _grad;
        double _yy;
        GLMValidation _val;
        final double _ymu;
        protected final double _reg;
        long _nobs;
        final boolean _validate;
        final float[] _thresholds;
        float[][] _newThresholds;
        int[] _ti;
        final boolean _computeGradient;
        final boolean _computeGram;
        public static final int N_THRESHOLDS = 50;
        static final /* synthetic */ boolean $assertionsDisabled;

        public GLMIterationTask(Key key, FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters, boolean z, boolean z2, boolean z3, double[] dArr, double d, double d2, float[] fArr, H2O.H2OCountedCompleter h2OCountedCompleter) {
            super(key, dataInfo, gLMParameters, h2OCountedCompleter);
            this._beta = dArr;
            this._ymu = d;
            this._reg = d2;
            this._computeGram = z;
            this._validate = z2;
            if (!$assertionsDisabled && gLMParameters._family == GLMModel.GLMParameters.Family.binomial && fArr == null) {
                throw new AssertionError();
            }
            this._thresholds = this._validate ? fArr : null;
            this._computeGradient = z3;
            if (!$assertionsDisabled && this._computeGradient && !z2) {
                throw new AssertionError();
            }
        }

        private void sampleThresholds(int i) {
            this._ti[i] = this._newThresholds[i].length >> 2;
            try {
                Arrays.sort(this._newThresholds[i]);
                for (int i2 = 0; i2 < this._newThresholds.length; i2 += 4) {
                    this._newThresholds[i][i2 >> 2] = this._newThresholds[i][i2];
                }
            } catch (Throwable th) {
                System.out.println("got AIOOB during sort?! ary = " + Arrays.toString(this._newThresholds[i]));
            }
        }

        @Override // hex.FrameTask
        public void processRow(long j, double[] dArr, int i, int[] iArr, double[] dArr2) {
            double computeEta;
            double linkInv;
            double d;
            double d2;
            this._nobs++;
            double d3 = dArr2[0];
            if (!$assertionsDisabled && this._glm._family == GLMModel.GLMParameters.Family.gamma && d3 <= 0.0d) {
                throw new AssertionError("illegal response column, y must be > 0  for family=Gamma.");
            }
            if (!$assertionsDisabled && this._glm._family == GLMModel.GLMParameters.Family.binomial && (0.0d > d3 || d3 > 1.0d)) {
                throw new AssertionError("illegal response column, y must be <0,1>  for family=Binomial. got " + d3);
            }
            int numStart = this._dinfo.numStart();
            double d4 = 1.0d;
            if (this._glm._family == GLMModel.GLMParameters.Family.gaussian) {
                d2 = 1.0d;
                d = d3;
                linkInv = (this._validate || this._computeGradient) ? computeEta(i, iArr, dArr, this._beta) : 0.0d;
            } else {
                if (this._beta == null) {
                    linkInv = this._glm.mustart(d3, this._ymu);
                    computeEta = this._glm.link(linkInv);
                } else {
                    computeEta = computeEta(i, iArr, dArr, this._beta);
                    linkInv = this._glm.linkInv(computeEta);
                }
                double max = Math.max(1.0E-5d, this._glm.variance(linkInv));
                d4 = this._glm.linkDeriv(linkInv);
                d = computeEta + ((d3 - linkInv) * d4);
                d2 = 1.0d / ((max * d4) * d4);
            }
            if (this._validate) {
                this._val.add(d3, linkInv);
                if (this._glm._family == GLMModel.GLMParameters.Family.binomial) {
                    int i2 = (int) d3;
                    if (this._ti[i2] == this._newThresholds[i2].length) {
                        sampleThresholds(i2);
                    }
                    float[] fArr = this._newThresholds[i2];
                    int[] iArr2 = this._ti;
                    int i3 = iArr2[i2];
                    iArr2[i2] = i3 + 1;
                    fArr[i3] = (float) linkInv;
                }
            }
            if (!$assertionsDisabled && d2 < 0.0d && !Double.isNaN(d2)) {
                throw new AssertionError("invalid weight " + d2);
            }
            double d5 = d2 * d;
            this._yy += d5 * d;
            if (this._computeGradient || this._computeGram) {
                double d6 = this._computeGradient ? d2 * d4 * (linkInv - d3) : 0.0d;
                for (int i4 = 0; i4 < i; i4++) {
                    int i5 = iArr[i4];
                    if (this._computeGradient) {
                        double[] dArr3 = this._grad;
                        dArr3[i5] = dArr3[i5] + d6;
                    }
                    double[] dArr4 = this._xy;
                    dArr4[i5] = dArr4[i5] + d5;
                }
                for (int i6 = 0; i6 < dArr.length; i6++) {
                    double[] dArr5 = this._xy;
                    int i7 = numStart + i6;
                    dArr5[i7] = dArr5[i7] + (d5 * dArr[i6]);
                    if (this._computeGradient) {
                        double[] dArr6 = this._grad;
                        int i8 = numStart + i6;
                        dArr6[i8] = dArr6[i8] + (d6 * dArr[i6]);
                    }
                }
                if (this._computeGradient) {
                    double[] dArr7 = this._grad;
                    int i9 = numStart + this._dinfo._nums;
                    dArr7[i9] = dArr7[i9] + d6;
                }
                double[] dArr8 = this._xy;
                int i10 = numStart + this._dinfo._nums;
                dArr8[i10] = dArr8[i10] + d5;
                if (this._computeGram) {
                    this._gram.addRow(dArr, i, iArr, d2);
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hex.FrameTask
        public void chunkInit() {
            if (this._computeGram) {
                this._gram = new Gram(this._dinfo.fullN(), this._dinfo.largestCat(), this._dinfo._nums, this._dinfo._cats, true);
            }
            this._xy = MemoryManager.malloc8d(this._dinfo.fullN() + 1);
            int i = 0;
            if (this._beta != null) {
                for (double d : this._beta) {
                    if (d != 0.0d) {
                        i++;
                    }
                }
            }
            if (this._validate) {
                this._val = new GLMValidation(null, this._ymu, this._glm, i, this._thresholds);
                if (this._glm._family == GLMModel.GLMParameters.Family.binomial) {
                    this._ti = new int[2];
                    this._newThresholds = new float[2][200];
                }
            }
            if (this._computeGradient) {
                this._grad = MemoryManager.malloc8d(this._dinfo.fullN() + 1);
            }
            if (this._glm._family == GLMModel.GLMParameters.Family.binomial && this._validate) {
                this._ti = new int[2];
                this._newThresholds = new float[2][200];
            }
        }

        @Override // hex.FrameTask
        protected void chunkDone(long j) {
            if (this._computeGram) {
                this._gram.mul(this._reg);
            }
            for (int i = 0; i < this._xy.length; i++) {
                double[] dArr = this._xy;
                int i2 = i;
                dArr[i2] = dArr[i2] * this._reg;
            }
            if (this._grad != null) {
                for (int i3 = 0; i3 < this._grad.length; i3++) {
                    double[] dArr2 = this._grad;
                    int i4 = i3;
                    dArr2[i4] = dArr2[i4] * this._reg;
                }
            }
            this._yy *= this._reg;
            if (this._validate && this._glm._family == GLMModel.GLMParameters.Family.binomial) {
                if (!$assertionsDisabled && this._val == null) {
                    throw new AssertionError();
                }
                this._newThresholds[0] = Arrays.copyOf(this._newThresholds[0], this._ti[0]);
                this._newThresholds[1] = Arrays.copyOf(this._newThresholds[1], this._ti[1]);
                Arrays.sort(this._newThresholds[0]);
                Arrays.sort(this._newThresholds[1]);
            }
        }

        @Override // 
        public void reduce(GLMIterationTask gLMIterationTask) {
            if (this._jobKey == null || Job.isRunning(this._jobKey)) {
                ArrayUtils.add(this._xy, gLMIterationTask._xy);
                if (this._computeGram) {
                    this._gram.add(gLMIterationTask._gram);
                }
                this._yy += gLMIterationTask._yy;
                this._nobs += gLMIterationTask._nobs;
                if (this._validate) {
                    this._val.add(gLMIterationTask._val);
                }
                if (this._computeGradient) {
                    ArrayUtils.add(this._grad, gLMIterationTask._grad);
                }
                if (this._validate && this._glm._family == GLMModel.GLMParameters.Family.binomial) {
                    this._newThresholds[0] = ArrayUtils.join(this._newThresholds[0], gLMIterationTask._newThresholds[0]);
                    this._newThresholds[1] = ArrayUtils.join(this._newThresholds[1], gLMIterationTask._newThresholds[1]);
                    if (this._newThresholds[0].length >= 100) {
                        for (int i = 0; i < 100; i += 2) {
                            this._newThresholds[0][i >> 1] = this._newThresholds[0][i];
                        }
                    }
                    if (this._newThresholds[0].length > 50) {
                        this._newThresholds[0] = Arrays.copyOf(this._newThresholds[0], 50);
                    }
                    if (this._newThresholds[1].length >= 100) {
                        for (int i2 = 0; i2 < 100; i2 += 2) {
                            this._newThresholds[1][i2 >> 1] = this._newThresholds[1][i2];
                        }
                    }
                    if (this._newThresholds[1].length > 50) {
                        this._newThresholds[1] = Arrays.copyOf(this._newThresholds[1], 50);
                    }
                }
                super.reduce((MRTask) gLMIterationTask);
            }
        }

        protected void postGlobal() {
            if (this._val != null) {
                this._val.computeAIC();
                this._val.computeAUC();
            }
        }

        public double[] gradient(double d, double d2) {
            double[] dArr = (double[]) this._grad.clone();
            if (this._beta != null) {
                for (int i = 0; i < dArr.length - 1; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + ((1.0d - d) * d2 * this._beta[i]);
                }
            }
            return dArr;
        }

        static {
            $assertionsDisabled = !GLMTask.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/glm/GLMTask$GLMLineSearchTask.class */
    public static class GLMLineSearchTask extends GLMTask<GLMLineSearchTask> {
        GLMIterationTask[] _glmts;

        public GLMLineSearchTask(Key key, FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters, double[] dArr, double[] dArr2, double d, double d2, long j, H2O.H2OCountedCompleter h2OCountedCompleter) {
            super(key, dataInfo, gLMParameters, h2OCountedCompleter);
            ArrayList arrayList = new ArrayList();
            double d3 = 1.0d;
            while (d3 > d && arrayList.size() < 100) {
                d3 = 0.0d;
                for (int i = 0; i < dArr2.length; i++) {
                    dArr2[i] = 0.5d * (dArr == null ? dArr2[i] : dArr[i] + dArr2[i]);
                    double d4 = dArr2[i] - (dArr == null ? 0.0d : dArr[i]);
                    if (d4 > d3) {
                        d3 = d4;
                    } else if (d4 < (-d3)) {
                        d3 = -d4;
                    }
                }
                arrayList.add(dArr2.clone());
            }
            this._glmts = new GLMIterationTask[arrayList.size()];
            for (int i2 = 0; i2 < this._glmts.length; i2++) {
                this._glmts[i2] = new GLMIterationTask(key, null, gLMParameters, false, true, true, (double[]) arrayList.get(i2), d2, 1.0d / j, new float[]{0.0f}, null);
            }
        }

        public GLMLineSearchTask(Key key, FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters, double[][] dArr, double d, long j, H2O.H2OCountedCompleter h2OCountedCompleter) {
            super(key, dataInfo, gLMParameters, h2OCountedCompleter);
            this._glmts = new GLMIterationTask[dArr.length];
            for (int i = 0; i < this._glmts.length; i++) {
                this._glmts[i] = new GLMIterationTask(key, null, gLMParameters, false, true, true, dArr[i], d, 1.0d / j, new float[]{0.0f}, null);
            }
        }

        @Override // hex.FrameTask
        public void setupLocal() {
            super.setupLocal();
            for (GLMIterationTask gLMIterationTask : this._glmts) {
                gLMIterationTask._dinfo = this._dinfo;
            }
        }

        @Override // hex.FrameTask
        public void closeLocal() {
            super.closeLocal();
            for (GLMIterationTask gLMIterationTask : this._glmts) {
                gLMIterationTask._dinfo = null;
            }
        }

        @Override // hex.FrameTask
        public void chunkInit() {
            this._glmts = (GLMIterationTask[]) this._glmts.clone();
            for (int i = 0; i < this._glmts.length; i++) {
                GLMIterationTask gLMIterationTask = (GLMIterationTask) this._glmts[i].clone();
                this._glmts[i] = gLMIterationTask;
                gLMIterationTask.chunkInit();
            }
        }

        @Override // hex.FrameTask
        public void chunkDone(long j) {
            for (int i = 0; i < this._glmts.length; i++) {
                this._glmts[i].chunkDone(j);
            }
        }

        public void postGlobal() {
            for (int i = 0; i < this._glmts.length; i++) {
                this._glmts[i].postGlobal();
            }
        }

        @Override // hex.FrameTask
        public final void processRow(long j, double[] dArr, int i, int[] iArr, double[] dArr2) {
            for (int i2 = 0; i2 < this._glmts.length; i2++) {
                this._glmts[i2].processRow(j, dArr, i, iArr, dArr2);
            }
        }

        public void reduce(GLMLineSearchTask gLMLineSearchTask) {
            for (int i = 0; i < this._glmts.length; i++) {
                this._glmts[i].reduce(gLMLineSearchTask._glmts[i]);
            }
        }
    }

    /* loaded from: input_file:hex/glm/GLMTask$GLMValidationTask.class */
    public static class GLMValidationTask<T extends GLMValidationTask<T>> extends MRTask<T> {
        protected final GLMModel _model;
        protected GLMValidation _res;
        public final double _lambda;
        public boolean _improved;
        Key _jobKey;

        public static Key makeKey() {
            return Key.make("__GLMValidation_" + Key.make().toString());
        }

        public GLMValidationTask(GLMModel gLMModel, double d) {
            this(gLMModel, d, null);
        }

        public GLMValidationTask(GLMModel gLMModel, double d, H2O.H2OCountedCompleter h2OCountedCompleter) {
            super(h2OCountedCompleter);
            this._lambda = d;
            this._model = gLMModel;
        }

        public void map(Chunk[] chunkArr) {
            this._res = new GLMValidation(null, this._model._ymu, this._model._parms, this._model.rank(this._lambda));
            int i = chunkArr[0]._len;
            double[] malloc8d = MemoryManager.malloc8d(this._model._output._names.length);
            float[] malloc4f = MemoryManager.malloc4f(this._model._parms._family == GLMModel.GLMParameters.Family.binomial ? 3 : 1);
            for (int i2 = 0; i2 < i; i2++) {
                if (!chunkArr[chunkArr.length - 1].isNA(i2)) {
                    int i3 = 0;
                    while (true) {
                        if (i3 >= chunkArr.length - 1) {
                            this._model.score0(malloc8d, malloc4f);
                            this._res.add(chunkArr[chunkArr.length - 1].atd(i2), this._model._parms._family == GLMModel.GLMParameters.Family.binomial ? malloc4f[2] : malloc4f[0]);
                        } else {
                            if (chunkArr[i3].isNA(i2)) {
                                break;
                            }
                            malloc8d[i3] = chunkArr[i3].atd(i2);
                            i3++;
                        }
                    }
                }
            }
        }

        @Override // 
        public void reduce(GLMValidationTask gLMValidationTask) {
            this._res.add(gLMValidationTask._res);
        }

        public void postGlobal() {
            this._res.computeAIC();
            this._res.computeAUC();
        }
    }

    /* loaded from: input_file:hex/glm/GLMTask$GLMXValidationTask.class */
    public static class GLMXValidationTask extends GLMValidationTask<GLMXValidationTask> {
        protected final GLMModel[] _xmodels;
        protected GLMValidation[] _xvals;
        long _nobs;
        final float[] _thresholds;

        public static Key makeKey() {
            return Key.make("__GLMValidation_" + Key.make().toString());
        }

        public GLMXValidationTask(GLMModel gLMModel, double d, GLMModel[] gLMModelArr, float[] fArr) {
            this(gLMModel, d, gLMModelArr, fArr, null);
        }

        public GLMXValidationTask(GLMModel gLMModel, double d, GLMModel[] gLMModelArr, float[] fArr, H2O.H2OCountedCompleter h2OCountedCompleter) {
            super(gLMModel, d, h2OCountedCompleter);
            this._xmodels = gLMModelArr;
            this._thresholds = fArr;
        }

        @Override // hex.glm.GLMTask.GLMValidationTask
        public void map(Chunk[] chunkArr) {
            long start = chunkArr[0].start();
            this._xvals = new GLMValidation[this._xmodels.length];
            for (int i = 0; i < this._xmodels.length; i++) {
                this._xvals[i] = new GLMValidation(null, this._xmodels[i]._ymu, this._xmodels[i]._parms, this._xmodels[i]._output.rank(), this._thresholds);
            }
            int i2 = chunkArr[0]._len;
            double[] malloc8d = MemoryManager.malloc8d(this._xmodels[0]._output._names.length);
            float[] malloc4f = MemoryManager.malloc4f(this._xmodels[0]._parms._family == GLMModel.GLMParameters.Family.binomial ? 3 : 1);
            for (int i3 = 0; i3 < i2; i3++) {
                if (!chunkArr[chunkArr.length - 1].isNA(i3)) {
                    int i4 = 0;
                    while (true) {
                        if (i4 >= chunkArr.length - 1) {
                            this._nobs++;
                            int length = (int) ((i3 + start) % this._xmodels.length);
                            GLMModel gLMModel = this._xmodels[length];
                            GLMValidation gLMValidation = this._xvals[length];
                            gLMModel.score0(malloc8d, malloc4f);
                            gLMValidation.add(chunkArr[chunkArr.length - 1].at8(i3), gLMModel._parms._family == GLMModel.GLMParameters.Family.binomial ? malloc4f[2] : malloc4f[0]);
                        } else {
                            if (chunkArr[i4].isNA(i3)) {
                                break;
                            }
                            malloc8d[i4] = chunkArr[i4].atd(i3);
                            i4++;
                        }
                    }
                }
            }
        }

        @Override // hex.glm.GLMTask.GLMValidationTask
        public void reduce(GLMXValidationTask gLMXValidationTask) {
            this._nobs += gLMXValidationTask._nobs;
            for (int i = 0; i < this._xvals.length; i++) {
                this._xvals[i].add(gLMXValidationTask._xvals[i]);
            }
        }

        @Override // hex.glm.GLMTask.GLMValidationTask
        public void postGlobal() {
            H2O.H2OCountedCompleter completer = getCompleter();
            if (completer != null) {
                completer.addToPendingCount(this._xvals.length + 1);
            }
            for (int i = 0; i < this._xvals.length; i++) {
                this._xvals[i].computeAIC();
                this._xvals[i].computeAUC();
                this._xvals[i].nobs = this._nobs - this._xvals[i].nobs;
                GLMModel.setXvalidation(completer, this._xmodels[i]._key, this._lambda, this._xvals[i]);
            }
            GLMModel.setXvalidation(completer, this._model._key, this._lambda, new GLMValidation.GLMXValidation(this._model, this._xmodels, this._xvals, this._lambda, this._nobs, this._thresholds));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/glm/GLMTask$LMAXTask.class */
    public static class LMAXTask extends GLMIterationTask {
        private double[] _z;
        private final double _gPrimeMu;
        private double _lmax;
        private final double _alpha;

        public LMAXTask(Key key, FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters, double d, long j, float[] fArr, H2O.H2OCountedCompleter h2OCountedCompleter) {
            super(key, dataInfo, gLMParameters, false, true, true, gLMParameters.nullModelBeta(dataInfo, d), d, 1.0d / j, fArr, h2OCountedCompleter);
            this._gPrimeMu = gLMParameters.linkDeriv(d);
            this._alpha = gLMParameters._alpha[0];
        }

        @Override // hex.glm.GLMTask.GLMIterationTask, hex.FrameTask
        public void chunkInit() {
            super.chunkInit();
            this._z = MemoryManager.malloc8d(this._grad.length);
        }

        @Override // hex.glm.GLMTask.GLMIterationTask, hex.FrameTask
        public void processRow(long j, double[] dArr, int i, int[] iArr, double[] dArr2) {
            double d = (dArr2[0] - this._ymu) * this._gPrimeMu;
            for (int i2 = 0; i2 < i; i2++) {
                double[] dArr3 = this._z;
                int i3 = iArr[i2];
                dArr3[i3] = dArr3[i3] + d;
            }
            int numStart = this._dinfo.numStart();
            for (int i4 = 0; i4 < dArr.length; i4++) {
                double[] dArr4 = this._z;
                int i5 = i4 + numStart;
                dArr4[i5] = dArr4[i5] + (d * dArr[i4]);
            }
            super.processRow(j, dArr, i, iArr, dArr2);
        }

        @Override // hex.glm.GLMTask.GLMIterationTask
        public void reduce(GLMIterationTask gLMIterationTask) {
            ArrayUtils.add(this._z, ((LMAXTask) gLMIterationTask)._z);
            super.reduce(gLMIterationTask);
        }

        @Override // hex.glm.GLMTask.GLMIterationTask
        protected void postGlobal() {
            super.postGlobal();
            double abs = Math.abs(this._z[0]);
            for (int i = 1; i < this._z.length; i++) {
                if (abs < this._z[i]) {
                    abs = this._z[i];
                } else if (abs < (-this._z[i])) {
                    abs = -this._z[i];
                }
            }
            this._lmax = (this._glm.variance(this._ymu) * abs) / (this._nobs * Math.max(this._alpha, 0.001d));
        }

        public double lmax() {
            return this._lmax;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/glm/GLMTask$YMUTask.class */
    public static class YMUTask extends FrameTask<YMUTask> {
        private long[] _nobs;
        protected double[] _ymu;
        public double[] _ymin;
        public double[] _ymax;
        final int _nfolds;

        public YMUTask(Key key, Key key2, int i) {
            this(key, key2, i, null);
        }

        public YMUTask(Key key, Key key2, int i, H2O.H2OCountedCompleter h2OCountedCompleter) {
            super(key, key2, null, h2OCountedCompleter);
            this._nfolds = i;
        }

        @Override // hex.FrameTask
        public void chunkInit() {
            super.chunkInit();
            this._ymu = new double[this._nfolds + 1];
            this._nobs = new long[this._nfolds + 1];
            this._ymax = new double[this._nfolds + 1];
            this._ymin = new double[this._nfolds + 1];
            Arrays.fill(this._ymax, Double.NEGATIVE_INFINITY);
            Arrays.fill(this._ymin, Double.POSITIVE_INFINITY);
        }

        @Override // hex.FrameTask
        protected void processRow(long j, double[] dArr, int i, int[] iArr, double[] dArr2) {
            double d = dArr2[0];
            double[] dArr3 = this._ymu;
            dArr3[0] = dArr3[0] + d;
            long[] jArr = this._nobs;
            jArr[0] = jArr[0] + 1;
            if (d < this._ymin[0]) {
                this._ymin[0] = d;
            }
            if (d > this._ymax[0]) {
                this._ymax[0] = d;
            }
            for (int i2 = 1; i2 < this._nfolds + 1; i2++) {
                if (j % this._nfolds != i2 - 1) {
                    double[] dArr4 = this._ymu;
                    int i3 = i2;
                    dArr4[i3] = dArr4[i3] + d;
                    long[] jArr2 = this._nobs;
                    int i4 = i2;
                    jArr2[i4] = jArr2[i4] + 1;
                    if (d < this._ymin[0]) {
                        this._ymin[i2] = d;
                    }
                    if (d > this._ymax[i2]) {
                        this._ymax[i2] = d;
                    }
                }
            }
        }

        public void reduce(YMUTask yMUTask) {
            if (yMUTask._nobs[0] != 0) {
                if (this._nobs[0] == 0) {
                    this._ymu = yMUTask._ymu;
                    this._nobs = yMUTask._nobs;
                    this._ymin = yMUTask._ymin;
                    this._ymax = yMUTask._ymax;
                    return;
                }
                for (int i = 0; i < this._nfolds + 1; i++) {
                    if (this._nobs[i] + yMUTask._nobs[i] != 0) {
                        this._ymu[i] = (this._ymu[i] * (this._nobs[i] / (this._nobs[i] + yMUTask._nobs[i]))) + ((yMUTask._ymu[i] * yMUTask._nobs[i]) / (this._nobs[i] + yMUTask._nobs[i]));
                        long[] jArr = this._nobs;
                        int i2 = i;
                        jArr[i2] = jArr[i2] + yMUTask._nobs[i];
                        if (yMUTask._ymax[i] > this._ymax[i]) {
                            this._ymax[i] = yMUTask._ymax[i];
                        }
                        if (yMUTask._ymin[i] < this._ymin[i]) {
                            this._ymin[i] = yMUTask._ymin[i];
                        }
                    }
                }
            }
        }

        @Override // hex.FrameTask
        protected void chunkDone(long j) {
            for (int i = 0; i < this._ymu.length; i++) {
                if (this._nobs[i] != 0) {
                    double[] dArr = this._ymu;
                    int i2 = i;
                    dArr[i2] = dArr[i2] / this._nobs[i];
                }
            }
        }

        public double ymu() {
            return ymu(-1);
        }

        public long nobs() {
            return nobs(-1);
        }

        public double ymu(int i) {
            return this._ymu[i + 1];
        }

        public long nobs(int i) {
            return this._nobs[i + 1];
        }
    }

    public GLMTask(Key key, FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters) {
        this(key, dataInfo, gLMParameters, null);
    }

    public GLMTask(Key key, FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters, H2O.H2OCountedCompleter h2OCountedCompleter) {
        super(key, dataInfo == null ? null : dataInfo._key, dataInfo == null ? null : dataInfo._activeCols, h2OCountedCompleter);
        this._glm = gLMParameters;
    }

    protected final double computeEta(int i, int[] iArr, double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            d += dArr2[iArr[i2]];
        }
        int numStart = this._dinfo.numStart();
        for (int i3 = 0; i3 < dArr.length; i3++) {
            d += dArr[i3] * dArr2[numStart + i3];
        }
        return d + dArr2[dArr2.length - 1];
    }
}
