package hex.tree.drf;

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.schemas.DRFV3;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.ScoreBuildHistogram;
import hex.tree.SharedTree;
import hex.tree.drf.DRFModel;
import hex.tree.drf.TreeMeasuresCollector;
import java.util.Arrays;
import java.util.Random;
import water.AutoBuffer;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.Timer;

/* loaded from: input_file:hex/tree/drf/DRF.class */
public class DRF extends SharedTree<DRFModel, DRFModel.DRFParameters, DRFModel.DRFOutput> {
    protected int _mtry;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/tree/drf/DRF$DRFDecidedNode.class */
    static class DRFDecidedNode extends DTree.DecidedNode {
        DRFDecidedNode(DTree.UndecidedNode undecidedNode, DHistogram[] dHistogramArr) {
            super(undecidedNode, dHistogramArr);
        }

        @Override // hex.tree.DTree.DecidedNode
        public DTree.UndecidedNode makeUndecidedNode(DHistogram[] dHistogramArr) {
            return new DRFUndecidedNode(this._tree, this._nid, dHistogramArr);
        }

        @Override // hex.tree.DTree.DecidedNode
        public DTree.Split bestCol(DTree.UndecidedNode undecidedNode, DHistogram[] dHistogramArr) {
            DTree.Split split = new DTree.Split(-1, -1, null, (byte) 0, Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE, 0L, 0L, 0.0d, 0.0d);
            if (dHistogramArr == null) {
                return split;
            }
            for (int i = 0; i < undecidedNode._scoreCols.length; i++) {
                int i2 = undecidedNode._scoreCols[i];
                DTree.Split scoreMSE = dHistogramArr[i2].scoreMSE(i2, this._tree._min_rows);
                if (scoreMSE != null) {
                    if (scoreMSE.se() < split.se()) {
                        split = scoreMSE;
                    }
                    if (scoreMSE.se() <= 0.0d) {
                        break;
                    }
                }
            }
            return split;
        }
    }

    /* loaded from: input_file:hex/tree/drf/DRF$DRFDriver.class */
    private class DRFDriver extends SharedTree<DRFModel, DRFModel.DRFParameters, DRFModel.DRFOutput>.Driver {
        protected int _ntreesFromCheckpoint;
        public transient TreeMeasuresCollector.TreeMeasures _treeMeasuresOnOOB;
        public transient TreeMeasuresCollector.TreeMeasures[] _treeMeasuresOnSOOB;
        private transient float[] _improvPerVar;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/tree/drf/DRF$DRFDriver$CollectPreds.class */
        public class CollectPreds extends MRTask<CollectPreds> {
            final DTree[] _trees;
            double _threshold;
            long rightVotes;
            long allRows;
            float sse;
            final boolean importance = true;
            static final /* synthetic */ boolean $assertionsDisabled;

            CollectPreds(DTree[] dTreeArr, int[] iArr, double d) {
                this._trees = dTreeArr;
                this._threshold = d;
            }

