package hex.glm;

import hex.glm.Gram;
import java.util.Arrays;
import jsr166y.CountedCompleter;
import water.H2O;
import water.Iced;
import water.Key;
import water.MemoryManager;

/* loaded from: input_file:hex/glm/LSMSolver.class */
public abstract class LSMSolver extends Iced {
    double _lambda;
    final double _alpha;
    public Key _jobKey;
    public String _id;
    protected boolean _converged;

    /* loaded from: input_file:hex/glm/LSMSolver$ADMMSolver.class */
    public static final class ADMMSolver extends LSMSolver {
        public static final double DEFAULT_ALPHA = 0.5d;
        public double[] _wgiven;
        public double _proximalPenalty;
        public final double _gradientEps;
        private static final double GLM1_RHO = 0.001d;
        public double gerr;
        public int iterations;
        public long decompTime;
        public double _addedL2;
        static final double RELTOL = 1.0E-4d;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* loaded from: input_file:hex/glm/LSMSolver$ADMMSolver$NonSPDMatrixException.class */
        public static class NonSPDMatrixException extends LSMSolverException {
            public NonSPDMatrixException() {
                super("Matrix is not SPD, can't solve without regularization\n");
            }

            public NonSPDMatrixException(Gram gram) {
                super("Matrix is not SPD, can't solve without regularization\n" + gram);
            }
        }

        /* loaded from: input_file:hex/glm/LSMSolver$ADMMSolver$ParallelSolver.class */
        public final class ParallelSolver extends H2O.H2OCountedCompleter {
            final Gram gram;
            final double rho;
            final double kappa;
            double _bestErr;
            double _lastErr;
            final double[] xy;
            double[] _xyPrime;
            double _orlx;
            int _k;
            final double[] u;
            final double[] z;
            Gram.Cholesky chol;
            final double d;
            int _iter;
            final int N;
            final int max_iter;
            final int round;
            final int _iBlock;
            final int _rBlock;
            static final /* synthetic */ boolean $assertionsDisabled;

            /* loaded from: input_file:hex/glm/LSMSolver$ADMMSolver$ParallelSolver$ADMMIteration.class */
            private final class ADMMIteration extends CountedCompleter {
                final long t1;

                public ADMMIteration(H2O.H2OCountedCompleter h2OCountedCompleter) {
                    super(h2OCountedCompleter);
                    this.t1 = System.currentTimeMillis();
                }

                public void compute() {
                    ParallelSolver.this._iter++;
                    double[] dArr = ParallelSolver.this._xyPrime;
                    for (int i = 0; i < ParallelSolver.this.N - 1; i++) {
                        dArr[i] = ParallelSolver.this.xy[i] + (ParallelSolver.this.rho * (ParallelSolver.this.z[i] - ParallelSolver.this.u[i]));
                    }
                    dArr[ParallelSolver.this.N - 1] = ParallelSolver.this.xy[ParallelSolver.this.N - 1];
                    ParallelSolver.this.chol.parSolver(this, dArr, ParallelSolver.this._iBlock, ParallelSolver.this._rBlock).fork();
                }

                public void onCompletion(CountedCompleter countedCompleter) {
                    double[] dArr = ParallelSolver.this._xyPrime;
                    double d = ParallelSolver.this._orlx;
                    for (int i = 0; i < ParallelSolver.this.N - 1; i++) {
                        double d2 = (dArr[i] * d) + ((1.0d - d) * ParallelSolver.this.z[i]);
                        ParallelSolver.this.z[i] = LSMSolver.shrinkage(d2 + ParallelSolver.this.u[i], ParallelSolver.this.kappa);
                        double[] dArr2 = ParallelSolver.this.u;
                        int i2 = i;
                        dArr2[i2] = dArr2[i2] + (d2 - ParallelSolver.this.z[i]);
                    }
                    ParallelSolver.this.z[ParallelSolver.this.N - 1] = dArr[ParallelSolver.this.N - 1];
                    if (ParallelSolver.this._iter == ParallelSolver.this._k) {
                        double[] grad = ADMMSolver.this.grad(ParallelSolver.this.gram, ParallelSolver.this.z, ParallelSolver.this.xy);
                        LSMSolver.subgrad(ADMMSolver.this._alpha, ADMMSolver.this._lambda, ParallelSolver.this.z, grad);
                        for (int i3 = 0; i3 < grad.length - 1; i3++) {
                            if (ADMMSolver.this.gerr < grad[i3] || ADMMSolver.this.gerr < (-grad[i3])) {
                                ADMMSolver.this.gerr = grad[i3];
                            }
                        }
                        if (ADMMSolver.this.gerr < 9.0E-4d) {
                            return;
                        }
                        ParallelSolver.this._k += ParallelSolver.this.round;
                    }
                    if (ParallelSolver.this._iter < ParallelSolver.this.max_iter) {
                        getCompleter().addToPendingCount(1);
                        new ADMMIteration(getCompleter()).fork();
                    }
                }
            }

