package hex.tree;

import hex.AUC;
import hex.ConfusionMatrix2;
import java.util.Arrays;
import water.DKV;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.ModelUtils;

/* loaded from: input_file:hex/tree/Score.class */
public class Score extends MRTask<Score> {
    final SharedTree _bldr;
    final int _nclass;
    final int _ncols;
    final boolean _oob;
    final boolean _validation;
    final int _cmlen;
    double _sum;
    long _snrows;
    long[][] _cm;
    long[][][] _cms;
    static final /* synthetic */ boolean $assertionsDisabled;

    public double r2() {
        double d = this._sum / this._snrows;
        double sigma = DKV.get(this._bldr._response_key).get().sigma();
        return 1.0d - (d / (sigma * sigma));
    }

    public ConfusionMatrix2 cm() {
        if (this._cm == null) {
            return null;
        }
        return new ConfusionMatrix2(this._cm);
    }

    public AUC auc() {
        if (this._nclass != 2) {
            return null;
        }
        int length = this._cms.length;
        ConfusionMatrix2[] confusionMatrix2Arr = new ConfusionMatrix2[length];
        for (int i = 0; i < length; i++) {
            confusionMatrix2Arr[i] = new ConfusionMatrix2(this._cms[i]);
        }
        return new AUC(confusionMatrix2Arr, ModelUtils.DEFAULT_THRESHOLDS, this._bldr.vresponse().domain());
    }

    public Score(SharedTree sharedTree, boolean z) {
        this._bldr = sharedTree;
        this._nclass = sharedTree._nclass;
        this._ncols = sharedTree._ncols;
        this._oob = z;
        this._validation = sharedTree._parms._valid != null;
        this._cmlen = sharedTree.vresponse().cardinality();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Score doIt(boolean z) {
        if (this._bldr._parms._valid != null) {
            Frame score = this._bldr._model.score(this._bldr.valid(), false);
            doAll(score, z);
            score.delete();
        } else {
            doAll(this._bldr.train(), z);
        }
        return this;
    }

    public void map(Chunk[] chunkArr) {
        float score2;
        float at0;
        Chunk chk_resp = this._bldr.chk_resp(chunkArr);
        float[] fArr = new float[this._nclass + 1];
        this._cm = this._cmlen == -1 ? (long[][]) null : new long[this._cmlen][this._cmlen];
        this._cms = this._cmlen == -1 ? (long[][][]) null : new long[ModelUtils.DEFAULT_THRESHOLDS.length][2][2];
        for (int i = 0; i < chk_resp._len; i++) {
            if (!chk_resp.isNA0(i)) {
                if (this._validation) {
                    for (int i2 = 0; i2 < this._nclass; i2++) {
                        fArr[i2 + 1] = (float) chunkArr[i2 + 1].at0(i);
                    }
                    score2 = this._nclass > 1 ? 1.0f : fArr[1];
                } else {
                    score2 = this._bldr.score2(chunkArr, fArr, i);
                }
                int i3 = 0;
                if (!this._oob || !this._bldr.outOfBagRow(chunkArr, i)) {
                    if (this._nclass > 1) {
                        i3 = (int) chk_resp.at80(i);
                        if (score2 == 0.0f) {
                            at0 = 1.0f - (1.0f / this._nclass);
                        } else {
                            if (!$assertionsDisabled && (0 > i3 || i3 >= this._nclass)) {
                                throw new AssertionError("weird ycls=" + i3 + ", y=" + chk_resp.at0(i));
                            }
                            at0 = Float.isInfinite(score2) ? Float.isInfinite(fArr[i3 + 1]) ? 0.0f : 1.0f : 1.0f - (fArr[i3 + 1] / score2);
                        }
                        if (!$assertionsDisabled && Double.isNaN(at0)) {
                            throw new AssertionError("fs[cls]=" + fArr[i3 + 1] + ", sum=" + score2);
                        }
                    } else {
                        at0 = ((float) chk_resp.at0(i)) - score2;
                    }
                    this._sum += at0 * at0;
                    if (!$assertionsDisabled && Double.isNaN(this._sum)) {
                        throw new AssertionError();
                    }
                    if (this._nclass > 1) {
                        if (this._nclass == 2) {
                            float f = this._validation ? fArr[2] : !Float.isInfinite(score2) ? fArr[2] / score2 : Float.isInfinite(fArr[2]) ? 1.0f : 0.0f;
                            for (int i4 = 0; i4 < ModelUtils.DEFAULT_THRESHOLDS.length; i4++) {
                                boolean z = f >= ModelUtils.DEFAULT_THRESHOLDS[i4];
                                long[] jArr = this._cms[i4][i3];
                                jArr[z ? 1 : 0] = jArr[z ? 1 : 0] + 1;
                            }
                        }
                        int at80 = this._validation ? (int) this._bldr.chk_work(chunkArr, 0).at80(i) : ModelUtils.getPrediction(fArr, i);
                        long[] jArr2 = this._cm[i3];
                        jArr2[at80] = jArr2[at80] + 1;
                    }
                    this._snrows++;
                }
            }
        }
    }

    public void reduce(Score score) {
        this._sum += score._sum;
        if (this._cm != null) {
            ArrayUtils.add(this._cm, score._cm);
        }
        this._snrows += score._snrows;
        if (this._cms != null) {
            for (int i = 0; i < this._cms.length; i++) {
                ArrayUtils.add(this._cms[i], score._cms[i]);
            }
        }
    }

    public Score report(int i, DTree[] dTreeArr) {
        if (!$assertionsDisabled && Double.isNaN(this._sum)) {
            throw new AssertionError();
        }
        Log.info(new Object[]{"============================================================== "});
        int i2 = 0;
        if (dTreeArr != null) {
            for (DTree dTree : dTreeArr) {
                if (dTree != null) {
                    i2 += dTree._len;
                }
            }
        }
        long j = this._snrows;
        Log.info(new Object[]{"r2 is " + r2() + ", with " + i + "x" + this._nclass + " trees (average of " + (i2 / this._nclass) + " nodes)"});
        if (this._nclass > 1) {
            for (int i3 = 0; i3 < this._nclass; i3++) {
                j -= this._cm[i3][i3];
            }
            Log.info(new Object[]{"Total of " + j + " errors on " + this._snrows + " rows, CM= " + Arrays.deepToString(this._cm)});
        } else {
            Log.info(new Object[]{"Reported on " + this._snrows + " rows."});
        }
        return this;
    }

    static {
        $assertionsDisabled = !Score.class.desiredAssertionStatus();
    }
}