            public void map(Chunk[] chunkArr) {
                int ns;
                Chunk chk_resp = DRF.this.chk_resp(chunkArr);
                double[] dArr = new double[1 + DRF.this._nclass];
                double[] dArr2 = new double[DRF.this._ncols];
                Chunk chk_oobt = DRF.this.chk_oobt(chunkArr);
                for (int i = 0; i < chk_oobt._len; i++) {
                    boolean z = false;
                    for (int i2 = 0; i2 < DRF.this._nclass; i2++) {
                        DTree dTree = this._trees[i2];
                        if (dTree != null) {
                            Chunk chk_tree = DRF.this.chk_tree(chunkArr, i2);
                            Chunk chk_nids = DRF.this.chk_nids(chunkArr, i2);
                            int at8 = (int) chk_nids.at8(i);
                            if (ScoreBuildHistogram.isOOBRow(at8)) {
                                if (!$assertionsDisabled && i2 != 0 && !z) {
                                    throw new AssertionError("Something is wrong: k-class trees oob row computing is broken! All k-trees should agree on oob row!");
                                }
                                z = true;
                                int oob2Nid = ScoreBuildHistogram.oob2Nid(at8);
                                if (dTree.node(oob2Nid) instanceof DTree.UndecidedNode) {
                                    oob2Nid = dTree.node(oob2Nid).pid();
                                }
                                if (dTree.root() instanceof DTree.LeafNode) {
                                    ns = 0;
                                } else {
                                    DTree.DecidedNode decided = dTree.decided(oob2Nid);
                                    if (decided._split.col() == -1) {
                                        decided = dTree.decided(dTree.node(oob2Nid).pid());
                                    }
                                    ns = decided.ns(chunkArr, i);
                                }
                                double pred = ((DTree.LeafNode) dTree.node(ns)).pred();
                                dArr[1 + i2] = (float) pred;
                                double atd = chk_oobt.atd(i);
                                if (DRF.this.isClassifier()) {
                                    chk_tree.set(i, ((float) ((chk_tree.atd(i) * atd) + pred)) / (atd + 1.0d));
                                } else {
                                    chk_tree.set(i, (float) (chk_tree.atd(i) + pred));
                                }
                                chk_oobt.set(i, chk_oobt.atd(i) + 1.0d);
                            }
                            chk_nids.set(i, 0L);
                        }
                    }
                    if (z && !chk_resp.isNA(i)) {
                        if (!DRF.this.isClassifier()) {
                            double d = dArr[1];
                            double atd2 = chk_resp.atd(i);
                            this.sse = (float) (this.sse + ((atd2 - d) * (atd2 - d)));
                        } else if (GenModel.getPrediction(dArr, DRF.this.data_row(chunkArr, i, dArr2), this._threshold) == ((int) chk_resp.at8(i))) {
                            this.rightVotes++;
                        }
                        this.allRows++;
                    }
                }
            }

            public void reduce(CollectPreds collectPreds) {
                this.rightVotes += collectPreds.rightVotes;
                this.allRows += collectPreds.allRows;
                this.sse += collectPreds.sse;
            }

            static {
                $assertionsDisabled = !DRF.class.desiredAssertionStatus();
            }
        }

        private DRFDriver() {
            super();
        }

        private void initTreeMeasurements() {
            this._improvPerVar = new float[DRF.this._ncols];
            int i = ((DRFModel.DRFParameters) DRF.this._parms)._ntrees;
            if (((DRFModel.DRFOutput) ((DRFModel) DRF.this._model)._output).isClassifier()) {
                this._treeMeasuresOnOOB = new TreeMeasuresCollector.TreeVotes(i);
                this._treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeVotes[DRF.this._ncols];
                for (int i2 = 0; i2 < DRF.this._ncols; i2++) {
                    this._treeMeasuresOnSOOB[i2] = new TreeMeasuresCollector.TreeVotes(i);
                }
                return;
            }
            this._treeMeasuresOnOOB = new TreeMeasuresCollector.TreeSSE(i);
            this._treeMeasuresOnSOOB = new TreeMeasuresCollector.TreeSSE[DRF.this._ncols];
            for (int i3 = 0; i3 < DRF.this._ncols; i3++) {
                this._treeMeasuresOnSOOB[i3] = new TreeMeasuresCollector.TreeSSE(i);
            }
        }

