package ai.libs.jaicore.ml.scikitwrapper;

import ai.libs.jaicore.basic.FileUtil;
import ai.libs.jaicore.basic.ResourceUtil;
import ai.libs.jaicore.ml.evaluation.IInstancesClassifier;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.jtwig.JtwigModel;
import org.jtwig.JtwigTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/ScikitLearnWrapper.class */
public class ScikitLearnWrapper implements IInstancesClassifier, Classifier {
    private static final String PYTHON_FILE_EXT = ".py";
    private static final String MODEL_DUMP_FILE_EXT = ".pcl";
    private static final String RESULT_FILE_EXT = ".json";
    private static final Logger L = LoggerFactory.getLogger(ScikitLearnWrapper.class);
    private static final File TMP_FOLDER = new File("tmp");
    private static final String RES_SCIKIT_TEMPLATE_PATH = "sklearn/scikit_template.twig.py";
    private static final File SCIKIT_TEMPLATE = new File(ResourceUtil.getResourceAsTempFile(RES_SCIKIT_TEMPLATE_PATH));
    private static final File MODEL_DUMPS_DIRECTORY = new File(TMP_FOLDER, "model_dumps");
    private static final boolean VERBOSE = false;
    private static final boolean DELETE_TEMPORARY_FILES_ON_EXIT = true;
    private ProblemType problemType;
    private int[] targetColumns;
    private final String configurationUID;
    private File modelFile;
    private File trainArff;
    private final boolean withoutModelDump;
    private String constructInstruction;
    private transient List<List<Double>> rawLastClassificationResults;

    /* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/ScikitLearnWrapper$ProblemType.class */
    public enum ProblemType {
        REGRESSION,
        CLASSIFICATION
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/ScikitLearnWrapper$SKLearnWrapperCommandBuilder.class */
    public class SKLearnWrapperCommandBuilder {
        private static final String ARFF_FLAG = "--arff";
        private static final String TEST_ARFF_FLAG = "--testarff";
        private static final String MODE_FLAG = "--mode";
        private static final String MODEL_FLAG = "--model";
        private static final String OUTPUT_FLAG = "--output";
        private static final String REGRESSION_FLAG = "--regression";
        private String arffFile;
        private String testArffFile;
        private WrapperExecutionMode mode;
        private String modelFile;
        private String outputFile;

        private SKLearnWrapperCommandBuilder() {
        }

        public SKLearnWrapperCommandBuilder withTestArffFile(File file) {
            this.testArffFile = file.getAbsolutePath();
            return this;
        }

        public SKLearnWrapperCommandBuilder withTrainMode() {
            return withMode(WrapperExecutionMode.TRAIN);
        }

        public SKLearnWrapperCommandBuilder withTestMode() {
            return withMode(WrapperExecutionMode.TEST);
        }

        public SKLearnWrapperCommandBuilder withTrainTestMode() {
            return withMode(WrapperExecutionMode.TRAIN_TEST);
        }