            private ParallelSolver(Gram gram, double[] dArr, double[] dArr2, double d, int i, int i2) {
                this._bestErr = Double.POSITIVE_INFINITY;
                this._lastErr = Double.POSITIVE_INFINITY;
                this._iBlock = i;
                this._rBlock = i2;
                this.gram = gram;
                this.xy = dArr;
                this.z = dArr2;
                this.N = dArr.length;
                this.d = this.gram._diagAdded;
                this.rho = d;
                this.u = MemoryManager.malloc8d(this.N);
                this.kappa = (ADMMSolver.this._lambda * ADMMSolver.this._alpha) / d;
                this.max_iter = (int) (10000.0d * (250.0d / (1 + dArr.length)));
                this.round = Math.max(20, (int) (this.max_iter * 0.01d));
                this._k = this.round;
            }

            public void compute2() {
                Arrays.fill(this.z, 0.0d);
                if (ADMMSolver.this._lambda > 0.0d || ADMMSolver.this._addedL2 > 0.0d) {
                    this.gram.addDiag((ADMMSolver.this._lambda * (1.0d - ADMMSolver.this._alpha)) + ADMMSolver.this._addedL2);
                }
                if (ADMMSolver.this._alpha > 0.0d && ADMMSolver.this._lambda > 0.0d) {
                    this.gram.addDiag(this.rho);
                }
                if (ADMMSolver.this._proximalPenalty > 0.0d && ADMMSolver.this._wgiven != null) {
                    this.gram.addDiag(ADMMSolver.this._proximalPenalty, true);
                    for (int i = 0; i < this.xy.length; i++) {
                        double[] dArr = this.xy;
                        int i2 = i;
                        dArr[i2] = dArr[i2] + (ADMMSolver.this._proximalPenalty * ADMMSolver.this._wgiven[i]);
                    }
                }
                int i3 = 0;
                long currentTimeMillis = System.currentTimeMillis();
                this.chol = this.gram.cholesky(null, true, ADMMSolver.this._id);
                long currentTimeMillis2 = System.currentTimeMillis();
                while (!this.chol.isSPD() && i3 < 10) {
                    if (ADMMSolver.this._addedL2 == 0.0d) {
                        ADMMSolver.this._addedL2 = 1.0E-5d;
                    } else {
                        ADMMSolver.this._addedL2 *= 10.0d;
                    }
                    i3++;
                    this.gram.addDiag(ADMMSolver.this._addedL2);
                    this.gram.cholesky(this.chol);
                }
                ADMMSolver.this.decompTime = currentTimeMillis2 - currentTimeMillis;
                if (!this.chol.isSPD()) {
                    throw new NonSPDMatrixException(this.gram);
                }
                if (ADMMSolver.this._alpha == 0.0d || ADMMSolver.this._lambda == 0.0d) {
                    System.arraycopy(this.xy, 0, this.z, 0, this.xy.length);
                    this.chol.parSolver(this, this.z, this._iBlock, this._rBlock).fork();
                } else {
                    ADMMSolver.this.gerr = Double.POSITIVE_INFINITY;
                    this._xyPrime = (double[]) this.xy.clone();
                    this._orlx = 1.8d;
                    new ADMMIteration(this).fork();
                }
            }

            public void onCompletion(CountedCompleter countedCompleter) {
                this.gram.addDiag((-this.gram._diagAdded) + this.d);
                if (!$assertionsDisabled && this.gram._diagAdded != this.d) {
                    throw new AssertionError();
                }
            }

            static {
                $assertionsDisabled = !LSMSolver.class.desiredAssertionStatus();
            }
        }

        public boolean normalize() {
            return this._lambda != 0.0d;
        }

        public ADMMSolver(double d, double d2, double d3) {
            super(d, d2);
            this.gerr = Double.POSITIVE_INFINITY;
            this.iterations = 0;
            this._gradientEps = d3;
        }

        public ADMMSolver(double d, double d2, double d3, double d4) {
            super(d, d2);
            this.gerr = Double.POSITIVE_INFINITY;
            this.iterations = 0;
            this._addedL2 = d4;
            this._gradientEps = d3;
        }