        @Override // hex.tree.SharedTree.Driver
        protected void buildModel() {
            DRF.this._mtry = ((DRFModel.DRFParameters) DRF.this._parms)._mtries == -1 ? DRF.this.isClassifier() ? Math.max((int) Math.sqrt(DRF.this._ncols), 1) : Math.max(DRF.this._ncols / 3, 1) : ((DRFModel.DRFParameters) DRF.this._parms)._mtries;
            if (1 > DRF.this._mtry || DRF.this._mtry > DRF.this._ncols) {
                throw new IllegalArgumentException("Computed mtry should be in interval <1,#cols> but it is " + DRF.this._mtry);
            }
            initTreeMeasurements();
            DRF.this._train.add("OUT_BAG_TREES", DRF.this._response.makeZero());
            new SetWrkTask().doAll(DRF.this._train);
            if (((DRFModel.DRFParameters) DRF.this._parms)._checkpoint) {
                Timer timer = new Timer();
                new OOBScorer(DRF.this._ncols, DRF.this._nclass, ((DRFModel.DRFParameters) DRF.this._parms)._sample_rate, ((DRFModel.DRFOutput) ((DRFModel) DRF.this._model)._output)._treeKeys).doAll(DRF.this._train);
                Log.info(new Object[]{"Reconstructing oob stats from checkpointed model took " + timer});
            }
            Random createRNG = SharedTree.createRNG(((DRFModel.DRFParameters) DRF.this._parms)._seed);
            for (int i = 0; i < this._ntreesFromCheckpoint; i++) {
                createRNG.nextLong();
            }
            for (int i2 = 0; i2 < ((DRFModel.DRFParameters) DRF.this._parms)._ntrees; i2++) {
                if (!(i2 == 0 && ((DRFModel.DRFParameters) DRF.this._parms)._checkpoint) && DRF.this.doScoringAndSaveModel(false, true, ((DRFModel.DRFParameters) DRF.this._parms)._build_tree_one_node) >= ((DRFModel.DRFParameters) DRF.this._parms)._r2_stopping) {
                    return;
                }
                Timer timer2 = new Timer();
                buildNextKTrees(DRF.this._train, DRF.this._mtry, ((DRFModel.DRFParameters) DRF.this._parms)._sample_rate, createRNG, i2);
                Log.info(new Object[]{(i2 + 1) + ". tree was built " + timer2.toString()});
                DRF.this.update(1L);
                if (!DRF.this.isRunning()) {
                    return;
                }
            }
            DRF.this.doScoringAndSaveModel(true, true, ((DRFModel.DRFParameters) DRF.this._parms)._build_tree_one_node);
        }

