package hex.glm;

import hex.FrameTask;
import hex.Model;
import hex.SupervisedModelBuilder;
import hex.glm.GLMModel;
import hex.glm.GLMTask;
import hex.glm.LSMSolver;
import hex.optimization.L_BFGS;
import hex.schemas.GLMV2;
import hex.schemas.ModelBuilderSchema;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import jsr166y.CountedCompleter;
import water.DKV;
import water.DTask;
import water.Futures;
import water.H2O;
import water.H2ONode;
import water.Iced;
import water.Job;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MRUtils;
import water.util.MathUtils;
import water.util.ModelUtils;

/* loaded from: input_file:hex/glm/GLM.class */
public class GLM extends SupervisedModelBuilder<GLMModel, GLMModel.GLMParameters, GLMModel.GLMOutput> {
    private static final int WORK_TOTAL = 100000000;
    private boolean _clean_enums;
    private static double GLM_GRAD_EPS;
    private static final int MAX_ITERATIONS_PER_LAMBDA = 10;
    private static final int MAX_ITER = 50;
    private static final int sparseCoefThreshold = 750;
    private static final double beta_epsilon = 1.0E-4d;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/glm/GLM$GLMColBasedGradientSolver.class */
    public static final class GLMColBasedGradientSolver extends L_BFGS.GradientSolver {
        final Key _jobKey = null;
        final GLMModel.GLMParameters _glmp;
        final FrameTask.DataInfo _dinfo;
        final double _ymu;
        final double _lambda;
        final long _nobs;

        public GLMColBasedGradientSolver(GLMModel.GLMParameters gLMParameters, FrameTask.DataInfo dataInfo, double d, double d2, long j) {
            this._glmp = gLMParameters;
            this._dinfo = dataInfo;
            this._ymu = d2;
            this._nobs = j;
            this._lambda = d;
        }

        @Override // hex.optimization.L_BFGS.GradientSolver
        public L_BFGS.GradientInfo[] getGradient(double[][] dArr) {
            GLMTask.ColGradientTask colGradientTask = (GLMTask.ColGradientTask) new GLMTask.ColGradientTask(this._dinfo, this._glmp, dArr, 1.0d / this._nobs).doAll(this._dinfo._adaptedFrame);
            L_BFGS.GradientInfo[] gradientInfoArr = new L_BFGS.GradientInfo[dArr.length];
            for (int i = 0; i < gradientInfoArr.length; i++) {
                for (int i2 = 0; i2 < dArr[i].length - 1; i2++) {
                    double[] dArr2 = colGradientTask._gradient[i];
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + (this._lambda * dArr[i][i2]);
                }
                gradientInfoArr[i] = new L_BFGS.GradientInfo(colGradientTask._objVals[i] + (0.5d * this._lambda * ArrayUtils.l2norm2(dArr[i], true)), colGradientTask._gradient[i]);
            }
            return gradientInfoArr;
        }
    }

    /* loaded from: input_file:hex/glm/GLM$GLMDriver.class */
    public final class GLMDriver extends DTask<GLMDriver> {
        final FrameTask.DataInfo _dinfo;
        transient ArrayList<FrameTask.DataInfo> _foldInfos;
        double[] lambdas;
        final GLMTaskInfo[] _state;
        int _lambdaId;
        int _maxLambda;
        transient AtomicBoolean _gotException;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* renamed from: hex.glm.GLM$GLMDriver$1, reason: invalid class name */
        /* loaded from: input_file:hex/glm/GLM$GLMDriver$1.class */
        public class AnonymousClass1 extends H2O.H2OCallback<GLMTask.YMUTask> {
            AnonymousClass1(H2O.H2OCountedCompleter h2OCountedCompleter) {
                super(h2OCountedCompleter);
            }

            public String toString() {
                return new StringBuilder().append("YMUTask callback. completer = ").append(getCompleter()).toString() != null ? "null" : getCompleter().toString();
            }

