package hex.coxph;

import Jama.Matrix;
import hex.DataInfo;
import hex.FrameTask;
import hex.Model;
import hex.ModelBuilder;
import hex.SupervisedModelBuilder;
import hex.coxph.CoxPHModel;
import hex.schemas.ModelBuilderSchema;
import java.util.Arrays;
import jsr166y.ForkJoinTask;
import jsr166y.RecursiveAction;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.MemoryManager;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/coxph/CoxPH.class */
public class CoxPH extends SupervisedModelBuilder<CoxPHModel, CoxPHModel.CoxPHParameters, CoxPHModel.CoxPHOutput> {

    /* loaded from: input_file:hex/coxph/CoxPH$CoxPHDriver.class */
    public class CoxPHDriver extends H2O.H2OCountedCompleter<CoxPHDriver> {
        private Frame _modelBuilderTrain = null;

        public CoxPHDriver() {
        }

        public void setModelBuilderTrain(Frame frame) {
            this._modelBuilderTrain = frame;
        }

        private void applyScoringFrameSideEffects() {
            int length = CoxPH.this._parms.offset_columns == null ? 0 : CoxPH.this._parms.offset_columns.length;
            if (length == 0) {
                return;
            }
            int numCols = this._modelBuilderTrain.numCols();
            String str = this._modelBuilderTrain.names()[numCols - 1];
            Vec remove = this._modelBuilderTrain.remove(numCols - 1);
            for (int i = 0; i < length; i++) {
                Vec vec = CoxPH.this._parms.offset_columns[i];
                int find = CoxPH.this._train.find(vec);
                if (find < 0) {
                    throw new RuntimeException("CoxPHDriver failed to find offsetVec");
                }
                this._modelBuilderTrain.add(CoxPH.this._parms.train().names()[find], vec);
            }
            this._modelBuilderTrain.add(str, remove);
        }

        private void applyTrainingFrameSideEffects() {
            int numCols = this._modelBuilderTrain.numCols();
            String str = this._modelBuilderTrain.names()[numCols - 1];
            Vec remove = this._modelBuilderTrain.remove(numCols - 1);
            boolean z = CoxPH.this._parms.weights_column != null;
            boolean z2 = CoxPH.this._parms.start_column != null;
            if (z) {
                Vec vec = CoxPH.this._parms.weights_column;
                int find = CoxPH.this._train.find(vec);
                if (find < 0) {
                    throw new RuntimeException("CoxPHDriver failed to find weightVec");
                }
                this._modelBuilderTrain.add(CoxPH.this._parms.train().names()[find], vec);
            }
            if (z2) {
                Vec vec2 = CoxPH.this._parms.start_column;
                int find2 = CoxPH.this._train.find(vec2);
                if (find2 < 0) {
                    throw new RuntimeException("CoxPHDriver failed to find startVec");
                }
                this._modelBuilderTrain.add(CoxPH.this._parms.train().names()[find2], vec2);
            }
            Vec vec3 = CoxPH.this._parms.stop_column;
            int find3 = CoxPH.this._train.find(vec3);
            if (find3 < 0) {
                throw new RuntimeException("CoxPHDriver failed to find stopVec");
            }
            this._modelBuilderTrain.add(CoxPH.this._parms.train().names()[find3], vec3);
            this._modelBuilderTrain.add(str, remove);
        }

        protected void initStats(CoxPHModel coxPHModel, DataInfo dataInfo) {
            CoxPHModel.CoxPHParameters coxPHParameters = coxPHModel._parms;
            CoxPHModel.CoxPHOutput coxPHOutput = coxPHModel._output;
            coxPHOutput.n = coxPHParameters.stop_column.length();
            coxPHOutput.data_info = dataInfo;
            int length = coxPHParameters.offset_columns == null ? 0 : coxPHParameters.offset_columns.length;
            int fullN = coxPHOutput.data_info.fullN() - length;
            String[] coefNames = coxPHOutput.data_info.coefNames();
            coxPHOutput.coef_names = new String[fullN];
            System.arraycopy(coefNames, 0, coxPHOutput.coef_names, 0, fullN);
            coxPHOutput.coef = MemoryManager.malloc8d(fullN);
            coxPHOutput.exp_coef = MemoryManager.malloc8d(fullN);
            coxPHOutput.exp_neg_coef = MemoryManager.malloc8d(fullN);
            coxPHOutput.se_coef = MemoryManager.malloc8d(fullN);
            coxPHOutput.z_coef = MemoryManager.malloc8d(fullN);
            coxPHOutput.gradient = MemoryManager.malloc8d(fullN);
            coxPHOutput.hessian = CoxPH.malloc2DArray(fullN, fullN);
            coxPHOutput.var_coef = CoxPH.malloc2DArray(fullN, fullN);
            coxPHOutput.x_mean_cat = MemoryManager.malloc8d(fullN - (coxPHOutput.data_info._nums - length));
            coxPHOutput.x_mean_num = MemoryManager.malloc8d(coxPHOutput.data_info._nums - length);
            coxPHOutput.mean_offset = MemoryManager.malloc8d(length);
            coxPHOutput.offset_names = new String[length];
            System.arraycopy(coefNames, fullN, coxPHOutput.offset_names, 0, length);
            Vec vec = coxPHParameters.start_column;
            Vec vec2 = coxPHParameters.stop_column;
            coxPHOutput.min_time = coxPHParameters.start_column == null ? (long) vec2.min() : ((long) vec.min()) + 1;
            coxPHOutput.max_time = (long) vec2.max();
            int length2 = new Vec.CollectDomain().doAll(new Vec[]{vec2}).domain().length;
            coxPHOutput.time = MemoryManager.malloc8(length2);
            coxPHOutput.n_risk = MemoryManager.malloc8d(length2);
            coxPHOutput.n_event = MemoryManager.malloc8d(length2);
            coxPHOutput.n_censor = MemoryManager.malloc8d(length2);
            coxPHOutput.cumhaz_0 = MemoryManager.malloc8d(length2);
            coxPHOutput.var_cumhaz_1 = MemoryManager.malloc8d(length2);
            coxPHOutput.var_cumhaz_2 = CoxPH.malloc2DArray(length2, fullN);
        }

