package meka.classifiers.multilabel;

import java.io.File;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Random;
import meka.classifiers.MultiXClassifier;
import meka.classifiers.multitarget.MultiTargetClassifier;
import meka.core.MLEvalUtils;
import meka.core.MLUtils;
import meka.core.Result;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SerializationHelper;
import weka.core.Utils;
import weka.core.converters.ArffSaver;
import weka.core.converters.ConverterUtils;

/* loaded from: input_file:meka/classifiers/multilabel/Evaluation.class */
public class Evaluation {
    public static void runExperiment(MultiLabelClassifier multiLabelClassifier, String[] strArr) throws Exception {
        Instances loadDataset;
        if (Utils.getOptionPos('h', strArr) >= 0) {
            System.out.println("\nHelp requested");
            printOptions(multiLabelClassifier.listOptions());
            return;
        }
        multiLabelClassifier.setOptions(strArr);
        if (multiLabelClassifier.getDebug()) {
            System.out.println("Loading and preparing dataset ...");
        }
        Instances loadDataset2 = loadDataset(strArr);
        MLUtils.prepareData(loadDataset2);
        if (Utils.getOptionPos('C', strArr) >= 0) {
            loadDataset2.setClassIndex(Integer.parseInt(Utils.getOption('C', strArr)));
        }
        int classIndex = loadDataset2.classIndex();
        if (classIndex <= 0) {
            throw new Exception("[Error] Number of labels not specified.\n\tYou must set the number of labels with the -C option, either inside the @relation tag of the Instances file, or on the command line.");
        }
        int parseInt = Utils.getOptionPos('s', strArr) >= 0 ? Integer.parseInt(Utils.getOption('s', strArr)) : 0;
        if (Utils.getFlag('R', strArr)) {
            loadDataset2.randomize(new Random(parseInt));
        }
        boolean flag = Utils.getOptionPos("Thr", strArr) >= 0 ? Utils.getFlag("Thr", strArr) : false;
        String option = Utils.getOptionPos("verbosity", strArr) >= 0 ? Utils.getOption("verbosity", strArr) : "1";
        String option2 = Utils.getOptionPos('d', strArr) >= 0 ? Utils.getOption('d', strArr) : null;
        String str = null;
        if (Utils.getOptionPos('l', strArr) >= 0) {
            str = Utils.getOption('l', strArr);
            Object[] readAll = SerializationHelper.readAll(str);
            multiLabelClassifier = (MultiLabelClassifier) readAll[0];
            if (readAll.length > 1) {
            }
        }
        try {
            Result result = null;
            String option3 = Utils.getOptionPos("threshold", strArr) >= 0 ? Utils.getOption("threshold", strArr) : "PCut1";
            String option4 = Utils.getOption("predictions", strArr);
            boolean z = !Utils.getFlag("no-eval", strArr);
            if (Utils.getOptionPos('x', strArr) >= 0) {
                if (!option4.isEmpty()) {
                    System.err.println("Predictions cannot be saved when using cross-validation!");
                }
                int integerOption = MLUtils.getIntegerOption(Utils.getOption('x', strArr), 10);
                Utils.checkForRemainingOptions(strArr);
                System.out.println(cvModel(multiLabelClassifier, loadDataset2, integerOption, option3, option).toString());
            } else {
                if (Utils.getOptionPos('T', strArr) >= 0) {
                    try {
                        loadDataset = loadDataset(strArr, 'T');
                        MLUtils.prepareData(loadDataset);
                    } catch (Exception e) {
                        throw new Exception("[Error] Failed to Load Test Instances from file.", e);
                    }
                } else {
                    int numInstances = (int) (loadDataset2.numInstances() * 0.6d);
                    if (Utils.getOptionPos("split-percentage", strArr) >= 0) {
                        numInstances = (int) Math.round(loadDataset2.numInstances() * (Double.parseDouble(Utils.getOption("split-percentage", strArr)) / 100.0d));
                    } else if (Utils.getOptionPos("split-number", strArr) >= 0) {
                        numInstances = Integer.parseInt(Utils.getOption("split-number", strArr));
                    }
                    loadDataset = new Instances(loadDataset2, numInstances, loadDataset2.numInstances() - numInstances);
                    loadDataset2 = new Instances(loadDataset2, 0, numInstances);
                }
                if (Utils.getFlag('i', strArr)) {
                    Instances instances = loadDataset;
                    loadDataset = loadDataset2;
                    loadDataset2 = instances;
                }
                Utils.checkForRemainingOptions(strArr);
                if (multiLabelClassifier.getDebug()) {
                    System.out.println(":- Dataset -: " + MLUtils.getDatasetName(loadDataset2) + "\tL=" + classIndex + "\tD(t:T)=(" + loadDataset2.numInstances() + ":" + loadDataset.numInstances() + ")\tLC(t:T)=" + Utils.roundDouble(MLUtils.labelCardinality(loadDataset2, classIndex), 2) + ":" + Utils.roundDouble(MLUtils.labelCardinality(loadDataset, classIndex), 2) + ")");
                }
                if (str != null) {
                    if (z) {
                        Result testClassifier = testClassifier(multiLabelClassifier, loadDataset);
                        String str2 = option3;
                        if (option3.startsWith("PCut")) {
                            str2 = MLEvalUtils.getThreshold(testClassifier.predictions, loadDataset2, option3);
                        }
                        result = evaluateModel(multiLabelClassifier, loadDataset, str2, option);
                    }
                } else if (loadDataset2.numInstances() <= 0 || loadDataset.numInstances() <= 0) {
                    multiLabelClassifier.buildClassifier(loadDataset2);
                } else if (z) {
                    result = flag ? evaluateModelM(multiLabelClassifier, loadDataset2, loadDataset, option3, option) : evaluateModel(multiLabelClassifier, loadDataset2, loadDataset, option3, option);
                } else {
                    multiLabelClassifier.buildClassifier(loadDataset2);
                }
                if (loadDataset2.numInstances() > 0 && loadDataset.numInstances() > 0 && result != null) {
                    System.out.println(result.toString());
                }
                if (!option4.isEmpty()) {
                    Instances instances2 = new Instances(loadDataset, 0);
                    for (int i = 0; i < loadDataset.numInstances(); i++) {
                        double[] distributionForInstance = multiLabelClassifier.distributionForInstance(loadDataset.instance(i));
                        if (multiLabelClassifier instanceof MultiTargetClassifier) {
                            distributionForInstance = Arrays.copyOf(distributionForInstance, loadDataset.classIndex());
                        }
                        Instance instance = (Instance) loadDataset.instance(i).copy();
                        for (int i2 = 0; i2 < distributionForInstance.length; i2++) {
                            instance.setValue(i2, Math.round(distributionForInstance[i2]));
                        }
                        instances2.add(instance);
                    }
                    ArffSaver saverForFile = ConverterUtils.getSaverForFile(option4);
                    if (saverForFile == null) {
                        System.err.println("Failed to determine saver for '" + option4 + "', using " + ArffSaver.class.getName());
                        saverForFile = new ArffSaver();
                    }
                    saverForFile.setFile(new File(option4));
                    saverForFile.setInstances(instances2);
                    saverForFile.writeBatch();
                    System.out.println("Predictions saved to: " + option4);
                }
            }
            if (option2 != null) {
                SerializationHelper.writeAll(option2, new Object[]{multiLabelClassifier, new Instances(loadDataset2, 0)});
            }
        } catch (Exception e2) {
            e2.printStackTrace();
            printOptions(multiLabelClassifier.listOptions());
            System.exit(1);
        }
        System.exit(0);
    }