            public void callback(final GLMTask.YMUTask yMUTask) {
                double ymu;
                long nobs;
                if (yMUTask._ymin == yMUTask._ymax) {
                    throw new IllegalArgumentException("GLM2(" + GLM.this._dest + "): attempted to run with constant response. Response == " + yMUTask._ymin + " for all rows in the training set.");
                }
                if (yMUTask.nobs() / GLMDriver.this._dinfo._adaptedFrame.numRows() < 0.75d) {
                    ymu = GLMDriver.this._dinfo._adaptedFrame.lastVec().mean();
                    nobs = GLMDriver.this._dinfo._adaptedFrame.numRows();
                } else {
                    ymu = yMUTask.ymu();
                    nobs = yMUTask.nobs();
                }
                if (GLM.this._parms._family != GLMModel.GLMParameters.Family.binomial || GLM.this._parms._prior == -1.0d || GLM.this._parms._prior == ymu || Double.isNaN(GLM.this._parms._prior)) {
                    GLM.this._parms._prior = ymu;
                } else {
                    double d = GLM.this._parms._prior / ymu;
                    double d2 = 1.0d;
                    double d3 = 1.0d;
                    if (d > 1.0d) {
                        d3 = 1.0d / d;
                    } else if (d < 1.0d) {
                        d2 = d;
                    }
                    Math.log(d2 / d3);
                }
                H2O.H2OCountedCompleter completer = getCompleter();
                completer.addToPendingCount(1);
                final long j = nobs;
                final double d4 = ymu;
                new GLMTask.LMAXTask(GLM.this._key, GLMDriver.this._dinfo, GLM.this._parms, ymu, nobs, ModelUtils.DEFAULT_THRESHOLDS, new H2O.H2OCallback<GLMTask.LMAXTask>(completer) { // from class: hex.glm.GLM.GLMDriver.1.1
                    static final /* synthetic */ boolean $assertionsDisabled;

                    public String toString() {
                        return "LMAXTask callback. completer = " + (getCompleter() != null ? "NULL" : getCompleter().toString());
                    }

                    public void callback(final GLMTask.LMAXTask lMAXTask) {
                        GLMModel.GLMOutput gLMOutput = new GLMModel.GLMOutput(GLM.this, GLMDriver.this._dinfo, GLM.this._parms._family == GLMModel.GLMParameters.Family.binomial);
                        if (!GLM.this._parms._lambda_search) {
                            if (GLM.this._parms._lambda == null || GLM.this._parms._lambda.length == 0) {
                                GLMDriver.this.lambdas = new double[]{0.01d * lMAXTask.lmax()};
                            } else {
                                GLMDriver.this.lambdas = GLM.this._parms._lambda;
                            }
                            int i = 0;
                            while (i < GLMDriver.this.lambdas.length && GLMDriver.this.lambdas[i] >= lMAXTask.lmax()) {
                                i++;
                            }
                            if (i == GLMDriver.this.lambdas.length) {
                                throw new IllegalArgumentException("Given lambda(s) are all > lambda_max = " + lMAXTask.lmax() + ", have nothing to run with. lambda = " + Arrays.toString(GLMDriver.this.lambdas));
                            }
                            r26 = i > 0 ? "Removed " + i + " lambdas greater than lambda_max." : null;
                            GLMDriver.this.lambdas = ArrayUtils.append(new double[]{lMAXTask.lmax()}, Arrays.copyOfRange(GLMDriver.this.lambdas, i, GLMDriver.this.lambdas.length));
                        } else {
                            if (!$assertionsDisabled && Double.isNaN(lMAXTask.lmax())) {
                                throw new AssertionError("running lambda_value search, but don't know what is the lambda_value max!");
                            }
                            if (GLM.this._parms._lambda_min_ratio == -1.0d) {
                                GLM.this._parms._lambda_min_ratio = j > ((long) (25 * GLMDriver.this._dinfo.fullN())) ? GLM.beta_epsilon : 0.01d;
                            }
                            double pow = Math.pow(GLM.this._parms._lambda_min_ratio, 1.0d / (GLM.this._parms._nlambdas - 1));
                            GLMDriver.this.lambdas = new double[GLM.this._parms._nlambdas];
                            GLMDriver.this.lambdas[0] = lMAXTask.lmax();
                            if (GLM.this._parms._nlambdas == 1) {
                                throw new IllegalArgumentException("Number of lambdas must be > 1 when running with lambda_search!");
                            }
                            for (int i2 = 1; i2 < GLMDriver.this.lambdas.length; i2++) {
                                GLMDriver.this.lambdas[i2] = GLMDriver.this.lambdas[i2 - 1] * pow;
                            }
                        }
                        double d5 = GLMDriver.this.lambdas[1];
                        if (GLMDriver.this.lambdas.length > 1) {
                            gLMOutput.addNullSubmodel(lMAXTask.lmax(), GLM.this._parms.link(d4), lMAXTask._val);
                        }
                        GLMDriver.this._maxLambda = GLMDriver.this.lambdas.length;
                        GLMModel gLMModel = new GLMModel(GLM.this._dest, GLM.this._parms, gLMOutput, GLMDriver.this._dinfo, d4, lMAXTask.lmax(), j);
                        if (r26 != null) {
                            gLMModel.addWarning(r26);
                        }
                        gLMModel.delete_and_lock(GLM.this._key);
                        double lmax = lMAXTask.lmax();
                        GLMDriver.this._state[0] = new GLMTaskInfo(GLM.this._dest, GLMDriver.this._dinfo, GLM.this._parms, lMAXTask._nobs, lMAXTask._ymu, lmax, lmax, null, lMAXTask.gradient(GLM.this._parms._alpha[0], lmax), GLM.objval(lMAXTask, GLM.this._parms._alpha[0], lMAXTask.lmax()));
                        getCompleter().addToPendingCount(1);
                        if (GLM.this._parms._n_folds <= 1) {
                            LambdaSearchIteration lambdaSearchIteration = new LambdaSearchIteration(getCompleter(), GLM.this._parms._solver == GLMModel.GLMParameters.Solver.L_BFGS);
                            Key key = GLM.this._key;
                            Key key2 = GLM.this._progressKey;
                            GLMTaskInfo gLMTaskInfo = GLMDriver.this._state[0];
                            double[] dArr = GLMDriver.this.lambdas;
                            GLMDriver gLMDriver = GLMDriver.this;
                            int i3 = gLMDriver._lambdaId + 1;
                            gLMDriver._lambdaId = i3;
                            new GLMLambdaTask(lambdaSearchIteration, key, key2, gLMTaskInfo, dArr[i3], GLM.this._parms._solver == GLMModel.GLMParameters.Solver.L_BFGS).fork();
                            return;
                        }
                        H2O.H2OCallback h2OCallback = new H2O.H2OCallback(getCompleter()) { // from class: hex.glm.GLM.GLMDriver.1.1.1
                            public void callback(H2O.H2OCountedCompleter h2OCountedCompleter) {
                                GLMLambdaTask[] gLMLambdaTaskArr = new GLMLambdaTask[GLMDriver.this._state.length];
                                for (int i4 = 0; i4 < gLMLambdaTaskArr.length; i4++) {
                                    gLMLambdaTaskArr[i4] = new GLMLambdaTask(null, GLM.this._key, GLM.this._progressKey, GLMDriver.this._state[i4], GLMDriver.this.lambdas[GLMDriver.this._lambdaId], GLM.this._parms._solver == GLMModel.GLMParameters.Solver.L_BFGS);
                                }
                                getCompleter().addToPendingCount(1);
                                new MRUtils.ParallelTasks(new LambdaSearchIteration(getCompleter(), GLM.this._parms._solver == GLMModel.GLMParameters.Solver.L_BFGS), gLMLambdaTaskArr).fork();
                            }
                        };
                        h2OCallback.addToPendingCount(GLMDriver.this._state.length - 2);
                        for (int i4 = 1; i4 < GLMDriver.this._state.length; i4++) {
                            final int i5 = i4;
                            final GLMModel.GLMParameters clone = GLM.this._parms.clone();
                            clone._n_folds = 0;
                            final FrameTask.DataInfo fold = GLMDriver.this._dinfo.getFold(i4 - 1, GLM.this._parms._n_folds);
                            GLMDriver.this._foldInfos.add(fold);
                            DKV.put(fold._key, fold);
                            Log.info(new Object[]{"inserted dinfo for fold " + i4 + " under key " + fold._key});
                            if (i4 != 0) {
                                new GLMTask.LMAXTask(GLM.this._key, fold, clone, yMUTask.ymu(i5 - 1), yMUTask.nobs(i5 - 1), ModelUtils.DEFAULT_THRESHOLDS, new H2O.H2OCallback<GLMTask.LMAXTask>(h2OCallback) { // from class: hex.glm.GLM.GLMDriver.1.1.2
                                    static final /* synthetic */ boolean $assertionsDisabled;

                                    public String toString() {
                                        return new StringBuilder().append("Xval LMAXTask callback., completer = ").append(getCompleter()).toString() == null ? "null" : getCompleter().toString();
                                    }

                                    public void callback(GLMTask.LMAXTask lMAXTask2) {
                                        double lmax2 = lMAXTask2.lmax();
                                        Key make = Key.make(GLM.this._dest.toString() + "_xval_" + i5, (byte) 1, (byte) 31, true, new H2ONode[]{H2O.SELF});
                                        GLMDriver.this._state[i5] = new GLMTaskInfo(make, fold, clone, lMAXTask2._nobs, lMAXTask2._ymu, lMAXTask2.lmax(), lMAXTask.lmax(), GLMDriver.this.nullBeta(fold, clone, lMAXTask2._ymu), lMAXTask2.gradient(GLM.this._parms._alpha[0], lmax2), GLM.objval(lMAXTask2, GLM.this._parms._alpha[0], lMAXTask2.lmax()));
                                        if (!$assertionsDisabled && DKV.get(fold._key) == null) {
                                            throw new AssertionError();
                                        }
                                        new GLMModel(make, clone, new GLMModel.GLMOutput(GLM.this, fold, GLM.this._parms._family == GLMModel.GLMParameters.Family.binomial), fold, lMAXTask2._ymu, lmax2, j).delete_and_lock(GLM.this._key);
                                        if (lMAXTask2.lmax() > lMAXTask.lmax()) {
                                            getCompleter().addToPendingCount(1);
                                            new GLMLambdaTask(getCompleter(), GLM.this._key, GLM.this._progressKey, GLMDriver.this._state[i5], lMAXTask.lmax(), GLM.this._parms._solver == GLMModel.GLMParameters.Solver.L_BFGS).fork();
                                        }
                                    }

                                    static {
                                        $assertionsDisabled = !GLM.class.desiredAssertionStatus();
                                    }
                                }).asyncExec(fold._adaptedFrame);
                            }
                        }
                    }

                    static {
                        $assertionsDisabled = !GLM.class.desiredAssertionStatus();
                    }
                }).asyncExec(GLMDriver.this._dinfo._adaptedFrame);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/glm/GLM$GLMDriver$LambdaSearchIteration.class */
        public class LambdaSearchIteration extends H2O.H2OCallback {
            final boolean _forceLBFGS;

            public LambdaSearchIteration(H2O.H2OCountedCompleter h2OCountedCompleter, boolean z) {
                super(h2OCountedCompleter);
                this._forceLBFGS = z;
            }

            public void callback(H2O.H2OCountedCompleter h2OCountedCompleter) {
                double d = GLMDriver.this.lambdas[GLMDriver.this._lambdaId];
                if (GLM.this._parms._n_folds > 1) {
                    MRUtils.ParallelTasks parallelTasks = (MRUtils.ParallelTasks) h2OCountedCompleter;
                    for (int i = 0; i < ((GLMLambdaTask[]) parallelTasks._tasks).length; i++) {
                        GLMDriver.this._state[i] = ((GLMLambdaTask[]) parallelTasks._tasks)[i]._taskInfo;
                    }
                }
                GLMDriver gLMDriver = GLMDriver.this;
                int i2 = gLMDriver._lambdaId + 1;
                gLMDriver._lambdaId = i2;
                if (i2 < GLMDriver.this._maxLambda) {
                    getCompleter().addToPendingCount(1);
                    double d2 = GLMDriver.this.lambdas[GLMDriver.this._lambdaId];
                    if (GLM.this._parms._n_folds <= 1) {
                        GLMDriver.this._state[0]._lastLambda = d;
                        new GLMLambdaTask(new LambdaSearchIteration(getCompleter(), this._forceLBFGS), GLM.this._key, GLM.this._progressKey, GLMDriver.this._state[0], d2, this._forceLBFGS).fork();
                        return;
                    }
                    GLMLambdaTask[] gLMLambdaTaskArr = new GLMLambdaTask[GLMDriver.this._state.length];
                    LambdaSearchIteration lambdaSearchIteration = new LambdaSearchIteration(getCompleter(), this._forceLBFGS);
                    lambdaSearchIteration.addToPendingCount(gLMLambdaTaskArr.length - 1);
                    for (int i3 = 0; i3 < gLMLambdaTaskArr.length; i3++) {
                        GLMDriver.this._state[i3]._lastLambda = d;
                        gLMLambdaTaskArr[i3] = new GLMLambdaTask(lambdaSearchIteration, GLM.this._key, GLM.this._progressKey, GLMDriver.this._state[i3], d2, false);
                    }
                    new MRUtils.ParallelTasks(new LambdaSearchIteration(getCompleter(), this._forceLBFGS), gLMLambdaTaskArr).fork();
                }
            }
        }

        public GLMDriver(H2O.H2OCountedCompleter h2OCountedCompleter, FrameTask.DataInfo dataInfo) {
            super(h2OCountedCompleter);
            this._foldInfos = new ArrayList<>();
            this._gotException = new AtomicBoolean();
            this._dinfo = dataInfo;
            this._state = GLM.this._parms._n_folds > 1 ? new GLMTaskInfo[GLM.this._parms._n_folds + 1] : new GLMTaskInfo[1];
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double[] nullBeta(FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters, double d) {
            double[] malloc8d = MemoryManager.malloc8d(dataInfo.fullN() + 1);
            malloc8d[malloc8d.length - 1] = gLMParameters.linkInv(d);
            return malloc8d;
        }

        private void doCleanup() {
            DKV.remove(this._dinfo._key);
            Iterator<FrameTask.DataInfo> it = this._foldInfos.iterator();
            while (it.hasNext()) {
                DKV.remove(it.next()._key);
            }
        }

        public boolean onExceptionalCompletion(Throwable th, CountedCompleter countedCompleter) {
            doCleanup();
            if (this._gotException.getAndSet(true)) {
                return false;
            }
            if (!(th instanceof TooManyPredictorsException)) {
                new DTask.RemoveCall((H2O.H2OCountedCompleter) null, GLM.this._dest).invokeTask();
                return true;
            }
            this._maxLambda = this._lambdaId;
            tryComplete();
            return false;
        }

        public void onCompletion(CountedCompleter countedCompleter) {
            doCleanup();
            H2O.H2OCountedCompleter completer = getCompleter();
            completer.addToPendingCount(1);
            new GLMModel.FinalizeAndUnlockTsk(completer, GLM.this._dest, GLM.this._key).fork();
        }

        protected void compute2() {
            if (GLM.this._parms._alpha.length > 1) {
                return;
            }
            if (GLM.this._parms._nlambdas == -1) {
                GLM.this._parms._nlambdas = 100;
            }
            if (GLM.this._parms._lambda_search && GLM.this._parms._nlambdas <= 1) {
                throw new IllegalArgumentException("GLM2(" + GLM.this._dest + ") nlambdas must be > 1 when running with lambda search.");
            }
            new Futures();
            new GLMTask.YMUTask(GLM.this._key, this._dinfo._key, GLM.this._parms._n_folds, new AnonymousClass1(this)).asyncExec(this._dinfo._adaptedFrame);
        }
    }

    /* loaded from: input_file:hex/glm/GLM$GLMGradientInfo.class */
    public static final class GLMGradientInfo extends L_BFGS.GradientInfo {
        public final GLMValidation _val;

        public GLMGradientInfo(GLMTask.GLMIterationTask gLMIterationTask, double d) {
            super((gLMIterationTask._val.residualDeviance() / gLMIterationTask._nobs) + (0.5d * d * ArrayUtils.l2norm2(gLMIterationTask._beta, true)), gLMIterationTask.gradient(0.0d, d));
            this._val = gLMIterationTask._val;
        }
    }

    /* loaded from: input_file:hex/glm/GLM$GLMGradientSolver.class */
    public static final class GLMGradientSolver extends L_BFGS.GradientSolver {
        final Key _jobKey = null;
        final GLMModel.GLMParameters _glmp;
        final FrameTask.DataInfo _dinfo;
        final double _ymu;
        final double _lambda;
        final long _nobs;

        public GLMGradientSolver(GLMModel.GLMParameters gLMParameters, FrameTask.DataInfo dataInfo, double d, double d2, long j) {
            this._glmp = gLMParameters;
            this._dinfo = dataInfo;
            this._ymu = d2;
            this._nobs = j;
            this._lambda = d;
        }

        @Override // hex.optimization.L_BFGS.GradientSolver
        public L_BFGS.GradientInfo[] getGradient(double[][] dArr) {
            double d = 1.0d / this._nobs;
            GLMTask.GLMIterationTask[] gLMIterationTaskArr = ((GLMTask.GLMLineSearchTask) new GLMTask.GLMLineSearchTask(this._jobKey, this._dinfo, this._glmp, dArr, this._ymu, this._nobs, null).doAll(this._dinfo._adaptedFrame))._glmts;
            L_BFGS.GradientInfo[] gradientInfoArr = new L_BFGS.GradientInfo[gLMIterationTaskArr.length];
            for (int i = 0; i < gradientInfoArr.length; i++) {
                gradientInfoArr[i] = new GLMGradientInfo(gLMIterationTaskArr[i], this._lambda);
            }
            return gradientInfoArr;
        }
    }

    /* loaded from: input_file:hex/glm/GLM$GLMLambdaTask.class */
    public static final class GLMLambdaTask extends DTask<GLMLambdaTask> {
        FrameTask.DataInfo _activeData;
        GLMTaskInfo _taskInfo;
        final double _currentLambda;
        int _iter;
        final Key _jobKey;
        Key _progressKey;
        long _start_time;
        double _addedL2;
        final boolean _forceLBFGS;
        int[] _activeCols;
        private transient IterationInfo _lastResult;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/glm/GLM$GLMLambdaTask$Iteration.class */
        public class Iteration extends H2O.H2OCallback<GLMTask.GLMIterationTask> {
            public final long _iterationStartTime;
            final boolean _countIteration;
            final double _lineSearchStep;
            static final /* synthetic */ boolean $assertionsDisabled;

            public Iteration(GLMLambdaTask gLMLambdaTask, CountedCompleter countedCompleter) {
                this(countedCompleter, true, 1.0d);
            }

            public Iteration(CountedCompleter countedCompleter, boolean z, double d) {
                super((H2O.H2OCountedCompleter) countedCompleter);
                this._lineSearchStep = d;
                this._countIteration = z;
                this._iterationStartTime = System.currentTimeMillis();
            }

            public void callback(GLMTask.GLMIterationTask gLMIterationTask) {
                if (GLMLambdaTask.this._jobKey != null && !Job.isRunning(GLMLambdaTask.this._jobKey)) {
                    throw new Job.JobCancelledException();
                }
                if (!$assertionsDisabled && GLMLambdaTask.this._activeCols != null && gLMIterationTask._beta != null && gLMIterationTask._beta.length != GLMLambdaTask.this._activeCols.length + 1) {
                    throw new AssertionError(GLMLambdaTask.this.LogInfo("betalen = " + gLMIterationTask._beta.length + ", activecols = " + GLMLambdaTask.this._activeCols.length));
                }
                if (!$assertionsDisabled && GLMLambdaTask.this._activeCols != null && GLMLambdaTask.this._activeCols.length != GLMLambdaTask.this._activeData.fullN()) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && getCompleter().getPendingCount() > 1) {
                    throw new AssertionError(GLMLambdaTask.this.LogInfo("unexpected pending count, expected <=  1, got " + getCompleter().getPendingCount()));
                }
                if (this._countIteration) {
                    GLMLambdaTask.this._iter++;
                }
                long currentTimeMillis = System.currentTimeMillis();
                if (GLMLambdaTask.this.needLineSearch(gLMIterationTask, this._lineSearchStep)) {
                    getCompleter().addToPendingCount(1);
                    GLMLambdaTask.this.LogInfo("invoking line search");
                    double[] dArr = GLMLambdaTask.this._lastResult._beta;
                    if (dArr == null) {
                        dArr = MemoryManager.malloc8d(GLMLambdaTask.this._taskInfo._dinfo.fullN() + 1);
                        dArr[dArr.length - 1] = GLMLambdaTask.this._taskInfo._params.link(GLMLambdaTask.this._taskInfo._ymu);
                    }
                    new GLMTask.GLMLineSearchTask(GLMLambdaTask.this._jobKey, GLMLambdaTask.this._activeData, GLMLambdaTask.this._taskInfo._params, dArr, gLMIterationTask._beta, GLM.beta_epsilon, GLMLambdaTask.this._taskInfo._ymu, GLMLambdaTask.this._taskInfo._nobs, new LineSearchIteration(getCompleter())).asyncExec(GLMLambdaTask.this._activeData._adaptedFrame);
                    return;
                }
                if (gLMIterationTask._newThresholds != null) {
                    GLMLambdaTask.this._taskInfo._thresholds = ArrayUtils.join(gLMIterationTask._newThresholds[0], gLMIterationTask._newThresholds[1]);
                    Arrays.sort(GLMLambdaTask.this._taskInfo._thresholds);
                }
                double d = Double.NaN;
                if (gLMIterationTask._val != null && gLMIterationTask._computeGradient) {
                    GLMLambdaTask.this._lastResult = new IterationInfo(GLMLambdaTask.this._iter, gLMIterationTask._beta, gLMIterationTask.gradient(GLMLambdaTask.this._taskInfo._params._alpha[0], GLMLambdaTask.this._currentLambda), GLM.objval(gLMIterationTask, GLMLambdaTask.this._taskInfo._params._alpha[0], GLMLambdaTask.this._currentLambda));
                    double[] dArr2 = (double[]) GLMLambdaTask.this._lastResult._grad.clone();
                    LSMSolver.ADMMSolver.subgrad(GLMLambdaTask.this._taskInfo._params._alpha[0], GLMLambdaTask.this._currentLambda, gLMIterationTask._beta, dArr2);
                    d = 0.0d;
                    for (double d2 : dArr2) {
                        if (d2 > d) {
                            d = d2;
                        } else if (d2 < (-d)) {
                            d = -d2;
                        }
                    }
                    if (d <= GLM.GLM_GRAD_EPS) {
                        GLMLambdaTask.this.LogInfo("converged by reaching small enough gradient, with max |subgradient| = " + d);
                        GLMLambdaTask.this.checkKKTAndComplete(gLMIterationTask._beta, false);
                        return;
                    }
                }
                double[] malloc8d = MemoryManager.malloc8d(gLMIterationTask._xy.length);
                long currentTimeMillis2 = System.currentTimeMillis();
                LSMSolver.ADMMSolver aDMMSolver = new LSMSolver.ADMMSolver(GLMLambdaTask.this._currentLambda, GLMLambdaTask.this._taskInfo._params._alpha[0], GLM.GLM_GRAD_EPS, GLMLambdaTask.this._addedL2);
                aDMMSolver.solve(gLMIterationTask._gram, gLMIterationTask._xy, gLMIterationTask._yy, malloc8d, GLMLambdaTask.this._currentLambda * GLMLambdaTask.this._taskInfo._params._alpha[0]);
                if (this._lineSearchStep < 1.0d) {
                    if (gLMIterationTask._beta != null) {
                        for (int i = 0; i < malloc8d.length; i++) {
                            malloc8d[i] = (gLMIterationTask._beta[i] * (1.0d - this._lineSearchStep)) + (this._lineSearchStep * malloc8d[i]);
                        }
                    } else {
                        for (int i2 = 0; i2 < malloc8d.length; i2++) {
                            int i3 = i2;
                            malloc8d[i3] = malloc8d[i3] * this._lineSearchStep;
                        }
                    }
                }
                GLMLambdaTask.this.LogInfo("Gram computed in " + (currentTimeMillis - this._iterationStartTime) + "ms, " + (Double.isNaN(d) ? "" : "gradient = " + d + ",") + ", step = " + this._lineSearchStep + ", ADMM: " + aDMMSolver.iterations + " iterations, " + (System.currentTimeMillis() - currentTimeMillis2) + "ms (" + aDMMSolver.decompTime + "), subgrad_err=" + aDMMSolver.gerr);
                if (aDMMSolver._addedL2 > GLMLambdaTask.this._addedL2) {
                    GLMLambdaTask.this.LogInfo("added " + (aDMMSolver._addedL2 - GLMLambdaTask.this._addedL2) + "L2 penalty");
                }
                new Job.ProgressUpdate(1L).fork(GLMLambdaTask.this._progressKey);
                GLMLambdaTask.this._addedL2 = aDMMSolver._addedL2;
                if (ArrayUtils.hasNaNsOrInfs(malloc8d)) {
                    throw new RuntimeException(GLMLambdaTask.this.LogInfo("got NaNs and/or Infs in beta"));
                }
                double beta_diff = GLM.beta_diff(gLMIterationTask._beta, malloc8d);
                if (GLMLambdaTask.this._taskInfo._params._family != GLMModel.GLMParameters.Family.gaussian && beta_diff >= GLM.beta_epsilon && GLMLambdaTask.this._iter < GLMLambdaTask.this._taskInfo._max_iter) {
                    if (gLMIterationTask._beta != null) {
                        GLMLambdaTask.this.setSubmodel(gLMIterationTask._beta, gLMIterationTask._val, getCompleter().getCompleter());
                    }
                    boolean z = GLMLambdaTask.this._taskInfo._params._higher_accuracy || GLMLambdaTask.this._iter % 5 == 0;
                    getCompleter().addToPendingCount(1);
                    new GLMTask.GLMIterationTask(GLMLambdaTask.this._jobKey, GLMLambdaTask.this._activeData, gLMIterationTask._glm, true, z, z, malloc8d, GLMLambdaTask.this._taskInfo._ymu, 1.0d / GLMLambdaTask.this._taskInfo._nobs, GLMLambdaTask.this._taskInfo._thresholds, new Iteration(getCompleter(), true, Math.min(1.0d, 2.0d * this._lineSearchStep))).asyncExec(GLMLambdaTask.this._activeData._adaptedFrame);
                    return;
                }
                int log10 = (int) Math.log10(beta_diff);
                int i4 = 0;
                for (double d3 : malloc8d) {
                    if (d3 != 0.0d) {
                        i4++;
                    }
                }
                GLMLambdaTask.this.LogInfo("converged (reached a fixed point with ~ 1e" + log10 + " precision), got " + i4 + " nzs");
                GLMLambdaTask.this.checkKKTAndComplete(malloc8d, false);
            }

            static {
                $assertionsDisabled = !GLM.class.desiredAssertionStatus();
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/glm/GLM$GLMLambdaTask$IterationInfo.class */
        public static final class IterationInfo {
            final double[] _beta;
            final double[] _grad;
            final double _objval;
            final int _iter;

            public IterationInfo(int i, double[] dArr, double[] dArr2, double d) {
                this._iter = i;
                this._beta = dArr;
                this._grad = dArr2;
                this._objval = d;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/glm/GLM$GLMLambdaTask$LineSearchIteration.class */
        public class LineSearchIteration extends H2O.H2OCallback<GLMTask.GLMLineSearchTask> {
            static final /* synthetic */ boolean $assertionsDisabled;

            LineSearchIteration(CountedCompleter countedCompleter) {
                super((H2O.H2OCountedCompleter) countedCompleter);
            }

            public void callback(GLMTask.GLMLineSearchTask gLMLineSearchTask) {
                if (!$assertionsDisabled && getCompleter().getPendingCount() > 1) {
                    throw new AssertionError("unexpected pending count, expected 1, got " + getCompleter().getPendingCount());
                }
                double d = 0.5d;
                for (int i = 0; i < gLMLineSearchTask._glmts.length; i++) {
                    if (!GLMLambdaTask.this.needLineSearch(gLMLineSearchTask._glmts[i], d)) {
                        GLMLambdaTask.this.LogInfo("line search: found admissible step = " + d + ",  objval = " + GLM.objval(gLMLineSearchTask._glmts[i], GLMLambdaTask.this._taskInfo._params._alpha[0], GLMLambdaTask.this._currentLambda));
                        GLMLambdaTask.this._taskInfo._params._higher_accuracy = true;
                        getCompleter().addToPendingCount(1);
                        new GLMTask.GLMIterationTask(GLMLambdaTask.this._jobKey, GLMLambdaTask.this._activeData, GLMLambdaTask.this._taskInfo._params, true, true, true, gLMLineSearchTask._glmts[i]._beta, GLMLambdaTask.this._taskInfo._ymu, 1.0d / GLMLambdaTask.this._taskInfo._nobs, GLMLambdaTask.this._taskInfo._thresholds, new Iteration(getCompleter(), false, d)).asyncExec(GLMLambdaTask.this._activeData._adaptedFrame);
                        return;
                    }
                    d *= 0.5d;
                }
                if (GLMLambdaTask.this._taskInfo._params._higher_accuracy) {
                    GLMLambdaTask.this.LogInfo("Line search did not find feasible step, converged.");
                    GLMLambdaTask.this.checkKKTAndComplete(GLMLambdaTask.this._lastResult._beta, true);
                    return;
                }
                GLMLambdaTask.this._taskInfo._params._higher_accuracy = true;
                int i2 = GLMLambdaTask.this._iter - GLMLambdaTask.this._taskInfo._iter;
                GLMLambdaTask.this.LogInfo("Line search failed to progress, rerunning current lambda from scratch with high accuracy on, adding " + i2 + " to max iterations");
                GLMLambdaTask.this._taskInfo._max_iter += i2;
                getCompleter().addToPendingCount(1);
                new GLMTask.GLMIterationTask(GLMLambdaTask.this._jobKey, GLMLambdaTask.this._activeData, GLMLambdaTask.this._taskInfo._params, true, true, true, GLM.contractVec(GLMLambdaTask.this._taskInfo._beta, GLMLambdaTask.this._activeCols), GLMLambdaTask.this._taskInfo._ymu, 1.0d / GLMLambdaTask.this._taskInfo._nobs, GLMLambdaTask.this._taskInfo._thresholds, new Iteration(getCompleter(), false, 1.0d)).asyncExec(GLMLambdaTask.this._activeData._adaptedFrame);
            }

            static {
                $assertionsDisabled = !GLM.class.desiredAssertionStatus();
            }
        }

        public GLMLambdaTask(H2O.H2OCountedCompleter h2OCountedCompleter, Key key, Key key2, GLMTaskInfo gLMTaskInfo, double d, boolean z) {
            super(h2OCountedCompleter);
            this._taskInfo = gLMTaskInfo;
            if (!$assertionsDisabled && DKV.get(this._taskInfo._dinfo._key) == null) {
                throw new AssertionError();
            }
            this._currentLambda = d;
            this._jobKey = key;
            this._progressKey = key2;
            this._forceLBFGS = z;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public String LogInfo(String str) {
            String str2 = "GLM2[dest=" + this._taskInfo._dstKey + ", iteration=" + this._iter + ", lambda = " + this._currentLambda + "]: " + str;
            Log.info(new Object[]{str2});
            return str2;
        }

        private int[] activeCols(double d, double d2, double[] dArr) {
            int i = 0;
            int[] iArr = null;
            if (this._taskInfo._params._alpha[0] > 0.0d) {
                double d3 = this._taskInfo._params._alpha[0] * ((2.0d * d) - d2);
                iArr = MemoryManager.malloc4(this._taskInfo._dinfo.fullN());
                int i2 = 0;
                if (this._activeCols == null) {
                    this._activeCols = new int[]{-1};
                }
                for (int i3 = 0; i3 < this._taskInfo._dinfo.fullN(); i3++) {
                    if ((i2 < this._activeCols.length && i3 == this._activeCols[i2]) || dArr[i3] > d3 || dArr[i3] < (-d3)) {
                        int i4 = i;
                        i++;
                        iArr[i4] = i3;
                        if (i2 < this._activeCols.length && i3 == this._activeCols[i2]) {
                            i2++;
                        }
                    }
                }
            }
            if (this._taskInfo._params._alpha[0] == 0.0d || i == this._taskInfo._dinfo.fullN()) {
                this._activeCols = null;
                this._activeData = this._taskInfo._dinfo;
                i = this._taskInfo._dinfo.fullN();
            } else {
                this._activeCols = Arrays.copyOf(iArr, i);
                this._activeData = this._taskInfo._dinfo.filterExpandedColumns(this._activeCols);
                if (!$assertionsDisabled && DKV.get(this._activeData._key) == null) {
                    throw new AssertionError();
                }
            }
            LogInfo("strong rule at lambda_value=" + d + ", got " + i + " active cols out of " + this._taskInfo._dinfo.fullN() + " total.");
            if ($assertionsDisabled || this._activeCols == null || this._activeData.fullN() == this._activeCols.length) {
                return this._activeCols;
            }
            throw new AssertionError(LogInfo("mismatched number of cols, got " + this._activeCols.length + " active cols, but data info claims " + this._activeData.fullN()));
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double[] setSubmodel(double[] dArr, GLMValidation gLMValidation, H2O.H2OCountedCompleter h2OCountedCompleter) {
            double[] dArr2;
            double[] expandVec = (this._activeCols == null || dArr == null) ? dArr : GLM.expandVec(dArr, this._activeCols, this._taskInfo._dinfo.fullN() + 1);
            if (expandVec == null) {
                expandVec = MemoryManager.malloc8d(this._taskInfo._dinfo.fullN() + 1);
                expandVec[expandVec.length - 1] = this._taskInfo._params.linkInv(this._taskInfo._ymu);
            }
            if (this._taskInfo._dinfo._predictor_transform == FrameTask.DataInfo.TransformType.STANDARDIZE) {
                dArr2 = (double[]) expandVec.clone();
                double d = 0.0d;
                int numStart = this._taskInfo._dinfo.numStart();
                for (int i = numStart; i < expandVec.length - 1; i++) {
                    double d2 = dArr2[i] * this._taskInfo._dinfo._normMul[i - numStart];
                    d += d2 * this._taskInfo._dinfo._normSub[i - numStart];
                    dArr2[i] = d2;
                }
                int length = dArr2.length - 1;
                dArr2[length] = dArr2[length] - d;
            } else {
                dArr2 = null;
            }
            GLMModel.setSubmodel(h2OCountedCompleter, this._taskInfo._dstKey, this._currentLambda, dArr2 == null ? expandVec : dArr2, dArr2 == null ? null : expandVec, this._iter + 1, System.currentTimeMillis() - this._start_time, this._taskInfo._dinfo.fullN() >= GLM.sparseCoefThreshold, gLMValidation);
            return expandVec;
        }

        protected void checkKKTAndComplete(final double[] dArr, final boolean z) {
            double[] expandVec;
            H2O.H2OCountedCompleter completer = getCompleter();
            completer.addToPendingCount(1);
            if (dArr == null) {
                expandVec = MemoryManager.malloc8d(this._taskInfo._dinfo.fullN() + 1);
                expandVec[expandVec.length - 1] = this._taskInfo._params.linkInv(this._taskInfo._ymu);
            } else {
                expandVec = GLM.expandVec(dArr, this._activeCols, this._taskInfo._dinfo.fullN() + 1);
            }
            final double[] dArr2 = expandVec;
            new GLMTask.GLMIterationTask(this._jobKey, this._taskInfo._dinfo, this._taskInfo._params, false, true, true, expandVec, this._taskInfo._ymu, 1.0d / this._taskInfo._nobs, this._taskInfo._thresholds, new H2O.H2OCallback<GLMTask.GLMIterationTask>(completer) { // from class: hex.glm.GLM.GLMLambdaTask.1
                public String toString() {
                    return new StringBuilder().append("checkKKTAndComplete.Callback, completer = ").append(getCompleter()).toString() == null ? "null" : getCompleter().toString();
                }

                public void callback(GLMTask.GLMIterationTask gLMIterationTask) {
                    double[] gradient = gLMIterationTask.gradient(GLMLambdaTask.this._taskInfo._params._alpha[0], GLMLambdaTask.this._currentLambda);
                    if (ArrayUtils.hasNaNsOrInfs(gradient)) {
                        if (!z) {
                            GLMLambdaTask.this.LogInfo("Check KKT got NaNs. Invoking line search");
                            GLMLambdaTask.this._taskInfo._params._higher_accuracy = true;
                            getCompleter().addToPendingCount(1);
                            new GLMTask.GLMLineSearchTask(GLMLambdaTask.this._jobKey, GLMLambdaTask.this._activeData, GLMLambdaTask.this._taskInfo._params, GLMLambdaTask.this._lastResult._beta, GLM.contractVec(dArr2, GLMLambdaTask.this._activeCols), GLM.beta_epsilon, GLMLambdaTask.this._taskInfo._ymu, GLMLambdaTask.this._taskInfo._nobs, new LineSearchIteration(getCompleter())).asyncExec(GLMLambdaTask.this._activeData._adaptedFrame);
                            return;
                        }
                        GLMLambdaTask.this.LogInfo("got NaNs/Infs in gradient at lambda " + GLMLambdaTask.this._currentLambda);
                    }
                    double[] dArr3 = (double[]) gradient.clone();
                    LSMSolver.ADMMSolver.subgrad(GLMLambdaTask.this._taskInfo._params._alpha[0], GLMLambdaTask.this._currentLambda, dArr2, dArr3);
                    double d = GLM.GLM_GRAD_EPS;
                    if (!z && GLMLambdaTask.this._activeCols != null) {
                        for (int i : GLMLambdaTask.this._activeCols) {
                            if (dArr3[i] > d) {
                                d = dArr3[i];
                            } else if (dArr3[i] < (-d)) {
                                d = -dArr3[i];
                            }
                        }
                        int[] iArr = new int[64];
                        int i2 = 0;
                        for (int i3 = 0; i3 < gradient.length - 1; i3++) {
                            if (Arrays.binarySearch(GLMLambdaTask.this._activeCols, i3) < 0 && (dArr3[i3] > d || (-dArr3[i3]) > d)) {
                                if (i2 == iArr.length) {
                                    iArr = Arrays.copyOf(iArr, iArr.length << 1);
                                }
                                int i4 = i2;
                                i2++;
                                iArr[i4] = i3;
                            }
                        }
                        if (i2 > 0) {
                            int length = GLMLambdaTask.this._activeCols.length;
                            int[] iArr2 = GLMLambdaTask.this._activeCols;
                            GLMLambdaTask.this._activeCols = Arrays.copyOf(GLMLambdaTask.this._activeCols, GLMLambdaTask.this._activeCols.length + i2);
                            for (int i5 = 0; i5 < i2; i5++) {
                                GLMLambdaTask.this._activeCols[length + i5] = iArr[i5];
                            }
                            if (GLMLambdaTask.this._lastResult != null) {
                                GLMLambdaTask.this._lastResult = new IterationInfo(GLMLambdaTask.this._lastResult._iter, GLM.resizeVec(GLMLambdaTask.this._lastResult._beta, GLMLambdaTask.this._activeCols, iArr2, GLMLambdaTask.this._taskInfo._dinfo.fullN() + 1), GLM.resizeVec(GLMLambdaTask.this._lastResult._grad, GLMLambdaTask.this._activeCols, iArr2, GLMLambdaTask.this._taskInfo._dinfo.fullN() + 1), GLMLambdaTask.this._lastResult._objval);
                            }
                            Arrays.sort(GLMLambdaTask.this._activeCols);
                            GLMLambdaTask.this.LogInfo(i2 + " variables failed KKT conditions check! Adding them to the model and continuing computation.(grad_eps = " + d + ", activeCols = " + (GLMLambdaTask.this._activeCols.length > 100 ? "lost" : Arrays.toString(GLMLambdaTask.this._activeCols)));
                            GLMLambdaTask.this._activeData = GLMLambdaTask.this._taskInfo._dinfo.filterExpandedColumns(GLMLambdaTask.this._activeCols);
                            getCompleter().addToPendingCount(1);
                            new GLMTask.GLMIterationTask(GLMLambdaTask.this._jobKey, GLMLambdaTask.this._activeData, GLMLambdaTask.this._taskInfo._params, true, true, true, GLM.contractVec(gLMIterationTask._beta, GLMLambdaTask.this._activeCols), GLMLambdaTask.this._taskInfo._ymu, 1.0d / GLMLambdaTask.this._taskInfo._nobs, GLMLambdaTask.this._taskInfo._thresholds, new Iteration(GLMLambdaTask.this, getCompleter())).asyncExec(GLMLambdaTask.this._activeData._adaptedFrame);
                            return;
                        }
                    }
                    GLMLambdaTask.this._taskInfo._beta = gLMIterationTask._beta;
                    GLMLambdaTask.this._taskInfo._gradient = gLMIterationTask.gradient(GLMLambdaTask.this._taskInfo._params._alpha[0], GLMLambdaTask.this._taskInfo._lastLambda);
                    GLMLambdaTask.this._taskInfo._iter = GLMLambdaTask.this._iter;
                    int i6 = (GLM.MAX_ITERATIONS_PER_LAMBDA - GLMLambdaTask.this._iter) + GLMLambdaTask.this._taskInfo._iter;
                    if (i6 > 0) {
                        new Job.ProgressUpdate(i6).fork(GLMLambdaTask.this._progressKey);
                    }
                    GLMLambdaTask.this.setSubmodel(dArr, gLMIterationTask._val, getCompleter().getCompleter());
                }
            }).asyncExec(this._taskInfo._dinfo._adaptedFrame);
        }

        protected boolean needLineSearch(GLMTask.GLMIterationTask gLMIterationTask) {
            return needLineSearch(gLMIterationTask, 1.0d);
        }

        protected boolean needLineSearch(GLMTask.GLMIterationTask gLMIterationTask, double d) {
            if (this._taskInfo._params._family == GLMModel.GLMParameters.Family.gaussian || gLMIterationTask._beta == null) {
                return false;
            }
            if (ArrayUtils.hasNaNsOrInfs(gLMIterationTask._xy)) {
                return true;
            }
            if (gLMIterationTask._grad != null && ArrayUtils.hasNaNsOrInfs(gLMIterationTask._grad)) {
                return true;
            }
            if (gLMIterationTask._gram != null && gLMIterationTask._gram.hasNaNsOrInfs()) {
                return true;
            }
            if (gLMIterationTask._val != null && gLMIterationTask._val.residual_deviance > gLMIterationTask._val.null_deviance) {
                return true;
            }
            if (gLMIterationTask._val == null) {
                return false;
            }
            return needLineSearch(gLMIterationTask._beta, GLM.objval(gLMIterationTask, this._taskInfo._params._alpha[0], this._currentLambda), d);
        }

        protected boolean needLineSearch(double[] dArr, double d, double d2) {
            if (!$assertionsDisabled && dArr == null) {
                throw new AssertionError();
            }
            if (Double.isNaN(d)) {
                return true;
            }
            double[] dArr2 = this._lastResult._grad;
            double d3 = 0.0d;
            double[] contractVec = this._lastResult == null ? GLM.contractVec(this._taskInfo._beta, this._activeCols) : this._lastResult._beta;
            if (contractVec == null) {
                for (int i = 0; i < dArr.length; i++) {
                    d3 += (d2 * dArr2[i] * dArr[i]) + (0.5d * dArr[i] * dArr[i]);
                }
            } else {
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    double d4 = dArr[i2] - contractVec[i2];
                    d3 += (d2 * dArr2[i2] * d4) + (0.5d * d4 * d4);
                }
            }
            return d > (GLM.beta_epsilon * d3) + this._lastResult._objval;
        }

        private static final boolean isSparse(Frame frame) {
            int i = 0;
            for (Vec vec : frame.vecs()) {
                if ((vec.nzCnt() << 3) > vec.length()) {
                    i++;
                }
            }
            return (frame.numCols() >> 1) < i;
        }

        private L_BFGS.GradientInfo adjustL2(L_BFGS.GradientInfo gradientInfo, double[] dArr, double d) {
            for (int i = 0; i < dArr.length - 1; i++) {
                double[] dArr2 = gradientInfo._gradient;
                int i2 = i;
                dArr2[i2] = dArr2[i2] + (d * dArr[i]);
            }
            return gradientInfo;
        }

        private MRTask makeGLMTask(Key key, FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters, boolean z, boolean z2, boolean z3, double[] dArr) {
            return null;
        }

        protected void compute2() {
            this._start_time = System.currentTimeMillis();
            if (this._currentLambda > this._taskInfo._lambdaMax) {
                tryComplete();
                return;
            }
            this._iter = this._taskInfo._iter;
            LogInfo("starting computation of lambda = " + this._currentLambda + ", previous lambda = " + this._taskInfo._lastLambda);
            int[] activeCols = activeCols(this._currentLambda, this._taskInfo._lastLambda, this._taskInfo._gradient);
            if ((activeCols == null ? this._taskInfo._dinfo.fullN() : activeCols.length) > this._taskInfo._params._max_active_predictors) {
                throw new TooManyPredictorsException();
            }
            double[] contractVec = GLM.contractVec(this._taskInfo._beta, this._activeCols);
            this._lastResult = new IterationInfo(this._taskInfo._iter, contractVec, GLM.contractVec(this._taskInfo._gradient, this._activeCols), this._taskInfo._objval);
            if (!this._forceLBFGS) {
                new GLMTask.GLMIterationTask(this._jobKey, this._activeData, this._taskInfo._params, true, false, false, contractVec, this._taskInfo._ymu, 1.0d / this._taskInfo._nobs, this._taskInfo._thresholds, new Iteration(this, this)).asyncExec(this._activeData._adaptedFrame);
                return;
            }
            if (this._taskInfo._params._alpha[0] > 0.0d || this._activeCols != null) {
                throw H2O.unimpl();
            }
            Log.info(new Object[]{"current lambda = " + this._currentLambda});
            L_BFGS.GradientSolver gLMColBasedGradientSolver = (this._activeData._adaptedFrame.numCols() >= 100 || isSparse(this._activeData._adaptedFrame)) ? new GLMColBasedGradientSolver(this._taskInfo._params, this._activeData, this._currentLambda, this._taskInfo._ymu, this._taskInfo._nobs) : new GLMGradientSolver(this._taskInfo._params, this._activeData, this._currentLambda, this._taskInfo._ymu, this._taskInfo._nobs);
            if (contractVec == null) {
                contractVec = MemoryManager.malloc8d(this._activeData.fullN() + 1);
                contractVec[contractVec.length - 1] = this._taskInfo._params.link(this._taskInfo._ymu);
            }
            long currentTimeMillis = System.currentTimeMillis();
            if (this._taskInfo._lbfgs == null) {
                this._taskInfo._lbfgs = new L_BFGS();
            }
            L_BFGS.GradientInfo gradient = this._taskInfo._gOld == null ? gLMColBasedGradientSolver.getGradient(contractVec) : adjustL2(this._taskInfo._gOld, contractVec, this._currentLambda - this._taskInfo._lastLambda);
            final int length = (GLM.WORK_TOTAL / this._taskInfo._params._lambda.length) / this._taskInfo._lbfgs.maxIter();
            L_BFGS.Result solve = this._taskInfo._lbfgs.solve(gLMColBasedGradientSolver, contractVec, gradient, new L_BFGS.ProgressMonitor() { // from class: hex.glm.GLM.GLMLambdaTask.2
                @Override // hex.optimization.L_BFGS.ProgressMonitor
                public boolean progress(L_BFGS.GradientInfo gradientInfo) {
                    Job.update(length, GLMLambdaTask.this._jobKey);
                    return Job.isRunning(GLMLambdaTask.this._jobKey);
                }
            });
            Log.info(new Object[]{"L_BFGS (k = " + this._taskInfo._lbfgs.k() + ") done after " + solve.iter + " iterations and " + ((System.currentTimeMillis() - currentTimeMillis) / 1000) + " seconds, objval = " + solve.ginfo._objVal + ", penalty = " + (this._currentLambda * 0.5d * ArrayUtils.l2norm2(contractVec, true)) + ",  gradient norm2 = " + MathUtils.l2norm2(solve.ginfo._gradient)});
            this._taskInfo._gOld = solve.ginfo;
            double[] dArr = solve.coefs;
            this._taskInfo._beta = dArr;
            GLMTaskInfo gLMTaskInfo = this._taskInfo;
            int i = this._iter + solve.iter;
            this._iter = i;
            gLMTaskInfo._iter = i;
            setSubmodel(dArr, null, this);
            tryComplete();
        }

        static {
            $assertionsDisabled = !GLM.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/glm/GLM$GLMTaskInfo.class */
    public static final class GLMTaskInfo extends Iced {
        final long _nobs;
        final double _ymu;
        final double _lambdaMax;
        double[] _beta;
        double[] _gradient;
        int _iter;
        int _max_iter;
        double _lastLambda;
        float[] _thresholds;
        double _objval;
        final Key _dstKey;
        final FrameTask.DataInfo _dinfo;
        final GLMModel.GLMParameters _params;
        L_BFGS _lbfgs;
        L_BFGS.GradientInfo _gOld;

        public GLMTaskInfo(Key key, FrameTask.DataInfo dataInfo, GLMModel.GLMParameters gLMParameters, long j, double d, double d2, double d3, double[] dArr, double[] dArr2, double d4) {
            this._dstKey = key;
            this._dinfo = dataInfo;
            this._params = gLMParameters;
            this._nobs = j;
            this._ymu = d;
            this._lambdaMax = d2;
            this._lastLambda = d3;
            this._beta = dArr;
            this._gradient = dArr2;
            this._max_iter = this._params._lambda_search ? GLM.MAX_ITERATIONS_PER_LAMBDA : 50;
            this._objval = d4;
            if (this._params._family == GLMModel.GLMParameters.Family.binomial) {
                this._thresholds = ModelUtils.DEFAULT_THRESHOLDS;
            }
        }
    }

    /* loaded from: input_file:hex/glm/GLM$TooManyPredictorsException.class */
    private static class TooManyPredictorsException extends RuntimeException {
        private TooManyPredictorsException() {
        }
    }

    public Model.ModelCategory[] can_build() {
        return new Model.ModelCategory[]{Model.ModelCategory.Regression};
    }

    public GLM(Key key, String str, GLMModel.GLMParameters gLMParameters) {
        super(key, str, gLMParameters);
        init(false);
    }

    public GLM(GLMModel.GLMParameters gLMParameters) {
        super("GLM", gLMParameters);
        init(false);
    }

    public void init(boolean z) {
        super.init(z);
        if (z && this._parms._link == GLMModel.GLMParameters.Link.family_default) {
            this._parms._link = this._parms._family.defaultLink;
        }
        this._parms.validate(this);
    }

    public ModelBuilderSchema schema() {
        return new GLMV2();
    }

    public Job<GLMModel> trainModel() {
        this._clean_enums = this._parms._convert_to_enum && !this._response.isEnum();
        this._parms.read_lock_frames(this);
        init(true);
        FrameTask.DataInfo dataInfo = new FrameTask.DataInfo(Key.make(), this._train, this._valid, 1, this._parms._use_all_factor_levels || this._parms._lambda_search, this._parms._standardize ? FrameTask.DataInfo.TransformType.STANDARDIZE : FrameTask.DataInfo.TransformType.NONE, FrameTask.DataInfo.TransformType.NONE);
        DKV.put(dataInfo._key, dataInfo);
        H2O.H2OCountedCompleter h2OCountedCompleter = new H2O.H2OCountedCompleter() { // from class: hex.glm.GLM.1
            AtomicBoolean _gotException = new AtomicBoolean(false);

            public void compute2() {
            }

            public void onCompletion(CountedCompleter countedCompleter) {
                GLM.this.done();
                GLM.this._parms.read_unlock_frames(GLM.this);
                if (GLM.this._clean_enums) {
                    GLM.this.train().lastVec().remove();
                    if (GLM.this.valid() != null) {
                        GLM.this.valid().lastVec().remove();
                    }
                }
            }

            public boolean onExceptionalCompletion(Throwable th, CountedCompleter countedCompleter) {
                if (this._gotException.getAndSet(true)) {
                    return false;
                }
                GLM.this.failed(th);
                GLM.this._parms.read_unlock_frames(GLM.this);
                if (!GLM.this._clean_enums) {
                    return true;
                }
                GLM.this.train().lastVec().remove();
                if (GLM.this.valid() == null) {
                    return true;
                }
                GLM.this.valid().lastVec().remove();
                return true;
            }
        };
        start(h2OCountedCompleter, 100000000L);
        H2O.submitTask(new GLMDriver(h2OCountedCompleter, dataInfo));
        return this;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final double beta_diff(double[] dArr, double[] dArr2) {
        if (dArr == null) {
            return Double.MAX_VALUE;
        }
        double d = dArr[0] >= dArr2[0] ? dArr[0] - dArr2[0] : dArr2[0] - dArr[0];
        for (int i = 1; i < dArr.length; i++) {
            double d2 = dArr[i] - dArr2[i];
            if (d2 > d) {
                d = d2;
            } else if ((-d2) > d) {
                d = -d2;
            }
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final double[] expandVec(double[] dArr, int[] iArr, int i) {
        if (!$assertionsDisabled && dArr == null) {
            throw new AssertionError();
        }
        if (iArr == null) {
            return dArr;
        }
        double[] malloc8d = MemoryManager.malloc8d(i);
        int i2 = 0;
        for (int i3 : iArr) {
            int i4 = i2;
            i2++;
            malloc8d[i3] = dArr[i4];
        }
        malloc8d[malloc8d.length - 1] = dArr[dArr.length - 1];
        return malloc8d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final double[] contractVec(double[] dArr, int[] iArr) {
        if (dArr == null) {
            return null;
        }
        if (iArr == null) {
            return (double[]) dArr.clone();
        }
        double[] malloc8d = MemoryManager.malloc8d(iArr.length + 1);
        int i = 0;
        for (int i2 : iArr) {
            int i3 = i;
            i++;
            malloc8d[i3] = dArr[i2];
        }
        malloc8d[malloc8d.length - 1] = dArr[dArr.length - 1];
        return malloc8d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static final double[] resizeVec(double[] dArr, int[] iArr, int[] iArr2, int i) {
        if (dArr == null || Arrays.equals(iArr, iArr2)) {
            return dArr;
        }
        double[] expandVec = expandVec(dArr, iArr2, i);
        return iArr == null ? expandVec : contractVec(expandVec, iArr);
    }

    protected static double l1norm(double[] dArr) {
        if (dArr == null) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < dArr.length - 1; i++) {
            d += dArr[i] < 0.0d ? -dArr[i] : dArr[i];
        }
        return d;
    }

    private static double penalty(double[] dArr, double d, double d2) {
        return d2 * ((d * l1norm(dArr)) + (0.5d * (1.0d - d) * ArrayUtils.l2norm(dArr, true)));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double objval(GLMTask.GLMIterationTask gLMIterationTask, double d, double d2) {
        return (gLMIterationTask._val.residual_deviance / gLMIterationTask._nobs) + penalty(gLMIterationTask._beta, d, d2);
    }

    static {
        $assertionsDisabled = !GLM.class.desiredAssertionStatus();
        GLM_GRAD_EPS = beta_epsilon;
    }
}