        protected void calcCounts(CoxPHModel coxPHModel, CoxPHTask coxPHTask) {
            CoxPHModel.CoxPHParameters coxPHParameters = coxPHModel._parms;
            CoxPHModel.CoxPHOutput coxPHOutput = coxPHModel._output;
            coxPHOutput.n_missing = coxPHOutput.n - coxPHTask.n;
            coxPHOutput.n = coxPHTask.n;
            for (int i = 0; i < coxPHOutput.x_mean_cat.length; i++) {
                coxPHOutput.x_mean_cat[i] = coxPHTask.sumWeightedCatX[i] / coxPHTask.sumWeights;
            }
            for (int i2 = 0; i2 < coxPHOutput.x_mean_num.length; i2++) {
                coxPHOutput.x_mean_num[i2] = coxPHTask.dinfo()._normSub[i2] + (coxPHTask.sumWeightedNumX[i2] / coxPHTask.sumWeights);
            }
            System.arraycopy(coxPHTask.dinfo()._normSub, coxPHOutput.x_mean_num.length, coxPHOutput.mean_offset, 0, coxPHOutput.mean_offset.length);
            int i3 = 0;
            for (int i4 = 0; i4 < coxPHTask.countEvents.length; i4++) {
                coxPHOutput.total_event += coxPHTask.countEvents[i4];
                if (coxPHTask.sizeEvents[i4] > 0.0d || coxPHTask.sizeCensored[i4] > 0.0d) {
                    coxPHOutput.time[i3] = coxPHOutput.min_time + i4;
                    coxPHOutput.n_risk[i3] = coxPHTask.sizeRiskSet[i4];
                    coxPHOutput.n_event[i3] = coxPHTask.sizeEvents[i4];
                    coxPHOutput.n_censor[i3] = coxPHTask.sizeCensored[i4];
                    i3++;
                }
            }
            if (coxPHParameters.start_column == null) {
                for (int length = coxPHOutput.n_risk.length - 2; length >= 0; length--) {
                    double[] dArr = coxPHOutput.n_risk;
                    int i5 = length;
                    dArr[i5] = dArr[i5] + coxPHOutput.n_risk[length + 1];
                }
            }
        }