        @Override // hex.glm.LSMSolver
        public boolean solve(Gram gram, double[] dArr, double d, double[] dArr2) {
            return solve(gram, dArr, d, dArr2, Double.POSITIVE_INFINITY);
        }

        private static double l1_norm(double[] dArr) {
            double d = 0.0d;
            for (double d2 : dArr) {
                d += Math.abs(d2);
            }
            return d;
        }

        private static double l2_norm(double[] dArr) {
            double d = 0.0d;
            for (double d2 : dArr) {
                d += d2 * d2;
            }
            return d;
        }

        private double converged(Gram gram, double[] dArr, double[] dArr2) {
            double[] grad = grad(gram, dArr, dArr2);
            subgrad(this._alpha, this._lambda, dArr, grad);
            double d = 0.0d;
            for (double d2 : grad) {
                if (d2 > d) {
                    d = d2;
                } else if (d2 < (-d)) {
                    d = -d2;
                }
            }
            return d;
        }

        private double getGrad(Gram gram, double[] dArr, double[] dArr2) {
            double d = 0.0d;
            for (double d2 : grad(gram, dArr, dArr2)) {
                if (d2 > d) {
                    d = d2;
                } else if (d2 < (-d)) {
                    d = -d2;
                }
            }
            return d;
        }

        public ParallelSolver parSolver(Gram gram, double[] dArr, double[] dArr2, double d, int i, int i2) {
            return new ParallelSolver(gram, dArr, dArr2, d, i, i2);
        }

        public boolean solve(Gram gram, double[] dArr, double d, double[] dArr2, double d2) {
            this.gerr = 0.0d;
            double d3 = gram._diagAdded;
            int length = dArr.length;
            Arrays.fill(dArr2, 0.0d);
            if (this._lambda > 0.0d || this._addedL2 > 0.0d) {
                gram.addDiag((this._lambda * (1.0d - this._alpha)) + this._addedL2);
            }
            if (this._alpha > 0.0d && this._lambda > 0.0d) {
                gram.addDiag(d2);
            }
            if (this._proximalPenalty > 0.0d && this._wgiven != null) {
                gram.addDiag(this._proximalPenalty, true);
                dArr = (double[]) dArr.clone();
                for (int i = 0; i < dArr.length; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + (this._proximalPenalty * this._wgiven[i]);
                }
            }
            int i3 = 0;
            long currentTimeMillis = System.currentTimeMillis();
            Gram.Cholesky cholesky = gram.cholesky(null, true, this._id);
            long currentTimeMillis2 = System.currentTimeMillis();
            while (!cholesky.isSPD() && i3 < 10) {
                if (this._addedL2 == 0.0d) {
                    this._addedL2 = 1.0E-5d;
                } else {
                    this._addedL2 *= 10.0d;
                }
                i3++;
                gram.addDiag(this._addedL2);
                gram.cholesky(cholesky);
            }
            this.decompTime = currentTimeMillis2 - currentTimeMillis;
            if (!cholesky.isSPD()) {
                throw new NonSPDMatrixException(gram);
            }
            if (this._alpha == 0.0d || this._lambda == 0.0d) {
                System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
                cholesky.solve(dArr2);
                gram.addDiag((-gram._diagAdded) + d3);
                return true;
            }
            double[] malloc8d = MemoryManager.malloc8d(length);
            double[] dArr3 = (double[]) dArr.clone();
            double d4 = (this._lambda * this._alpha) / d2;
            int max = Math.max(500, (int) (50000.0d / (1 + (dArr.length >> 3))));
            double d5 = 1.8d;
            double d6 = 1.0E-4d;
            int i4 = 0;
            while (i4 < max) {
                System.currentTimeMillis();
                for (int i5 = 0; i5 < length - 1; i5++) {
                    dArr3[i5] = dArr[i5] + (d2 * (dArr2[i5] - malloc8d[i5]));
                }
                dArr3[length - 1] = dArr[length - 1];
                cholesky.solve(dArr3);
                double d7 = 0.0d;
                double d8 = 0.0d;
                double d9 = 0.0d;
                double d10 = 0.0d;
                for (int i6 = 0; i6 < length - 1; i6++) {
                    double d11 = dArr3[i6];
                    double d12 = dArr2[i6];
                    double d13 = (d11 * d5) + ((1.0d - d5) * d12);
                    dArr2[i6] = shrinkage(d13 + malloc8d[i6], d4);
                    int i7 = i6;
                    malloc8d[i7] = malloc8d[i7] + (d13 - dArr2[i6]);
                    double d14 = dArr3[i6] - dArr2[i6];
                    double d15 = dArr2[i6] - d12;
                    d7 += d14 * d14;
                    d8 += d15 * d15;
                    d10 += d11 * d11;
                    d9 += malloc8d[i6] * malloc8d[i6];
                }
                dArr2[length - 1] = dArr3[length - 1];
                if (d7 < d6 * d10 && d8 < d6 * d9) {
                    this.gerr = 0.0d;
                    double[] grad = grad(gram, dArr2, dArr);
                    subgrad(this._alpha, this._lambda, dArr2, grad);
                    for (int i8 = 0; i8 < grad.length - 1; i8++) {
                        if (this.gerr < grad[i8]) {
                            this.gerr = grad[i8];
                        } else if (this.gerr < (-grad[i8])) {
                            this.gerr = -grad[i8];
                        }
                    }
                    if (this.gerr < RELTOL || d6 <= 1.0E-6d) {
                        break;
                    }
                    while (d7 < d6 * d10 && d8 < d6 * d9) {
                        d6 *= 0.1d;
                    }
                }
                if (i4 % 20 == 0) {
                    d5 = (1.0d + (15.0d * d5)) * 0.0625d;
                }
                i4++;
            }
            gram.addDiag((-gram._diagAdded) + d3);
            if (!$assertionsDisabled && gram._diagAdded != d3) {
                throw new AssertionError();
            }
            this.iterations = i4;
            boolean z = this.gerr < this._gradientEps;
            this._converged = z;
            return z;
        }

