/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.io.Serializable;
import java.util.Arrays;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.validation.metric.AUC;
import smile.validation.metric.Accuracy;
import smile.validation.metric.CrossEntropy;
import smile.validation.metric.Error;
import smile.validation.metric.FScore;
import smile.validation.metric.LogLoss;
import smile.validation.metric.MatthewsCorrelation;
import smile.validation.metric.Precision;
import smile.validation.metric.Sensitivity;
import smile.validation.metric.Specificity;

public record ClassificationMetrics(double fitTime, double scoreTime, int size, int error, double accuracy, double sensitivity, double specificity, double precision, double f1, double mcc, double auc, double logloss, double crossentropy) implements Serializable
{
    private static final long serialVersionUID = 3L;

    public ClassificationMetrics(double fitTime, double scoreTime, int size, int error, double accuracy) {
        this(fitTime, scoreTime, size, error, accuracy, Double.NaN);
    }

    public ClassificationMetrics(double fitTime, double scoreTime, int size, int error, double accuracy, double crossentropy) {
        this(fitTime, scoreTime, size, error, accuracy, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, crossentropy);
    }

    public ClassificationMetrics(double fitTime, double scoreTime, int size, int error, double accuracy, double sensitivity, double specificity, double precision, double f1, double mcc) {
        this(fitTime, scoreTime, size, error, accuracy, sensitivity, specificity, precision, f1, mcc, Double.NaN, Double.NaN);
    }

    public ClassificationMetrics(double fitTime, double scoreTime, int size, int error, double accuracy, double sensitivity, double specificity, double precision, double f1, double mcc, double auc, double logloss) {
        this(fitTime, scoreTime, size, error, accuracy, sensitivity, specificity, precision, f1, mcc, auc, logloss, logloss);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder("{\n");
        if (!Double.isNaN(this.fitTime)) {
            sb.append(String.format("  fit time: %.3f ms,\n", this.fitTime));
        }
        sb.append(String.format("  score time: %.3f ms,\n", this.scoreTime));
        sb.append(String.format("  validation data size: %d,\n", this.size));
        sb.append(String.format("  error: %d,\n", this.error));
        sb.append(String.format("  accuracy: %.2f%%", 100.0 * this.accuracy));
        if (!Double.isNaN(this.sensitivity)) {
            sb.append(String.format(",\n  sensitivity: %.2f%%", 100.0 * this.sensitivity));
        }
        if (!Double.isNaN(this.specificity)) {
            sb.append(String.format(",\n  specificity: %.2f%%", 100.0 * this.specificity));
        }
        if (!Double.isNaN(this.precision)) {
            sb.append(String.format(",\n  precision: %.2f%%", 100.0 * this.precision));
        }
        if (!Double.isNaN(this.f1)) {
            sb.append(String.format(",\n  F1 score: %.2f%%", 100.0 * this.f1));
        }
        if (!Double.isNaN(this.mcc)) {
            sb.append(String.format(",\n  MCC: %.2f%%", 100.0 * this.mcc));
        }
        if (!Double.isNaN(this.auc)) {
            sb.append(String.format(",\n  AUC: %.2f%%", 100.0 * this.auc));
        }
        if (!Double.isNaN(this.logloss)) {
            sb.append(String.format(",\n  log loss: %.4f", this.logloss));
        } else if (!Double.isNaN(this.crossentropy)) {
            sb.append(String.format(",\n  cross entropy: %.4f", this.crossentropy));
        }
        sb.append("\n}");
        return sb.toString();
    }

    public static ClassificationMetrics of(double fitTime, double scoreTime, int[] truth, int[] prediction) {
        if (MathEx.unique((int[])truth).length == 2) {
            return new ClassificationMetrics(fitTime, scoreTime, truth.length, Error.of(truth, prediction), Accuracy.of(truth, prediction), Sensitivity.of(truth, prediction), Specificity.of(truth, prediction), Precision.of(truth, prediction), FScore.F1.score(truth, prediction), MatthewsCorrelation.of(truth, prediction));
        }
        return new ClassificationMetrics(fitTime, scoreTime, truth.length, Error.of(truth, prediction), Accuracy.of(truth, prediction));
    }

    public static ClassificationMetrics of(double fitTime, double scoreTime, int[] truth, int[] prediction, double[][] posteriori) {
        if (posteriori[0].length == 2) {
            double[] probability = Arrays.stream(posteriori).mapToDouble(p -> p[1]).toArray();
            return new ClassificationMetrics(fitTime, scoreTime, truth.length, Error.of(truth, prediction), Accuracy.of(truth, prediction), Sensitivity.of(truth, prediction), Specificity.of(truth, prediction), Precision.of(truth, prediction), FScore.F1.score(truth, prediction), MatthewsCorrelation.of(truth, prediction), AUC.of(truth, probability), LogLoss.of(truth, probability));
        }
        return new ClassificationMetrics(fitTime, scoreTime, truth.length, Error.of(truth, prediction), Accuracy.of(truth, prediction), CrossEntropy.of(truth, posteriori));
    }

    public static <T, M extends Classifier<T>> ClassificationMetrics of(M model, T[] testx, int[] testy) {
        return ClassificationMetrics.of(Double.NaN, model, testx, testy);
    }

    public static <T, M extends Classifier<T>> ClassificationMetrics of(double fitTime, M model, T[] testx, int[] testy) {
        int k = MathEx.unique((int[])testy).length;
        long start = System.nanoTime();
        if (model.soft()) {
            double[][] posteriori = new double[testx.length][k];
            int[] prediction = model.predict(testx, posteriori);
            double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
            return ClassificationMetrics.of(fitTime, scoreTime, testy, prediction, posteriori);
        }
        int[] prediction = model.predict(testx);
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        return ClassificationMetrics.of(fitTime, scoreTime, testy, prediction);
    }

    public static <M extends DataFrameClassifier> ClassificationMetrics of(M model, Formula formula, DataFrame test) {
        return ClassificationMetrics.of(Double.NaN, model, formula, test);
    }

    public static <M extends DataFrameClassifier> ClassificationMetrics of(double fitTime, M model, Formula formula, DataFrame test) {
        int[] testy = formula.y(test).toIntArray();
        int k = MathEx.unique((int[])testy).length;
        long start = System.nanoTime();
        int n = test.nrow();
        int[] prediction = new int[n];
        if (model.soft()) {
            double[][] posteriori = new double[n][k];
            for (int i = 0; i < n; ++i) {
                prediction[i] = model.predict((Tuple)test.get(i), posteriori[i]);
            }
            double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
            return ClassificationMetrics.of(fitTime, scoreTime, testy, prediction, posteriori);
        }
        for (int i = 0; i < n; ++i) {
            prediction[i] = model.predict((Tuple)test.get(i));
        }
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        return ClassificationMetrics.of(fitTime, scoreTime, testy, prediction);
    }
}