        protected double calcLoglik(CoxPHModel coxPHModel, final CoxPHTask coxPHTask) {
            CoxPHModel.CoxPHParameters coxPHParameters = coxPHModel._parms;
            CoxPHModel.CoxPHOutput coxPHOutput = coxPHModel._output;
            final int length = coxPHOutput.coef.length;
            int length2 = coxPHTask.sizeEvents.length;
            double d = 0.0d;
            for (int i = 0; i < length; i++) {
                coxPHOutput.gradient[i] = 0.0d;
            }
            for (int i2 = 0; i2 < length; i2++) {
                for (int i3 = 0; i3 < length; i3++) {
                    coxPHOutput.hessian[i2][i3] = 0.0d;
                }
            }
            switch (coxPHParameters.ties) {
                case efron:
                    final double[] malloc8d = MemoryManager.malloc8d(length2);
                    final double[][] malloc2DArray = CoxPH.malloc2DArray(length2, length);
                    final double[][][] malloc3DArray = CoxPH.malloc3DArray(length2, length, length);
                    ForkJoinTask[] forkJoinTaskArr = new ForkJoinTask[length2];
                    for (int i4 = length2 - 1; i4 >= 0; i4--) {
                        final int i5 = i4;
                        forkJoinTaskArr[i4] = new RecursiveAction() { // from class: hex.coxph.CoxPH.CoxPHDriver.1
                            protected void compute() {
                                double d2 = coxPHTask.sizeEvents[i5];
                                if (d2 <= 0.0d) {
                                    return;
                                }
                                long j = coxPHTask.countEvents[i5];
                                double d3 = coxPHTask.sumLogRiskEvents[i5];
                                double d4 = coxPHTask.sumRiskEvents[i5];
                                double d5 = coxPHTask.rcumsumRisk[i5];
                                double d6 = d2 / j;
                                malloc8d[i5] = d3;
                                System.arraycopy(coxPHTask.sumXEvents[i5], 0, malloc2DArray[i5], 0, length);
                                long j2 = 0;
                                while (true) {
                                    long j3 = j2;
                                    if (j3 >= j) {
                                        return;
                                    }
                                    double d7 = j3 / j;
                                    double d8 = d5 - (d7 * d4);
                                    double[] dArr = malloc8d;
                                    int i6 = i5;
                                    dArr[i6] = dArr[i6] - (d6 * Math.log(d8));
                                    for (int i7 = 0; i7 < length; i7++) {
                                        double d9 = (coxPHTask.rcumsumXRisk[i5][i7] - (d7 * coxPHTask.sumXRiskEvents[i5][i7])) / d8;
                                        double[] dArr2 = malloc2DArray[i5];
                                        int i8 = i7;
                                        dArr2[i8] = dArr2[i8] - (d6 * d9);
                                        for (int i9 = 0; i9 < length; i9++) {
                                            double d10 = coxPHTask.rcumsumXRisk[i5][i9] - (d7 * coxPHTask.sumXRiskEvents[i5][i9]);
                                            double d11 = coxPHTask.rcumsumXXRisk[i5][i7][i9] - (d7 * coxPHTask.sumXXRiskEvents[i5][i7][i9]);
                                            double[] dArr3 = malloc3DArray[i5][i7];
                                            int i10 = i9;
                                            dArr3[i10] = dArr3[i10] - (d6 * ((d11 / d8) - (d9 * (d10 / d8))));
                                        }
                                    }
                                    j2 = j3 + 1;
                                }
                            }
                        };
                    }
                    ForkJoinTask.invokeAll(forkJoinTaskArr);
                    for (int i6 = 0; i6 < length2; i6++) {
                        d += malloc8d[i6];
                    }
                    for (int i7 = 0; i7 < length2; i7++) {
                        for (int i8 = 0; i8 < length; i8++) {
                            double[] dArr = coxPHOutput.gradient;
                            int i9 = i8;
                            dArr[i9] = dArr[i9] + malloc2DArray[i7][i8];
                        }
                    }
                    for (int i10 = 0; i10 < length2; i10++) {
                        for (int i11 = 0; i11 < length; i11++) {
                            for (int i12 = 0; i12 < length; i12++) {
                                double[] dArr2 = coxPHOutput.hessian[i11];
                                int i13 = i12;
                                dArr2[i13] = dArr2[i13] + malloc3DArray[i10][i11][i12];
                            }
                        }
                    }
                    break;
                case breslow:
                    for (int i14 = length2 - 1; i14 >= 0; i14--) {
                        double d2 = coxPHTask.sizeEvents[i14];
                        if (d2 > 0.0d) {
                            double d3 = coxPHTask.sumLogRiskEvents[i14];
                            double d4 = coxPHTask.rcumsumRisk[i14];
                            d = (d + d3) - (d2 * Math.log(d4));
                            for (int i15 = 0; i15 < length; i15++) {
                                double d5 = coxPHTask.rcumsumXRisk[i14][i15] / d4;
                                double[] dArr3 = coxPHOutput.gradient;
                                int i16 = i15;
                                dArr3[i16] = dArr3[i16] + coxPHTask.sumXEvents[i14][i15];
                                double[] dArr4 = coxPHOutput.gradient;
                                int i17 = i15;
                                dArr4[i17] = dArr4[i17] - (d2 * d5);
                                for (int i18 = 0; i18 < length; i18++) {
                                    double[] dArr5 = coxPHOutput.hessian[i15];
                                    int i19 = i18;
                                    dArr5[i19] = dArr5[i19] - (d2 * ((coxPHTask.rcumsumXXRisk[i14][i15][i18] / d4) - (d5 * (coxPHTask.rcumsumXRisk[i14][i18] / d4))));
                                }
                            }
                        }
                    }
                    break;
                default:
                    throw new IllegalArgumentException("ties method must be either efron or breslow");
            }
            return d;
        }