        @Override // hex.glm.LSMSolver
        public String name() {
            return "ADMM";
        }

        static {
            $assertionsDisabled = !LSMSolver.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/glm/LSMSolver$LSMSolverException.class */
    public static class LSMSolverException extends RuntimeException {
        public LSMSolverException(String str) {
            super(str);
        }
    }

    /* loaded from: input_file:hex/glm/LSMSolver$LSMSolverType.class */
    public enum LSMSolverType {
        AUTO,
        ADMM,
        GenGradient
    }

    /* loaded from: input_file:hex/glm/LSMSolver$ProxSolver.class */
    public static final class ProxSolver extends LSMSolver {
        public ProxSolver(double d, double d2) {
            super(d, d2);
        }

        private static final double f_hat(double[] dArr, double d, double[] dArr2, double[] dArr3, double[] dArr4, double d2) {
            double d3 = d;
            double d4 = 0.0d;
            for (int i = 0; i < dArr.length; i++) {
                double d5 = dArr[i] - dArr2[i];
                d3 += (dArr3[i] - dArr4[i]) * d5;
                d4 += d5 * d5;
            }
            return d3 + ((0.25d * d4) / d2);
        }

        private double penalty(double[] dArr) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i = 0; i < dArr.length; i++) {
                d += Math.abs(dArr[i]);
                d2 += dArr[i] * dArr[i];
            }
            return this._lambda * ((this._alpha * d) + ((1.0d - this._alpha) * d2 * 0.5d));
        }

        private static double betaDiff(double[] dArr, double[] dArr2) {
            for (int i = 0; i < dArr.length; i++) {
                Math.max(0.0d, Math.abs(dArr[i] - dArr2[i]));
            }
            return 0.0d;
        }

        @Override // hex.glm.LSMSolver
        public boolean solve(Gram gram, double[] dArr, double d, double[] dArr2) {
            ADMMSolver aDMMSolver = new ADMMSolver(this._lambda, this._alpha, 0.01d);
            if (gram != null) {
                return aDMMSolver.solve(gram, dArr, d, dArr2);
            }
            Arrays.fill(dArr2, 0.0d);
            System.currentTimeMillis();
            double[] mul = gram.mul(dArr2);
            objectiveVal(dArr, d, dArr2, mul);
            double[] malloc8d = MemoryManager.malloc8d(dArr2.length);
            double[] malloc8d2 = MemoryManager.malloc8d(dArr2.length);
            double d2 = this._lambda * this._alpha;
            double d3 = this._lambda * (1.0d - this._alpha);
            double lsm_objectiveVal = lsm_objectiveVal(dArr, d, dArr2, mul);
            boolean z = false;
            int length = dArr2.length - 1;
            int i = 0;
            while (!z && i < 1000) {
                i++;
                double d4 = 1.0d;
                while (true) {
                    double d5 = d4;
                    if (d5 <= 1.0E-12d) {
                        z = true;
                        break;
                    }
                    double d6 = 1.0d / (1.0d + (d5 * d3));
                    double d7 = d2 * d5;
                    for (int i2 = 0; i2 < dArr2.length - 1; i2++) {
                        malloc8d[i2] = d6 * shrinkage(dArr2[i2] - (d5 * (mul[i2] - dArr[i2])), d7);
                    }
                    malloc8d[length] = dArr2[length] - (d5 * (mul[length] - dArr[length]));
                    gram.mul(malloc8d, malloc8d2);
                    double lsm_objectiveVal2 = lsm_objectiveVal(dArr, d, malloc8d, malloc8d2);
                    if (lsm_objectiveVal2 <= f_hat(malloc8d, lsm_objectiveVal, dArr2, mul, dArr, d5)) {
                        lsm_objectiveVal = lsm_objectiveVal2;
                        z = betaDiff(dArr2, malloc8d) < 1.0E-6d;
                        System.arraycopy(malloc8d, 0, dArr2, 0, malloc8d.length);
                        System.arraycopy(malloc8d2, 0, mul, 0, malloc8d2.length);
                    } else {
                        d4 = d5 * 0.8d;
                    }
                }
            }
            return z;
        }