        private void buildNextKTrees(Frame frame, int i, float f, Random random, int i2) {
            DTree[] dTreeArr = new DTree[DRF.this._nclass];
            DHistogram[][][] dHistogramArr = new DHistogram[DRF.this._nclass][1][DRF.this._ncols];
            int max = Math.max(((DRFModel.DRFParameters) DRF.this._parms)._nbins_top_level, ((DRFModel.DRFParameters) DRF.this._parms)._nbins);
            long[] jArr = ((DRFModel.DRFOutput) ((DRFModel) DRF.this._model)._output)._distribution;
            long nextLong = random.nextLong();
            for (int i3 = 0; i3 < DRF.this._nclass; i3++) {
                if (jArr[i3] != 0 && (i3 != 1 || DRF.this._nclass != 2 || ((DRFModel.DRFParameters) DRF.this._parms)._binomial_double_trees)) {
                    dTreeArr[i3] = new DRFTree(frame, DRF.this._ncols, (char) ((DRFModel.DRFParameters) DRF.this._parms)._nbins, (char) ((DRFModel.DRFParameters) DRF.this._parms)._nbins_cats, (char) DRF.this._nclass, ((DRFModel.DRFParameters) DRF.this._parms)._min_rows, i, nextLong);
                    new DRFUndecidedNode(dTreeArr[i3], -1, DHistogram.initialHist(frame, DRF.this._ncols, max, ((DRFModel.DRFParameters) DRF.this._parms)._nbins_cats, dHistogramArr[i3][0], DRF.this.isClassifier()));
                }
            }
            Timer timer = new Timer();
            Sample[] sampleArr = new Sample[DRF.this._nclass];
            for (int i4 = 0; i4 < DRF.this._nclass; i4++) {
                if (dTreeArr[i4] != null) {
                    sampleArr[i4] = (Sample) new Sample((DRFTree) dTreeArr[i4], f).dfork(0, new Frame(new Vec[]{DRF.this.vec_nids(frame, i4), DRF.this.vec_resp(frame)}), ((DRFModel.DRFParameters) DRF.this._parms)._build_tree_one_node);
                }
            }
            for (int i5 = 0; i5 < DRF.this._nclass; i5++) {
                if (sampleArr[i5] != null) {
                    sampleArr[i5].getResult();
                }
            }
            Log.debug(new Object[]{"Sampling took: + " + timer});
            int[] iArr = new int[DRF.this._nclass];
            Timer timer2 = new Timer();
            for (int i6 = 0; i6 < ((DRFModel.DRFParameters) DRF.this._parms)._max_depth; i6++) {
                if (!DRF.this.isRunning()) {
                    return;
                }
                dHistogramArr = DRF.this.buildLayer(frame, ((DRFModel.DRFParameters) DRF.this._parms)._nbins, ((DRFModel.DRFParameters) DRF.this._parms)._nbins_cats, dTreeArr, iArr, dHistogramArr, true, ((DRFModel.DRFParameters) DRF.this._parms)._build_tree_one_node);
                if (dHistogramArr == null) {
                    break;
                }
            }
            Log.debug(new Object[]{"Tree build took: " + timer2});
            Timer timer3 = new Timer();
            for (int i7 = 0; i7 < DRF.this._nclass; i7++) {
                DTree dTree = dTreeArr[i7];
                if (dTree != null) {
                    int len = dTree.len();
                    iArr[i7] = len;
                    for (int i8 = 0; i8 < len; i8++) {
                        if (dTree.node(i8) instanceof DTree.DecidedNode) {
                            DTree.DecidedNode decided = dTree.decided(i8);
                            if (decided._split._col != -1) {
                                for (int i9 = 0; i9 < decided._nids.length; i9++) {
                                    int i10 = decided._nids[i9];
                                    if (i10 == -1 || (dTree.node(i10) instanceof DTree.UndecidedNode) || ((dTree.node(i10) instanceof DTree.DecidedNode) && ((DTree.DecidedNode) dTree.node(i10))._split.col() == -1)) {
                                        DRFLeafNode dRFLeafNode = new DRFLeafNode(dTree, i8);
                                        dRFLeafNode._pred = (float) decided.pred(i9);
                                        decided._nids[i9] = dRFLeafNode.nid();
                                    }
                                }
                            } else if (i8 == 0) {
                                new DRFLeafNode(dTree, -1, 0)._pred = (float) (DRF.this.isClassifier() ? ((DRFModel.DRFOutput) ((DRFModel) DRF.this._model)._output)._priorClassDist[i7] : DRF.this._response.mean());
                            }
                        }
                    }
                }
            }
            Log.debug(new Object[]{"Nodes propagation: " + timer3});
            Timer timer4 = new Timer();
            CollectPreds collectPreds = (CollectPreds) new CollectPreds(dTreeArr, iArr, ((DRFModel) DRF.this._model).defaultThreshold()).doAll(frame, ((DRFModel.DRFParameters) DRF.this._parms)._build_tree_one_node);
            if (DRF.this.isClassifier()) {
                TreeMeasuresCollector.asVotes(this._treeMeasuresOnOOB).append(collectPreds.rightVotes, collectPreds.allRows);
            } else {
                TreeMeasuresCollector.asSSE(this._treeMeasuresOnOOB).append(collectPreds.sse, collectPreds.allRows);
            }
            Log.debug(new Object[]{"CollectPreds done: " + timer4});
            ((DRFModel.DRFOutput) ((DRFModel) DRF.this._model)._output).addKTrees(dTreeArr);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hex.tree.SharedTree.Driver
        public DRFModel makeModel(Key key, DRFModel.DRFParameters dRFParameters, double d, double d2) {
            return new DRFModel(key, dRFParameters, new DRFModel.DRFOutput(DRF.this, d, d2));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/drf/DRF$DRFLeafNode.class */
    public static class DRFLeafNode extends DTree.LeafNode {
        static final /* synthetic */ boolean $assertionsDisabled;

        DRFLeafNode(DTree dTree, int i) {
            super(dTree, i);
        }

        DRFLeafNode(DTree dTree, int i, int i2) {
            super(dTree, i, i2);
        }

        @Override // hex.tree.DTree.Node
        protected AutoBuffer compress(AutoBuffer autoBuffer) {
            if ($assertionsDisabled || !Double.isNaN(this._pred)) {
                return autoBuffer.put4f(this._pred);
            }
            throw new AssertionError();
        }

        @Override // hex.tree.DTree.Node
        protected int size() {
            return 4;
        }

        static {
            $assertionsDisabled = !DRF.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/drf/DRF$DRFTree.class */
    public static class DRFTree extends DTree {
        final int _mtrys;
        final long[] _seeds;
        final transient Random _rand;

        DRFTree(Frame frame, int i, char c, char c2, char c3, int i2, int i3, long j) {
            super(frame._names, i, c, c2, c3, i2, j);
            this._mtrys = i3;
            this._rand = SharedTree.createRNG(j);
            this._seeds = new long[frame.vecs()[0].nChunks()];
            for (int i4 = 0; i4 < this._seeds.length; i4++) {
                this._seeds[i4] = this._rand.nextLong();
            }
        }

        public Random rngForChunk(int i) {
            return SharedTree.createRNG(this._seeds[i]);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/drf/DRF$DRFUndecidedNode.class */
    public static class DRFUndecidedNode extends DTree.UndecidedNode {
        static final /* synthetic */ boolean $assertionsDisabled;

        DRFUndecidedNode(DTree dTree, int i, DHistogram[] dHistogramArr) {
            super(dTree, i, dHistogramArr);
        }

        @Override // hex.tree.DTree.UndecidedNode
        public int[] scoreCols(DHistogram[] dHistogramArr) {
            DRFTree dRFTree = (DRFTree) this._tree;
            int[] iArr = new int[dHistogramArr.length];
            int i = 0;
            for (int i2 = 0; i2 < dHistogramArr.length; i2++) {
                if (dHistogramArr[i2] != null) {
                    if (!$assertionsDisabled && (dHistogramArr[i2]._min >= dHistogramArr[i2]._maxEx || dHistogramArr[i2].nbins() <= 1)) {
                        throw new AssertionError("broken histo range " + dHistogramArr[i2]);
                    }
                    int i3 = i;
                    i++;
                    iArr[i3] = i2;
                }
            }
            int i4 = i;
            if (!$assertionsDisabled && i4 <= 0) {
                throw new AssertionError();
            }
            for (int i5 = 0; i5 < dRFTree._mtrys && i != 0; i5++) {
                int nextInt = dRFTree._rand.nextInt(i);
                int i6 = iArr[nextInt];
                i--;
                iArr[nextInt] = iArr[i];
                iArr[i] = i6;
            }
            if ($assertionsDisabled || i4 - i > 0) {
                return Arrays.copyOfRange(iArr, i, i4);
            }
            throw new AssertionError();
        }

        static {
            $assertionsDisabled = !DRF.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/drf/DRF$Sample.class */
    public static class Sample extends MRTask<Sample> {
        final DRFTree _tree;
        final float _rate;

        Sample(DRFTree dRFTree, float f) {
            this._tree = dRFTree;
            this._rate = f;
        }

        public void map(Chunk chunk, Chunk chunk2) {
            Random rngForChunk = this._tree.rngForChunk(chunk.cidx());
            for (int i = 0; i < chunk._len; i++) {
                if (rngForChunk.nextFloat() >= this._rate || Double.isNaN(chunk2.atd(i))) {
                    chunk.set(i, -2L);
                }
            }
        }
    }

    /* loaded from: input_file:hex/tree/drf/DRF$SetWrkTask.class */
    private class SetWrkTask extends MRTask<SetWrkTask> {
        private SetWrkTask() {
        }

        public void map(Chunk[] chunkArr) {
            Chunk chk_resp = DRF.this.chk_resp(chunkArr);
            for (int i = 0; i < chk_resp._len; i++) {
                if (!chk_resp.isNA(i)) {
                    if (DRF.this.isClassifier()) {
                        DRF.this.chk_work(chunkArr, (int) chk_resp.at8(i)).set(i, 1L);
                    } else {
                        DRF.this.chk_work(chunkArr, 0).set(i, (float) chk_resp.atd(i));
                    }
                }
            }
        }
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public DRF(DRFModel.DRFParameters dRFParameters) {
        super(DRFGrid.MODEL_NAME, dRFParameters);
        init(false);
    }

    /* renamed from: schema, reason: merged with bridge method [inline-methods] */
    public DRFV3 m164schema() {
        return new DRFV3();
    }

    public Job<DRFModel> trainModel() {
        return start(new DRFDriver(), ((DRFModel.DRFParameters) this._parms)._ntrees);
    }

    @Override // hex.tree.SharedTree
    public void init(boolean z) {
        super.init(z);
        if (0.0d >= ((DRFModel.DRFParameters) this._parms)._sample_rate || ((DRFModel.DRFParameters) this._parms)._sample_rate > 1.0d) {
            throw new IllegalArgumentException("Sample rate should be interval (0,1> but it is " + ((DRFModel.DRFParameters) this._parms)._sample_rate);
        }
        if (((DRFModel.DRFParameters) this._parms)._mtries < 1 && ((DRFModel.DRFParameters) this._parms)._mtries != -1) {
            error("_mtries", "mtries must be -1 (converted to sqrt(features)), or >= 1 but it is " + ((DRFModel.DRFParameters) this._parms)._mtries);
        }
        if (this._train != null) {
            int numCols = this._train.numCols();
            if (((DRFModel.DRFParameters) this._parms)._mtries != -1 && (1 > ((DRFModel.DRFParameters) this._parms)._mtries || ((DRFModel.DRFParameters) this._parms)._mtries >= numCols)) {
                error("_mtries", "Computed mtries should be -1 or in interval <1,#cols> but it is " + ((DRFModel.DRFParameters) this._parms)._mtries);
            }
        }
        if (((DRFModel.DRFParameters) this._parms)._sample_rate == 1.0f && this._valid == null) {
            error("_sample_rate", "Sample rate is 100% and no validation dataset is specified.  There are no OOB data to compute out-of-bag error estimation!");
        }
        if (this._train == null || ((DRFModel.DRFParameters) this._parms)._response_column == null || this._nclass == 2 || !((DRFModel.DRFParameters) this._parms)._binomial_double_trees) {
            return;
        }
        warn("_binomial_double_trees", "Binomial double tree is ignored for non-binomial response.");
    }

    @Override // hex.tree.SharedTree
    protected DTree.DecidedNode makeDecided(DTree.UndecidedNode undecidedNode, DHistogram[] dHistogramArr) {
        return new DRFDecidedNode(undecidedNode, dHistogramArr);
    }

    @Override // hex.tree.SharedTree
    protected double score1(Chunk[] chunkArr, double[] dArr, int i) {
        double d = 0.0d;
        if (this._nclass > 2 || (this._nclass == 2 && ((DRFModel.DRFParameters) this._parms)._binomial_double_trees)) {
            for (int i2 = 0; i2 < this._nclass; i2++) {
                double atd = chk_tree(chunkArr, i2).atd(i);
                dArr[i2 + 1] = atd;
                d += atd;
            }
        } else if (this._nclass != 2 || ((DRFModel.DRFParameters) this._parms)._binomial_double_trees) {
            double atd2 = chk_tree(chunkArr, 0).atd(i) / chk_oobt(chunkArr).atd(i);
            dArr[0] = atd2;
            d = 0.0d + atd2;
            dArr[1] = 0.0d;
        } else {
            dArr[1] = chk_tree(chunkArr, 0).atd(i);
            if (!$assertionsDisabled && (dArr[1] < 0.0d || dArr[1] > 1.0d)) {
                throw new AssertionError();
            }
            dArr[2] = 1.0d - dArr[1];
        }
        return d;
    }

    static {
        $assertionsDisabled = !DRF.class.desiredAssertionStatus();
    }
}