        protected void calcModelStats(CoxPHModel coxPHModel, double[] dArr, double d) {
            CoxPHModel.CoxPHParameters coxPHParameters = coxPHModel._parms;
            CoxPHModel.CoxPHOutput coxPHOutput = coxPHModel._output;
            int length = coxPHOutput.coef.length;
            Matrix inverse = new Matrix(coxPHOutput.hessian).inverse();
            for (int i = 0; i < length; i++) {
                for (int i2 = 0; i2 <= i; i2++) {
                    double d2 = -inverse.get(i, i2);
                    coxPHOutput.var_coef[i][i2] = d2;
                    coxPHOutput.var_coef[i2][i] = d2;
                }
            }
            for (int i3 = 0; i3 < length; i3++) {
                coxPHOutput.coef[i3] = dArr[i3];
                coxPHOutput.exp_coef[i3] = Math.exp(coxPHOutput.coef[i3]);
                coxPHOutput.exp_neg_coef[i3] = Math.exp(-coxPHOutput.coef[i3]);
                coxPHOutput.se_coef[i3] = Math.sqrt(coxPHOutput.var_coef[i3][i3]);
                coxPHOutput.z_coef[i3] = coxPHOutput.coef[i3] / coxPHOutput.se_coef[i3];
            }
            if (coxPHOutput.iter == 0) {
                coxPHOutput.null_loglik = d;
                coxPHOutput.maxrsq = 1.0d - Math.exp((2.0d * coxPHOutput.null_loglik) / coxPHOutput.n);
                coxPHOutput.score_test = 0.0d;
                for (int i4 = 0; i4 < length; i4++) {
                    double d3 = 0.0d;
                    for (int i5 = 0; i5 < length; i5++) {
                        d3 += coxPHOutput.var_coef[i4][i5] * coxPHOutput.gradient[i5];
                    }
                    coxPHOutput.score_test += coxPHOutput.gradient[i4] * d3;
                }
            }
            coxPHOutput.loglik = d;
            coxPHOutput.loglik_test = (-2.0d) * (coxPHOutput.null_loglik - coxPHOutput.loglik);
            coxPHOutput.rsq = 1.0d - Math.exp((-coxPHOutput.loglik_test) / coxPHOutput.n);
            coxPHOutput.wald_test = 0.0d;
            for (int i6 = 0; i6 < length; i6++) {
                double d4 = 0.0d;
                for (int i7 = 0; i7 < length; i7++) {
                    d4 -= coxPHOutput.hessian[i6][i7] * (coxPHOutput.coef[i7] - coxPHParameters.init);
                }
                coxPHOutput.wald_test += (coxPHOutput.coef[i6] - coxPHParameters.init) * d4;
            }
        }

        protected void calcCumhaz_0(CoxPHModel coxPHModel, CoxPHTask coxPHTask) {
            CoxPHModel.CoxPHParameters coxPHParameters = coxPHModel._parms;
            CoxPHModel.CoxPHOutput coxPHOutput = coxPHModel._output;
            int length = coxPHOutput.coef.length;
            int i = 0;
            switch (coxPHParameters.ties) {
                case efron:
                    for (int i2 = 0; i2 < coxPHTask.sizeEvents.length; i2++) {
                        double d = coxPHTask.sizeEvents[i2];
                        double d2 = coxPHTask.sizeCensored[i2];
                        if (d > 0.0d || d2 > 0.0d) {
                            long j = coxPHTask.countEvents[i2];
                            double d3 = coxPHTask.sumRiskEvents[i2];
                            double d4 = coxPHTask.rcumsumRisk[i2];
                            double d5 = d / j;
                            coxPHOutput.cumhaz_0[i] = 0.0d;
                            coxPHOutput.var_cumhaz_1[i] = 0.0d;
                            for (int i3 = 0; i3 < length; i3++) {
                                coxPHOutput.var_cumhaz_2[i][i3] = 0.0d;
                            }
                            long j2 = 0;
                            while (true) {
                                long j3 = j2;
                                if (j3 < j) {
                                    double d6 = j3 / j;
                                    double d7 = 1.0d / (d4 - (d6 * d3));
                                    double d8 = d7 * d7;
                                    double[] dArr = coxPHOutput.cumhaz_0;
                                    int i4 = i;
                                    dArr[i4] = dArr[i4] + (d5 * d7);
                                    double[] dArr2 = coxPHOutput.var_cumhaz_1;
                                    int i5 = i;
                                    dArr2[i5] = dArr2[i5] + (d5 * d8);
                                    for (int i6 = 0; i6 < length; i6++) {
                                        double[] dArr3 = coxPHOutput.var_cumhaz_2[i];
                                        int i7 = i6;
                                        dArr3[i7] = dArr3[i7] + (d5 * (coxPHTask.rcumsumXRisk[i2][i6] - (d6 * coxPHTask.sumXRiskEvents[i2][i6])) * d8);
                                    }
                                    j2 = j3 + 1;
                                } else {
                                    i++;
                                }
                            }
                        }
                    }
                    break;
                case breslow:
                    for (int i8 = 0; i8 < coxPHTask.sizeEvents.length; i8++) {
                        double d9 = coxPHTask.sizeEvents[i8];
                        double d10 = coxPHTask.sizeCensored[i8];
                        if (d9 > 0.0d || d10 > 0.0d) {
                            double d11 = coxPHTask.rcumsumRisk[i8];
                            double d12 = d9 / d11;
                            coxPHOutput.cumhaz_0[i] = d12;
                            coxPHOutput.var_cumhaz_1[i] = d9 / (d11 * d11);
                            for (int i9 = 0; i9 < length; i9++) {
                                coxPHOutput.var_cumhaz_2[i][i9] = (coxPHTask.rcumsumXRisk[i8][i9] / d11) * d12;
                            }
                            i++;
                        }
                    }
                    break;
                default:
                    throw new IllegalArgumentException("ties method must be either efron or breslow");
            }
            for (int i10 = 1; i10 < coxPHOutput.cumhaz_0.length; i10++) {
                coxPHOutput.cumhaz_0[i10] = coxPHOutput.cumhaz_0[i10 - 1] + coxPHOutput.cumhaz_0[i10];
                coxPHOutput.var_cumhaz_1[i10] = coxPHOutput.var_cumhaz_1[i10 - 1] + coxPHOutput.var_cumhaz_1[i10];
                for (int i11 = 0; i11 < length; i11++) {
                    coxPHOutput.var_cumhaz_2[i10][i11] = coxPHOutput.var_cumhaz_2[i10 - 1][i11] + coxPHOutput.var_cumhaz_2[i10][i11];
                }
            }
        }

