package hex.tree;

import hex.SupervisedModel;
import hex.VarImp;
import hex.tree.SharedTreeModel;
import hex.tree.SharedTreeModel.SharedTreeOutput;
import hex.tree.SharedTreeModel.SharedTreeParameters;
import java.util.Arrays;
import water.DKV;
import water.Futures;
import water.Key;

/* loaded from: input_file:hex/tree/SharedTreeModel.class */
public abstract class SharedTreeModel<M extends SharedTreeModel<M, P, O>, P extends SharedTreeParameters, O extends SharedTreeOutput> extends SupervisedModel<M, P, O> {
    static final String PRED_TYPE = "float";

    /* loaded from: input_file:hex/tree/SharedTreeModel$SharedTreeOutput.class */
    public static abstract class SharedTreeOutput extends SupervisedModel.SupervisedOutput {
        public double _initialPrediction;
        public int _ntrees;
        final TreeStats _treeStats;
        public Key<CompressedTree>[][] _treeKeys;
        public double[] _mse_train;
        public double[] _mse_valid;
        public VarImp _varimp;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX WARN: Type inference failed for: r1v4, types: [water.Key[], water.Key<hex.tree.CompressedTree>[][]] */
        public SharedTreeOutput(SharedTree sharedTree, double d, double d2) {
            super(sharedTree);
            this._ntrees = 0;
            this._treeKeys = new Key[this._ntrees];
            this._treeStats = new TreeStats();
            this._mse_train = new double[]{d};
            this._mse_valid = Double.isNaN(d2) ? null : new double[]{d2};
        }

        public void addKTrees(DTree[] dTreeArr) {
            if (!$assertionsDisabled && nclasses() != dTreeArr.length) {
                throw new AssertionError();
            }
            this._treeStats.updateBy(dTreeArr);
            this._treeKeys = (Key[][]) Arrays.copyOf(this._treeKeys, this._ntrees + 1);
            Key<CompressedTree>[][] keyArr = this._treeKeys;
            int i = this._ntrees;
            Key<CompressedTree>[] keyArr2 = new Key[dTreeArr.length];
            keyArr[i] = keyArr2;
            Futures futures = new Futures();
            for (int i2 = 0; i2 < nclasses(); i2++) {
                if (dTreeArr[i2] != null) {
                    CompressedTree compress = dTreeArr[i2].compress(this._ntrees, i2);
                    Key<CompressedTree> key = compress._key;
                    keyArr2[i2] = key;
                    DKV.put(key, compress, futures);
                }
            }
            this._ntrees++;
            this._mse_train = Arrays.copyOf(this._mse_train, this._ntrees + 1);
            if (this._mse_valid != null) {
                this._mse_valid = Arrays.copyOf(this._mse_valid, this._ntrees + 1);
            }
            futures.blockForPending();
        }

        public String toStringTree(int i, int i2) {
            return ((CompressedTree) this._treeKeys[i][i2].get()).toString(this);
        }

        static {
            $assertionsDisabled = !SharedTreeModel.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/tree/SharedTreeModel$SharedTreeParameters.class */
    public static abstract class SharedTreeParameters extends SupervisedModel.SupervisedParameters {
        static final int MAX_SUPPORTED_LEVELS = 1000;
        public int _ntrees = 50;
        public int _max_depth = 5;
        public int _min_rows = 10;
        public int _nbins = 20;
        public boolean _variable_importance = false;
        public long _seed;
        public boolean _checkpoint;
    }

    public SharedTreeModel(Key key, P p, O o) {
        super(key, p, o);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float[] score0(double[] dArr, float[] fArr) {
        Arrays.fill(fArr, 0.0f);
        for (int i = 0; i < this._output._treeKeys.length; i++) {
            score0(dArr, fArr, i);
        }
        return fArr;
    }

    public void score0(double[] dArr, float[] fArr, int i) {
        Key<CompressedTree>[] keyArr = this._output._treeKeys[i];
        for (int i2 = 0; i2 < keyArr.length; i2++) {
            if (keyArr[i2] != null) {
                int i3 = keyArr.length == 1 ? 0 : i2 + 1;
                fArr[i3] = fArr[i3] + DKV.get(keyArr[i2]).get().score(dArr);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean isFromSpeeDRF() {
        return false;
    }

    protected Futures remove_impl(Futures futures) {
        for (Key<CompressedTree>[] keyArr : this._output._treeKeys) {
            for (Key<CompressedTree> key : keyArr) {
                if (key != null) {
                    key.remove(futures);
                }
            }
        }
        return super.remove_impl(futures);
    }
}
