/*
 * Decompiled with CFR 0.152.
 */
package smile.validation.metric;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;

public record ConfusionMatrix(int[][] matrix) implements Serializable
{
    private static final long serialVersionUID = 3L;

    public static ConfusionMatrix of(int[] truth, int[] prediction) {
        if (truth.length != prediction.length) {
            throw new IllegalArgumentException(String.format("The vector sizes don't match: %d != %d.", truth.length, prediction.length));
        }
        HashSet<Integer> y = new HashSet<Integer>();
        for (int i = 0; i < truth.length; ++i) {
            y.add(truth[i]);
            y.add(prediction[i]);
        }
        int k = 0;
        Iterator iterator = y.iterator();
        while (iterator.hasNext()) {
            int c = (Integer)iterator.next();
            if (k >= c) continue;
            k = c;
        }
        int[][] matrix = new int[k + 1][k + 1];
        for (int i = 0; i < truth.length; ++i) {
            int[] nArray = matrix[truth[i]];
            int n = prediction[i];
            nArray[n] = nArray[n] + 1;
        }
        return new ConfusionMatrix(matrix);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("ROW=truth and COL=predicted\n");
        for (int i = 0; i < this.matrix.length; ++i) {
            sb.append(String.format("class %2d |", i));
            for (int j = 0; j < this.matrix.length; ++j) {
                sb.append(String.format("%8d |", this.matrix[i][j]));
            }
            sb.append('\n');
        }
        return sb.toString().trim();
    }
}

