/*
 * Decompiled with CFR 0.152.
 */
package smile.validation.metric;

import smile.math.MathEx;
import smile.validation.metric.Averaging;
import smile.validation.metric.ClassificationMetric;

public class Recall
implements ClassificationMetric {
    private static final long serialVersionUID = 2L;
    public static final Recall instance = new Recall();
    private final Averaging strategy;

    public Recall() {
        this(null);
    }

    public Recall(Averaging strategy) {
        this.strategy = strategy;
    }

    @Override
    public double score(int[] truth, int[] prediction) {
        return Recall.of(truth, prediction, this.strategy);
    }

    public String toString() {
        return this.strategy == null ? "Recall" : String.valueOf((Object)this.strategy) + "-Recall";
    }

    public static double of(int[] truth, int[] prediction) {
        for (int t : truth) {
            if (t == 0 || t == 1) continue;
            throw new IllegalArgumentException("Recall can only be applied to binary classification: " + t);
        }
        for (int p : prediction) {
            if (p == 0 || p == 1) continue;
            throw new IllegalArgumentException("Recall can only be applied to binary classification: " + p);
        }
        return Recall.of(truth, prediction, null);
    }

    public static double of(int[] truth, int[] prediction, Averaging strategy) {
        if (truth.length != prediction.length) {
            throw new IllegalArgumentException(String.format("The vector sizes don't match: %d != %d.", truth.length, prediction.length));
        }
        int numClasses = Math.max(MathEx.max((int[])truth), MathEx.max((int[])prediction)) + 1;
        if (numClasses > 2 && strategy == null) {
            throw new IllegalArgumentException("Averaging strategy is null for multi-class");
        }
        int length = strategy == Averaging.Macro || strategy == Averaging.Weighted ? numClasses : 1;
        int[] tp = new int[length];
        int[] size = new int[numClasses];
        int n = truth.length;
        int[] nArray = truth;
        int n2 = nArray.length;
        for (int i = 0; i < n2; ++i) {
            int target;
            int n3 = target = nArray[i];
            size[n3] = size[n3] + 1;
        }
        if (strategy == null) {
            for (i = 0; i < n; ++i) {
                if (prediction[i] != 1 || truth[i] != 1) continue;
                tp[0] = tp[0] + 1;
            }
        } else if (strategy == Averaging.Micro) {
            for (i = 0; i < n; ++i) {
                tp[0] = tp[0] + (truth[i] == prediction[i] ? 1 : 0);
            }
        } else {
            for (i = 0; i < n; ++i) {
                int n4 = truth[i];
                tp[n4] = tp[n4] + (truth[i] == prediction[i] ? 1 : 0);
            }
        }
        double[] recall = new double[tp.length];
        if (tp.length == 1) {
            recall[0] = (double)tp[0] / (double)(strategy == null ? size[1] : n);
        } else {
            for (int i = 0; i < tp.length; ++i) {
                recall[i] = (double)tp[i] / (double)size[i];
            }
        }
        if (strategy == Averaging.Macro) {
            return MathEx.mean((double[])recall);
        }
        if (strategy == Averaging.Weighted) {
            double weighted = 0.0;
            for (int i = 0; i < numClasses; ++i) {
                weighted += recall[i] * (double)size[i];
            }
            return weighted / (double)n;
        }
        return recall[0];
    }
}