        protected void compute2() {
            try {
                try {
                    Scope.enter();
                    CoxPH.this._parms.read_lock_frames(CoxPH.this);
                    CoxPH.this.init(true);
                    applyScoringFrameSideEffects();
                    CoxPHModel coxPHModel = new CoxPHModel(CoxPH.this.dest(), CoxPH.this._parms, new CoxPHModel.CoxPHOutput(CoxPH.this));
                    coxPHModel.delete_and_lock(CoxPH.this._key);
                    applyTrainingFrameSideEffects();
                    DataInfo dataInfo = new DataInfo(Key.make(), this._modelBuilderTrain, (Frame) null, 1, false, DataInfo.TransformType.DEMEAN, DataInfo.TransformType.NONE, true, false);
                    initStats(coxPHModel, dataInfo);
                    int length = coxPHModel._parms.offset_columns == null ? 0 : coxPHModel._parms.offset_columns.length;
                    int fullN = dataInfo.fullN() - length;
                    double[] malloc8d = MemoryManager.malloc8d(fullN);
                    double[] malloc8d2 = MemoryManager.malloc8d(fullN);
                    double[] malloc8d3 = MemoryManager.malloc8d(fullN);
                    Arrays.fill(malloc8d, Double.NaN);
                    Arrays.fill(malloc8d2, Double.NaN);
                    for (int i = 0; i < fullN; i++) {
                        malloc8d3[i] = coxPHModel._parms.init;
                    }
                    double d = -1.7976931348623157E308d;
                    int i2 = (int) ((coxPHModel._output.max_time - coxPHModel._output.min_time) + 1);
                    boolean z = coxPHModel._parms.start_column != null;
                    boolean z2 = coxPHModel._parms.weights_column != null;
                    for (int i3 = 0; i3 <= coxPHModel._parms.iter_max; i3++) {
                        coxPHModel._output.iter = i3;
                        CoxPHTask coxPHTask = (CoxPHTask) new CoxPHTask(self(), dataInfo, malloc8d3, coxPHModel._output.min_time, i2, length, z, z2).doAll(dataInfo._adaptedFrame);
                        double calcLoglik = calcLoglik(coxPHModel, coxPHTask);
                        if (calcLoglik > d) {
                            if (i3 == 0) {
                                calcCounts(coxPHModel, coxPHTask);
                            }
                            calcModelStats(coxPHModel, malloc8d3, calcLoglik);
                            calcCumhaz_0(coxPHModel, coxPHTask);
                            if (calcLoglik == 0.0d) {
                                coxPHModel._output.lre = -Math.log10(Math.abs(d - calcLoglik));
                            } else {
                                coxPHModel._output.lre = -Math.log10(Math.abs((d - calcLoglik) / calcLoglik));
                            }
                            if (coxPHModel._output.lre >= coxPHModel._parms.lre_min) {
                                break;
                            }
                            Arrays.fill(malloc8d, 0.0d);
                            for (int i4 = 0; i4 < fullN; i4++) {
                                for (int i5 = 0; i5 < fullN; i5++) {
                                    int i6 = i4;
                                    malloc8d[i6] = malloc8d[i6] - (coxPHModel._output.var_coef[i4][i5] * coxPHModel._output.gradient[i5]);
                                }
                            }
                            for (int i7 = 0; i7 < fullN && !Double.isNaN(malloc8d[i7]) && !Double.isInfinite(malloc8d[i7]); i7++) {
                            }
                            d = calcLoglik;
                            System.arraycopy(malloc8d3, 0, malloc8d2, 0, malloc8d2.length);
                        } else {
                            for (int i8 = 0; i8 < fullN; i8++) {
                                int i9 = i8;
                                malloc8d[i9] = malloc8d[i9] / 2.0d;
                            }
                        }
                        for (int i10 = 0; i10 < fullN; i10++) {
                            malloc8d3[i10] = malloc8d2[i10] - malloc8d[i10];
                        }
                    }
                    coxPHModel.update(CoxPH.this._key);
                    CoxPH.this._parms.read_unlock_frames(CoxPH.this);
                    Scope.exit(new Key[0]);
                    CoxPH.this.done();
                } catch (Throwable th) {
                    if (DKV.getGet(CoxPH.this._key)._state != Job.JobState.CANCELLED) {
                        th.printStackTrace();
                        CoxPH.this.failed(th);
                        throw th;
                    }
                    Log.info(new Object[]{"Job cancelled by user."});
                    CoxPH.this._parms.read_unlock_frames(CoxPH.this);
                    Scope.exit(new Key[0]);
                    CoxPH.this.done();
                }
                tryComplete();
            } catch (Throwable th2) {
                CoxPH.this._parms.read_unlock_frames(CoxPH.this);
                Scope.exit(new Key[0]);
                CoxPH.this.done();
                throw th2;
            }
        }