    public static boolean isMT(Instances instances) {
        int classIndex = instances.classIndex();
        for (int i = 0; i < classIndex; i++) {
            if (!instances.attribute(i).isNominal()) {
                System.err.println("[Warning] Found a non-nominal class -- not sure how this happened?");
            } else if (instances.attribute(i).numValues() > 2) {
                return true;
            }
        }
        return false;
    }

    public static Result evaluateModel(MultiXClassifier multiXClassifier, Instances instances, Instances instances2, String str) throws Exception {
        return evaluateModel(multiXClassifier, instances, instances2, str, "1");
    }

    public static Result evaluateModel(MultiXClassifier multiXClassifier, Instances instances, Instances instances2, String str, String str2) throws Exception {
        Result evaluateModel = evaluateModel(multiXClassifier, instances, instances2);
        if ((multiXClassifier instanceof MultiTargetClassifier) || isMT(instances2)) {
            evaluateModel.setInfo("Type", "MT");
        } else if (multiXClassifier instanceof MultiLabelClassifier) {
            evaluateModel.setInfo("Type", "ML");
            evaluateModel.setInfo("Threshold", MLEvalUtils.getThreshold(evaluateModel.predictions, instances, str));
        }
        evaluateModel.setInfo("Verbosity", str2);
        evaluateModel.output = Result.getStats(evaluateModel, str2);
        return evaluateModel;
    }