        @Override // hex.glm.LSMSolver
        public String name() {
            return "ProximalGradientSolver";
        }
    }

    public LSMSolver(double d, double d2) {
        this._lambda = d;
        this._alpha = d2;
    }

    public final double[] grad(Gram gram, double[] dArr, double[] dArr2) {
        double[] mul = gram.mul(dArr);
        for (int i = 0; i < mul.length; i++) {
            int i2 = i;
            mul[i2] = mul[i2] - dArr2[i];
        }
        return mul;
    }

    public static void subgrad(double d, double d2, double[] dArr, double[] dArr2) {
        if (dArr == null) {
            return;
        }
        double d3 = d2 * d;
        for (int i = 0; i < dArr2.length - 1; i++) {
            if (dArr[i] < 0.0d) {
                int i2 = i;
                dArr2[i2] = dArr2[i2] - d3;
            } else if (dArr[i] > 0.0d) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + d3;
            } else {
                dArr2[i] = shrinkage(dArr2[i], d3);
            }
        }
    }

    public abstract boolean solve(Gram gram, double[] dArr, double d, double[] dArr2);

    public final boolean converged() {
        return this._converged;
    }

    public abstract String name();

    protected static double shrinkage(double d, double d2) {
        double d3 = d < 0.0d ? -1.0d : 1.0d;
        double d4 = d * d3;
        if (d4 <= d2) {
            return 0.0d;
        }
        return d3 * (d4 - d2);
    }

    protected double objectiveVal(double[] dArr, double d, double[] dArr2, double[] dArr3) {
        double lsm_objectiveVal = lsm_objectiveVal(dArr, d, dArr2, dArr3);
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < dArr2.length; i++) {
            d2 += Math.abs(dArr2[i]);
            d3 += dArr2[i] * dArr2[i];
        }
        return lsm_objectiveVal + (this._alpha * this._lambda * d2) + (0.5d * (1.0d - this._alpha) * this._lambda * d3);
    }

    protected double lsm_objectiveVal(double[] dArr, double d, double[] dArr2, double[] dArr3) {
        double d2 = 0.5d * d;
        for (int i = 0; i < dArr3.length; i++) {
            d2 += dArr2[i] * ((0.5d * dArr3[i]) - dArr[i]);
        }
        return d2;
    }

    static final double[] mul(double[][] dArr, double[] dArr2, double[] dArr3) {
        int length = dArr.length;
        int length2 = dArr2.length;
        for (int i = 0; i < length; i++) {
            dArr3[i] = dArr[i][0] * dArr2[0];
            for (int i2 = 1; i2 < length2; i2++) {
                int i3 = i;
                dArr3[i3] = dArr3[i3] + (dArr[i][i2] * dArr2[i2]);
            }
        }
        return dArr3;
    }

    static final double[] mul(double[] dArr, double d, double[] dArr2) {
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = d * dArr[i];
        }
        return dArr2;
    }

    static final double[] plus(double[] dArr, double[] dArr2, double[] dArr3) {
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = dArr[i] + dArr2[i];
        }
        return dArr3;
    }

    static final double[] minus(double[] dArr, double[] dArr2, double[] dArr3) {
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = dArr[i] - dArr2[i];
        }
        return dArr3;
    }

    static final double[] shrink(double[] dArr, double[] dArr2, double d) {
        for (int i = 0; i < dArr.length - 1; i++) {
            dArr2[i] = shrinkage(dArr[i], d);
        }
        dArr2[dArr.length - 1] = dArr[dArr.length - 1];
        return dArr2;
    }
}