        Key self() {
            return CoxPH.this._key;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:hex/coxph/CoxPH$CoxPHTask.class */
    public static class CoxPHTask extends FrameTask<CoxPHTask> {
        private final double[] _beta;
        private final int _n_time;
        private final long _min_time;
        private final int _n_offsets;
        private final boolean _has_start_column;
        private final boolean _has_weights_column;
        protected long n;
        protected long n_missing;
        protected double sumWeights;
        protected double[] sumWeightedCatX;
        protected double[] sumWeightedNumX;
        protected double[] sizeRiskSet;
        protected double[] sizeCensored;
        protected double[] sizeEvents;
        protected long[] countEvents;
        protected double[][] sumXEvents;
        protected double[] sumRiskEvents;
        protected double[][] sumXRiskEvents;
        protected double[][][] sumXXRiskEvents;
        protected double[] sumLogRiskEvents;
        protected double[] rcumsumRisk;
        protected double[][] rcumsumXRisk;
        protected double[][][] rcumsumXXRisk;

        CoxPHTask(Key key, DataInfo dataInfo, double[] dArr, long j, int i, int i2, boolean z, boolean z2) {
            super(key, dataInfo);
            this._beta = dArr;
            this._n_time = i;
            this._min_time = j;
            this._n_offsets = i2;
            this._has_start_column = z;
            this._has_weights_column = z2;
        }

        @Override // hex.FrameTask
        protected void chunkInit() {
            int length = this._beta.length;
            this.sumWeightedCatX = MemoryManager.malloc8d(length - (this._dinfo._nums - this._n_offsets));
            this.sumWeightedNumX = MemoryManager.malloc8d(this._dinfo._nums);
            this.sizeRiskSet = MemoryManager.malloc8d(this._n_time);
            this.sizeCensored = MemoryManager.malloc8d(this._n_time);
            this.sizeEvents = MemoryManager.malloc8d(this._n_time);
            this.countEvents = MemoryManager.malloc8(this._n_time);
            this.sumRiskEvents = MemoryManager.malloc8d(this._n_time);
            this.sumLogRiskEvents = MemoryManager.malloc8d(this._n_time);
            this.rcumsumRisk = MemoryManager.malloc8d(this._n_time);
            this.sumXEvents = CoxPH.malloc2DArray(this._n_time, length);
            this.sumXRiskEvents = CoxPH.malloc2DArray(this._n_time, length);
            this.rcumsumXRisk = CoxPH.malloc2DArray(this._n_time, length);
            this.sumXXRiskEvents = CoxPH.malloc3DArray(this._n_time, length, length);
            this.rcumsumXXRisk = CoxPH.malloc3DArray(this._n_time, length, length);
        }

        @Override // hex.FrameTask
        protected void processRow(long j, DataInfo.Row row) {
            this.n++;
            double[] dArr = row.response;
            int i = row.nBins;
            int[] iArr = row.numIds;
            double[] dArr2 = row.numVals;
            double d = this._has_weights_column ? dArr[0] : 1.0d;
            if (d <= 0.0d) {
                throw new IllegalArgumentException("weights must be positive values");
            }
            long j2 = (long) dArr[dArr.length - 1];
            int i2 = this._has_start_column ? (int) ((((long) dArr[dArr.length - 3]) + 1) - this._min_time) : -1;
            int i3 = (int) (((long) dArr[dArr.length - 2]) - this._min_time);
            if (i2 > i3) {
                throw new IllegalArgumentException("start times must be strictly less than stop times");
            }
            int numStart = this._dinfo.numStart();
            this.sumWeights += d;
            for (int i4 = 0; i4 < i; i4++) {
                double[] dArr3 = this.sumWeightedCatX;
                int i5 = iArr[i4];
                dArr3[i5] = dArr3[i5] + d;
            }
            for (int i6 = 0; i6 < dArr2.length; i6++) {
                double[] dArr4 = this.sumWeightedNumX;
                int i7 = i6;
                dArr4[i7] = dArr4[i7] + (d * dArr2[i6]);
            }
            double d2 = 0.0d;
            for (int i8 = 0; i8 < i; i8++) {
                d2 += this._beta[iArr[i8]];
            }
            for (int i9 = 0; i9 < dArr2.length - this._n_offsets; i9++) {
                d2 += dArr2[i9] * this._beta[numStart + i9];
            }
            for (int length = dArr2.length - this._n_offsets; length < dArr2.length; length++) {
                d2 += dArr2[length];
            }
            double exp = d * Math.exp(d2);
            double d3 = d2 * d;
            if (j2 > 0) {
                long[] jArr = this.countEvents;
                jArr[i3] = jArr[i3] + 1;
                double[] dArr5 = this.sizeEvents;
                dArr5[i3] = dArr5[i3] + d;
                double[] dArr6 = this.sumLogRiskEvents;
                dArr6[i3] = dArr6[i3] + d3;
                double[] dArr7 = this.sumRiskEvents;
                dArr7[i3] = dArr7[i3] + exp;
            } else {
                double[] dArr8 = this.sizeCensored;
                dArr8[i3] = dArr8[i3] + d;
            }
            if (this._has_start_column) {
                for (int i10 = i2; i10 <= i3; i10++) {
                    double[] dArr9 = this.sizeRiskSet;
                    int i11 = i10;
                    dArr9[i11] = dArr9[i11] + d;
                }
                for (int i12 = i2; i12 <= i3; i12++) {
                    double[] dArr10 = this.rcumsumRisk;
                    int i13 = i12;
                    dArr10[i13] = dArr10[i13] + exp;
                }
            } else {
                double[] dArr11 = this.sizeRiskSet;
                dArr11[i3] = dArr11[i3] + d;
                double[] dArr12 = this.rcumsumRisk;
                dArr12[i3] = dArr12[i3] + exp;
            }
            int length2 = i + (dArr2.length - this._n_offsets);
            int i14 = numStart - i;
            int i15 = 0;
            while (i15 < length2) {
                boolean z = i15 < i;
                int i16 = z ? iArr[i15] : i14 + i15;
                double d4 = z ? 1.0d : dArr2[i15 - i];
                double d5 = d4 * exp;
                if (j2 > 0) {
                    double[] dArr13 = this.sumXEvents[i3];
                    dArr13[i16] = dArr13[i16] + (d * d4);
                    double[] dArr14 = this.sumXRiskEvents[i3];
                    dArr14[i16] = dArr14[i16] + d5;
                }
                if (this._has_start_column) {
                    for (int i17 = i2; i17 <= i3; i17++) {
                        double[] dArr15 = this.rcumsumXRisk[i17];
                        dArr15[i16] = dArr15[i16] + d5;
                    }
                } else {
                    double[] dArr16 = this.rcumsumXRisk[i3];
                    dArr16[i16] = dArr16[i16] + d5;
                }
                int i18 = 0;
                while (i18 < length2) {
                    boolean z2 = i18 < i;
                    int i19 = z2 ? iArr[i18] : i14 + i18;
                    double d6 = (z2 ? 1.0d : dArr2[i18 - i]) * d5;
                    if (j2 > 0) {
                        double[] dArr17 = this.sumXXRiskEvents[i3][i16];
                        dArr17[i19] = dArr17[i19] + d6;
                    }
                    if (this._has_start_column) {
                        for (int i20 = i2; i20 <= i3; i20++) {
                            double[] dArr18 = this.rcumsumXXRisk[i20][i16];
                            dArr18[i19] = dArr18[i19] + d6;
                        }
                    } else {
                        double[] dArr19 = this.rcumsumXXRisk[i3][i16];
                        dArr19[i19] = dArr19[i19] + d6;
                    }
                    i18++;
                }
                i15++;
            }
        }

        public void reduce(CoxPHTask coxPHTask) {
            this.n += coxPHTask.n;
            this.sumWeights += coxPHTask.sumWeights;
            ArrayUtils.add(this.sumWeightedCatX, coxPHTask.sumWeightedCatX);
            ArrayUtils.add(this.sumWeightedNumX, coxPHTask.sumWeightedNumX);
            ArrayUtils.add(this.sizeRiskSet, coxPHTask.sizeRiskSet);
            ArrayUtils.add(this.sizeCensored, coxPHTask.sizeCensored);
            ArrayUtils.add(this.sizeEvents, coxPHTask.sizeEvents);
            ArrayUtils.add(this.countEvents, coxPHTask.countEvents);
            ArrayUtils.add(this.sumXEvents, coxPHTask.sumXEvents);
            ArrayUtils.add(this.sumRiskEvents, coxPHTask.sumRiskEvents);
            ArrayUtils.add(this.sumXRiskEvents, coxPHTask.sumXRiskEvents);
            ArrayUtils.add(this.sumXXRiskEvents, coxPHTask.sumXXRiskEvents);
            ArrayUtils.add(this.sumLogRiskEvents, coxPHTask.sumLogRiskEvents);
            ArrayUtils.add(this.rcumsumRisk, coxPHTask.rcumsumRisk);
            ArrayUtils.add(this.rcumsumXRisk, coxPHTask.rcumsumXRisk);
            ArrayUtils.add(this.rcumsumXXRisk, coxPHTask.rcumsumXXRisk);
        }

        protected void postGlobal() {
            if (this._has_start_column) {
                return;
            }
            for (int length = this.rcumsumRisk.length - 2; length >= 0; length--) {
                double[] dArr = this.rcumsumRisk;
                int i = length;
                dArr[i] = dArr[i] + this.rcumsumRisk[length + 1];
            }
            for (int length2 = this.rcumsumXRisk.length - 2; length2 >= 0; length2--) {
                for (int i2 = 0; i2 < this.rcumsumXRisk[length2].length; i2++) {
                    double[] dArr2 = this.rcumsumXRisk[length2];
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + this.rcumsumXRisk[length2 + 1][i2];
                }
            }
            for (int length3 = this.rcumsumXXRisk.length - 2; length3 >= 0; length3--) {
                for (int i4 = 0; i4 < this.rcumsumXXRisk[length3].length; i4++) {
                    for (int i5 = 0; i5 < this.rcumsumXXRisk[length3][i4].length; i5++) {
                        double[] dArr3 = this.rcumsumXXRisk[length3][i4];
                        int i6 = i5;
                        dArr3[i6] = dArr3[i6] + this.rcumsumXXRisk[length3 + 1][i4][i5];
                    }
                }
            }
        }
    }

    public Model.ModelCategory[] can_build() {
        return new Model.ModelCategory[]{Model.ModelCategory.Unknown};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public CoxPH(CoxPHModel.CoxPHParameters coxPHParameters) {
        super("CoxPHLearning", coxPHParameters);
        init(false);
    }

    public ModelBuilderSchema schema() {
        H2O.unimpl();
        return null;
    }

    public Job<CoxPHModel> trainModel() {
        CoxPHDriver coxPHDriver = new CoxPHDriver();
        coxPHDriver.setModelBuilderTrain(this._train);
        return (CoxPH) start(coxPHDriver, this._parms.iter_max);
    }

    public void init(boolean z) {
        super.init(z);
        if (this._parms.start_column != null && !this._parms.start_column.isInt()) {
            error("start_column", "start time must be null or of type integer");
        }
        if (!this._parms.stop_column.isInt()) {
            error("stop_column", "stop time must be of type integer");
        }
        if (!this._parms.event_column.isInt() && !this._parms.event_column.isEnum()) {
            error("event_column", "event must be of type integer or factor");
        }
        if (Double.isNaN(this._parms.lre_min) || this._parms.lre_min <= 0.0d) {
            error("lre_min", "lre_min must be a positive number");
        }
        if (this._parms.iter_max < 1) {
            error("iter_max", "iter_max must be a positive integer");
        }
        int max = (int) ((this._parms.stop_column.max() - (this._parms.start_column == null ? (long) this._parms.stop_column.min() : ((long) this._parms.start_column.min()) + 1)) + 1.0d);
        if (max < 1) {
            error("start_column", "start times must be strictly less than stop times");
        }
        if (max > 10000) {
            error("stop_column", "number of distinct stop times is " + max + "; maximum number allowed is 10000");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    public static double[][] malloc2DArray(int i, int i2) {
        ?? r0 = new double[i];
        for (int i3 = 0; i3 < i; i3++) {
            r0[i3] = MemoryManager.malloc8d(i2);
        }
        return r0;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double[][][] malloc3DArray(int i, int i2, int i3) {
        double[][][] dArr = new double[i][i2];
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < i2; i5++) {
                dArr[i4][i5] = MemoryManager.malloc8d(i3);
            }
        }
        return dArr;
    }
}