    public static Result evaluateModel(MultiXClassifier multiXClassifier, Instances instances, String str, String str2) throws Exception {
        Result testClassifier = testClassifier(multiXClassifier, instances);
        if ((multiXClassifier instanceof MultiTargetClassifier) || isMT(instances)) {
            testClassifier.setInfo("Type", "MT");
        } else if (multiXClassifier instanceof MultiLabelClassifier) {
            testClassifier.setInfo("Type", "ML");
        }
        testClassifier.setInfo("Threshold", str);
        testClassifier.setInfo("Verbosity", str2);
        testClassifier.output = Result.getStats(testClassifier, str2);
        return testClassifier;
    }

    public static Result cvModel(MultiLabelClassifier multiLabelClassifier, Instances instances, int i, String str) throws Exception {
        return cvModel(multiLabelClassifier, instances, i, str, "1");
    }

    public static Result cvModel(MultiLabelClassifier multiLabelClassifier, Instances instances, int i, String str, String str2) throws Exception {
        Result[] resultArr = new Result[i];
        for (int i2 = 0; i2 < i; i2++) {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException("Thread has been interrupted.");
            }
            Instances trainCV = instances.trainCV(i, i2);
            Instances testCV = instances.testCV(i, i2);
            if (multiLabelClassifier.getDebug()) {
                System.out.println(":- Fold [" + i2 + "/" + i + "] -: " + MLUtils.getDatasetName(instances) + "\tL=" + instances.classIndex() + "\tD(t:T)=(" + trainCV.numInstances() + ":" + testCV.numInstances() + ")\tLC(t:T)=" + Utils.roundDouble(MLUtils.labelCardinality(trainCV, instances.classIndex()), 2) + ":" + Utils.roundDouble(MLUtils.labelCardinality(testCV, instances.classIndex()), 2) + ")");
            }
            resultArr[i2] = evaluateModel(multiLabelClassifier, trainCV, testCV);
        }
        Result combinePredictions = MLEvalUtils.combinePredictions(resultArr);
        if ((multiLabelClassifier instanceof MultiTargetClassifier) || isMT(instances)) {
            combinePredictions.setInfo("Type", "MT-CV");
        } else if (multiLabelClassifier instanceof MultiLabelClassifier) {
            combinePredictions.setInfo("Type", "ML-CV");
            try {
                combinePredictions.setInfo("Threshold", String.valueOf(Double.parseDouble(str)));
            } catch (Exception e) {
                System.err.println("[WARNING] Automatic threshold calibration not currently enabled for cross-fold validation, setting threshold = 0.5.\n");
                combinePredictions.setInfo("Threshold", String.valueOf(0.5d));
            }
        }
        combinePredictions.setInfo("Verbosity", str2);
        combinePredictions.output = Result.getStats(combinePredictions, str2);
        combinePredictions.setValue("Number of training instances", instances.numInstances());
        combinePredictions.setValue("Number of test instances", instances.numInstances());
        return combinePredictions;
    }

    public static Result evaluateModel(MultiXClassifier multiXClassifier, Instances instances, Instances instances2) throws Exception {
        long currentTimeMillis = System.currentTimeMillis();
        if (multiXClassifier instanceof SemisupervisedClassifier) {
            ((SemisupervisedClassifier) multiXClassifier).introduceUnlabelledData(MLUtils.setLabelsMissing(new Instances(instances2)));
        }
        multiXClassifier.buildClassifier(instances);
        long currentTimeMillis2 = System.currentTimeMillis();
        long currentTimeMillis3 = System.currentTimeMillis();
        Result testClassifier = testClassifier(multiXClassifier, instances2);
        long currentTimeMillis4 = System.currentTimeMillis();
        testClassifier.setValue("Number of training instances", instances.numInstances());
        testClassifier.setValue("Number of test instances", instances2.numInstances());
        testClassifier.setValue("Label cardinality (train set)", MLUtils.labelCardinality(instances));
        testClassifier.setValue("Label cardinality (test set)", MLUtils.labelCardinality(instances2));
        testClassifier.setValue("Build Time", (currentTimeMillis2 - currentTimeMillis) / 1000.0d);
        testClassifier.setValue("Test Time", (currentTimeMillis4 - currentTimeMillis3) / 1000.0d);
        testClassifier.setValue("Total Time", (currentTimeMillis4 - currentTimeMillis) / 1000.0d);
        testClassifier.setInfo("Classifier", multiXClassifier.getClass().getName());
        testClassifier.setInfo("Options", Arrays.toString(multiXClassifier.getOptions()));
        testClassifier.setInfo("Additional Info", multiXClassifier.toString());
        testClassifier.setInfo("Dataset", MLUtils.getDatasetName(instances));
        testClassifier.setInfo("Number of labels (L)", String.valueOf(instances.classIndex()));
        if (multiXClassifier.getModel().length() > 0) {
            testClassifier.setModel("Model", multiXClassifier.getModel());
        }
        return testClassifier;
    }

    public static Result evaluateModelM(MultiXClassifier multiXClassifier, Instances instances, Instances instances2, String str, String str2) throws Exception {
        long currentTimeMillis = System.currentTimeMillis();
        multiXClassifier.buildClassifier(instances);
        long currentTimeMillis2 = System.currentTimeMillis();
        long currentTimeMillis3 = System.currentTimeMillis();
        Result testClassifierM = testClassifierM(multiXClassifier, instances2);
        long currentTimeMillis4 = System.currentTimeMillis();
        testClassifierM.setValue("N_train", instances.numInstances());
        testClassifierM.setValue("N_test", instances2.numInstances());
        testClassifierM.setValue("LCard_train", MLUtils.labelCardinality(instances));
        testClassifierM.setValue("LCard_test", MLUtils.labelCardinality(instances2));
        testClassifierM.setValue("Build_time", (currentTimeMillis2 - currentTimeMillis) / 1000.0d);
        testClassifierM.setValue("Test_time", (currentTimeMillis4 - currentTimeMillis3) / 1000.0d);
        testClassifierM.setValue("Total_time", (currentTimeMillis4 - currentTimeMillis) / 1000.0d);
        testClassifierM.setInfo("Classifier_name", multiXClassifier.getClass().getName());
        testClassifierM.setInfo("Classifier_ops", Arrays.toString(multiXClassifier.getOptions()));
        testClassifierM.setInfo("Classifier_info", multiXClassifier.toString());
        testClassifierM.setInfo("Dataset_name", MLUtils.getDatasetName(instances));
        if ((multiXClassifier instanceof MultiTargetClassifier) || isMT(instances2)) {
            testClassifierM.setInfo("Type", "MT");
        } else if (multiXClassifier instanceof MultiLabelClassifier) {
            testClassifierM.setInfo("Type", "ML");
        }
        testClassifierM.setInfo("Threshold", MLEvalUtils.getThreshold(testClassifierM.predictions, instances, str));
        testClassifierM.setInfo("Verbosity", str2);
        testClassifierM.output = Result.getStats(testClassifierM, str2);
        return testClassifierM;
    }

    public static Result testClassifier(MultiXClassifier multiXClassifier, Instances instances) throws Exception {
        int numInstances;
        int classIndex = instances.classIndex();
        Result result = new Result(instances.numInstances(), classIndex);
        if (multiXClassifier.getDebug()) {
            System.out.print(":- Evaluate ");
        }
        int i = 0;
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException("Thread has been interrupted.");
            }
            if (multiXClassifier.getDebug() && (numInstances = (i2 * 50) / instances.numInstances()) > i) {
                System.out.print("#");
                i = numInstances;
            }
            Instance instance = (Instance) instances.instance(i2).copy();
            for (int i3 = 0; i3 < instances.classIndex(); i3++) {
                instance.setValue(i3, 0.0d);
            }
            double[] distributionForInstance = multiXClassifier.distributionForInstance(instance);
            if (multiXClassifier instanceof MultiTargetClassifier) {
                distributionForInstance = Arrays.copyOf(distributionForInstance, classIndex);
            }
            result.addResult(distributionForInstance, instances.instance(i2));
        }
        if (multiXClassifier.getDebug()) {
            System.out.println(":-");
        }
        return result;
    }

    public static Result testClassifierM(MultiXClassifier multiXClassifier, Instances instances) throws Exception {
        Result result = new Result(instances.numInstances(), instances.classIndex());
        if (multiXClassifier.getDebug()) {
            System.out.print(":- Evaluate ");
        }
        if (multiXClassifier instanceof MultiLabelClassifierThreaded) {
            ((MultiLabelClassifierThreaded) multiXClassifier).setThreaded(true);
            double[][] distributionForInstanceM = ((MultiLabelClassifierThreaded) multiXClassifier).distributionForInstanceM(instances);
            for (int i = 0; i < instances.numInstances(); i++) {
                result.addResult(distributionForInstanceM[i], instances.instance(i));
            }
            if (multiXClassifier.getDebug()) {
                System.out.println(":-");
            }
        }
        return result;
    }

    public static Instances loadDataset(String[] strArr) throws Exception {
        return loadDataset(strArr, 't');
    }

    public static Instances loadDataset(String[] strArr, char c) throws Exception {
        String option = Utils.getOption(c, strArr);
        if (option == null || option.isEmpty()) {
            throw new Exception("[Error] You did not specify a dataset!");
        }
        File file = new File(option);
        if (!file.exists()) {
            throw new Exception("[Error] File does not exist: " + option);
        }
        if (file.isDirectory()) {
            throw new Exception("[Error] " + option + " points to a directory!");
        }
        try {
            return new ConverterUtils.DataSource(option).getDataSet();
        } catch (Exception e) {
            e.printStackTrace();
            throw new Exception("[Error] Failed to load Instances from file '" + option + "'.");
        }
    }

    public static void printOptions(Enumeration enumeration) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("\n\nEvaluation Options:\n\n");
        stringBuffer.append("-h\n");
        stringBuffer.append("\tOutput help information.\n");
        stringBuffer.append("-t <name of training file>\n");
        stringBuffer.append("\tSets training file.\n");
        stringBuffer.append("-T <name of test file>\n");
        stringBuffer.append("\tSets test file (will be used for making predictions).\n");
        stringBuffer.append("-predictions <name of output file for predictions>\n");
        stringBuffer.append("\tSets the file to store the predictions in (does not work with cross-validation).\n");
        stringBuffer.append("-x <number of folds>\n");
        stringBuffer.append("\tDo cross-validation with this many folds.\n");
        stringBuffer.append("-no-eval\n");
        stringBuffer.append("\tSkips evaluation, e.g., used when test set contains no class labels.\n");
        stringBuffer.append("-R\n");
        stringBuffer.append("\tRandomize the order of instances in the dataset.\n");
        stringBuffer.append("-split-percentage <percentage>\n");
        stringBuffer.append("\tSets the percentage for the train/test set split, e.g., 66.\n");
        stringBuffer.append("-split-number <number>\n");
        stringBuffer.append("\tSets the number of training examples, e.g., 800\n");
        stringBuffer.append("-i\n");
        stringBuffer.append("\tInvert the specified train/test split.\n");
        stringBuffer.append("-s <random number seed>\n");
        stringBuffer.append("\tSets random number seed (use with -R, for different CV or train/test splits).\n");
        stringBuffer.append("-threshold <threshold>\n");
        stringBuffer.append("\tSets the type of thresholding; where\n\t\t'PCut1' automatically calibrates a threshold (the default);\n\t\t'PCutL' automatically calibrates one threshold for each label;\n\t\tany number, e.g. '0.5', specifies that threshold.\n");
        stringBuffer.append("-C <number of labels>\n");
        stringBuffer.append("\tSets the number of target variables (labels) to assume (indexed from the beginning).\n");
        stringBuffer.append("-d <classifier_file>\n");
        stringBuffer.append("\tSpecify a file to dump classifier into.\n");
        stringBuffer.append("-l <classifier_file>\n");
        stringBuffer.append("\tSpecify a file to load classifier from.\n");
        stringBuffer.append("-verbosity <verbosity level>\n");
        stringBuffer.append("\tSpecify more/less evaluation output\n");
        stringBuffer.append("\n\nClassifier Options:\n\n");
        while (enumeration.hasMoreElements()) {
            Option option = (Option) enumeration.nextElement();
            stringBuffer.append("-" + option.name() + '\n');
            stringBuffer.append("" + option.description() + '\n');
        }
        System.out.println(stringBuffer);
    }
}