        private SKLearnWrapperCommandBuilder withMode(WrapperExecutionMode wrapperExecutionMode) {
            this.mode = wrapperExecutionMode;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public SKLearnWrapperCommandBuilder withModelFile(File file) {
            if (!file.exists()) {
                throw new IllegalArgumentException("Model dump does not exist");
            }
            this.modelFile = file.getAbsolutePath();
            return this;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public SKLearnWrapperCommandBuilder withOutputFile(File file) {
            this.outputFile = file.getAbsolutePath();
            return this;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public SKLearnWrapperCommandBuilder withArffFile(File file) {
            if (!file.exists()) {
                throw new IllegalArgumentException("Arff File does not exist.");
            }
            this.arffFile = file.getAbsolutePath();
            return this;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public String[] toCommandArray() {
            Objects.requireNonNull(this.mode);
            Objects.requireNonNull(this.outputFile);
            Objects.requireNonNull(this.arffFile);
            File sKLearnScriptFile = ScikitLearnWrapper.this.getSKLearnScriptFile();
            if (!sKLearnScriptFile.exists()) {
                throw new IllegalArgumentException("The wrapped sklearn script " + sKLearnScriptFile.getAbsolutePath() + " file does not exist");
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add("python");
            arrayList.add("-u");
            arrayList.add(sKLearnScriptFile.getAbsolutePath());
            arrayList.addAll(Arrays.asList(MODE_FLAG, this.mode.toString()));
            arrayList.addAll(Arrays.asList(ARFF_FLAG, this.arffFile));
            if (this.testArffFile != null) {
                arrayList.addAll(Arrays.asList(TEST_ARFF_FLAG, this.testArffFile));
            }
            arrayList.addAll(Arrays.asList(OUTPUT_FLAG, this.outputFile));
            if (ScikitLearnWrapper.this.problemType == ProblemType.REGRESSION) {
                arrayList.add(REGRESSION_FLAG);
            }
            if (this.mode == WrapperExecutionMode.TEST) {
                Objects.requireNonNull(this.modelFile);
                arrayList.addAll(Arrays.asList(MODEL_FLAG, this.modelFile));
            }
            if (ScikitLearnWrapper.this.targetColumns != null && ScikitLearnWrapper.this.targetColumns.length > 0) {
                arrayList.add("--targets");
                for (int i : ScikitLearnWrapper.this.targetColumns) {
                    arrayList.add("" + i);
                }
            }
            return (String[]) arrayList.toArray(new String[0]);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/libs/jaicore/ml/scikitwrapper/ScikitLearnWrapper$WrapperExecutionMode.class */
    public enum WrapperExecutionMode {
        TRAIN("train"),
        TEST("test"),
        TRAIN_TEST("traintest");

        private String name;

        WrapperExecutionMode(String str) {
            this.name = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.name;
        }
    }

    public ScikitLearnWrapper(String str, String str2, boolean z) throws IOException {
        this.problemType = ProblemType.CLASSIFICATION;
        this.targetColumns = new int[0];
        this.rawLastClassificationResults = null;
        this.withoutModelDump = z;
        this.constructInstruction = str;
        Map<String, Object> templateValueMap = getTemplateValueMap(str, str2);
        String str3 = StringUtils.join(new String[]{str, str2}).hashCode() + "";
        this.configurationUID = str3.startsWith("-") ? str3.replace("-", "1") : "0" + str3;
        if (!TMP_FOLDER.exists()) {
            TMP_FOLDER.mkdirs();
        }
        File sKLearnScriptFile = getSKLearnScriptFile();
        if (!sKLearnScriptFile.createNewFile() && L.isDebugEnabled()) {
            L.debug("Script file for configuration UID {} already exists in {}", this.configurationUID, sKLearnScriptFile.getAbsolutePath());
        }
        sKLearnScriptFile.deleteOnExit();
        JtwigTemplate.fileTemplate(SCIKIT_TEMPLATE).render(JtwigModel.newModel(templateValueMap), new FileOutputStream(sKLearnScriptFile));
    }

    public ScikitLearnWrapper(String str, String str2) throws IOException {
        this(str, str2, false);
    }

    public ScikitLearnWrapper(String str, String str2, File file) throws IOException {
        this(str, str2, false);
        this.modelFile = file;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public File getSKLearnScriptFile() {
        Objects.requireNonNull(this.configurationUID);
        return new File(TMP_FOLDER, this.configurationUID + PYTHON_FILE_EXT);
    }

    private File getResultFile(String str) {
        return new File(MODEL_DUMPS_DIRECTORY, str + "_" + this.configurationUID + RESULT_FILE_EXT);
    }

    public void buildClassifier(Instances instances) throws Exception {
        MODEL_DUMPS_DIRECTORY.mkdirs();
        String arffName = getArffName(instances);
        this.trainArff = getArffFile(instances, arffName);
        if (this.withoutModelDump) {
            return;
        }
        this.modelFile = new File(MODEL_DUMPS_DIRECTORY, this.configurationUID + "_" + arffName + MODEL_DUMP_FILE_EXT);
        String[] commandArray = new SKLearnWrapperCommandBuilder().withTrainMode().withArffFile(this.trainArff).withOutputFile(this.modelFile).toCommandArray();
        if (L.isDebugEnabled()) {
            L.debug("{} run train mode {}", Thread.currentThread().getName(), Arrays.toString(commandArray));
        }
        runProcess(commandArray, new DefaultProcessListener(false));
    }

    private File getArffFile(Instances instances, String str) throws IOException {
        File file = new File(TMP_FOLDER, str + ".arff");
        file.deleteOnExit();
        if (file.exists()) {
            L.debug("Reusing {}.arff", str);
            return file;
        }
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file));
        try {
            bufferedWriter.write(instances.toString());
            bufferedWriter.close();
            return file;
        } catch (Throwable th) {
            try {
                bufferedWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.libs.jaicore.ml.evaluation.IInstancesClassifier
    public double[] classifyInstances(Instances instances) throws Exception {
        MODEL_DUMPS_DIRECTORY.mkdirs();
        String arffName = getArffName(instances);
        File arffFile = getArffFile(instances, arffName);
        File resultFile = getResultFile(arffName);
        resultFile.getParentFile().mkdirs();
        if (this.withoutModelDump) {
            String[] commandArray = new SKLearnWrapperCommandBuilder().withTrainTestMode().withArffFile(this.trainArff).withTestArffFile(arffFile).withOutputFile(resultFile).toCommandArray();
            if (L.isDebugEnabled()) {
                L.debug("Run train test mode with {}", Arrays.toString(commandArray));
            }
            runProcess(commandArray, new DefaultProcessListener(false));
        } else {
            String[] commandArray2 = new SKLearnWrapperCommandBuilder().withTestMode().withArffFile(arffFile).withModelFile(this.modelFile).withOutputFile(resultFile).toCommandArray();
            if (L.isDebugEnabled()) {
                L.debug("Run test mode with {}", Arrays.toString(commandArray2));
            }
            runProcess(commandArray2, new DefaultProcessListener(false));
        }
        try {
            String readFileAsString = FileUtil.readFileAsString(resultFile);
            Files.delete(resultFile.toPath());
            this.rawLastClassificationResults = (List) new ObjectMapper().readValue(readFileAsString, List.class);
            List list = (List) this.rawLastClassificationResults.stream().flatMap((v0) -> {
                return v0.stream();
            }).collect(Collectors.toList());
            double[] dArr = new double[list.size()];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = ((Double) list.get(i)).doubleValue();
            }
            return dArr;
        } catch (IOException e) {
            throw new IOException("Could not read result file or parse the json content to a list", e);
        }
    }

    public double classifyInstance(Instance instance) throws Exception {
        Instances instances = new Instances(instance.dataset(), 0);
        DenseInstance denseInstance = new DenseInstance(instance);
        denseInstance.setDataset(instances);
        instances.add(denseInstance);
        return classifyInstances(instances)[0];
    }

    public static String createImportStatementFromImportFolder(File file, boolean z) throws IOException {
        if (file == null || !file.exists() || file.list().length == 0) {
            return "";
        }
        if (!Arrays.asList(file.list()).contains("__init__.py")) {
            File file2 = new File(file, "__init__.py");
            if (!file2.createNewFile() && L.isDebugEnabled()) {
                L.debug("Init file {} exists already", file2.getAbsolutePath());
            }
        }
        StringBuilder sb = new StringBuilder();
        String absolutePath = file.getAbsolutePath();
        sb.append("\n");
        sb.append("sys.path.append(r'" + absolutePath + "')\n");
        for (File file3 : file.listFiles()) {
            if (!file3.getName().startsWith("__")) {
                if (z) {
                    sb.append("import " + file3.getName().substring(0, file3.getName().length() - 3) + "\n");
                } else {
                    sb.append("from " + file3.getName().substring(0, file3.getName().length() - 3) + " import *\n");
                }
            }
        }
        return sb.toString();
    }

    private Map<String, Object> getTemplateValueMap(String str, String str2) {
        if (str == null || str.isEmpty()) {
            throw new AssertionError("Construction command for classifier must be stated.");
        }
        HashMap hashMap = new HashMap();
        hashMap.put("imports", str2 != null ? str2 : "");
        hashMap.put("classifier_construct", str);
        return hashMap;
    }

    public static String getImportString(Collection<String> collection) {
        return (collection == null || collection.isEmpty()) ? "" : "import " + StringUtils.join(collection, "\nimport ");
    }

    public List<List<Double>> getRawLastClassificationResults() {
        return this.rawLastClassificationResults;
    }

    public void setProblemType(ProblemType problemType) {
        this.problemType = problemType;
    }

    public void setTargets(int... iArr) {
        this.targetColumns = iArr;
    }

    public void setModelPath(File file) {
        this.modelFile = file;
    }

    public File getModelPath() {
        return this.modelFile;
    }

    private String getArffName(Instances instances) {
        String str = "" + instances.hashCode();
        return str.startsWith("-") ? str.replace("-", "1") : "0" + str;
    }

    private void runProcess(String[] strArr, AProcessListener aProcessListener) throws InterruptedException, IOException {
        if (L.isDebugEnabled()) {
            String replace = Arrays.toString(strArr).replace(",", "");
            L.debug("Starting process {}", replace.substring(1, replace.length() - 1));
        }
        aProcessListener.listenTo(new ProcessBuilder(strArr).directory(TMP_FOLDER).start());
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        throw new UnsupportedOperationException("This method is not yet implemented");
    }

    public Capabilities getCapabilities() {
        return null;
    }

    public String toString() {
        return this.constructInstruction;
    }
}
