package ai.libs.jaicore.ml.learningcurve.extrapolation;

import ai.libs.jaicore.basic.ILoggingCustomizable;
import ai.libs.jaicore.basic.algorithm.AlgorithmExecutionCanceledException;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.dataset.DatasetCreationException;
import ai.libs.jaicore.ml.core.dataset.ILabeledAttributeArrayInstance;
import ai.libs.jaicore.ml.core.dataset.IOrderedLabeledAttributeArrayDataset;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.ASamplingAlgorithm;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.factories.interfaces.IRerunnableSamplingAlgorithmFactory;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import ai.libs.jaicore.ml.core.dataset.weka.WekaInstances;
import ai.libs.jaicore.ml.interfaces.LearningCurve;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.UnsupportedAttributeTypeException;

/* loaded from: input_file:ai/libs/jaicore/ml/learningcurve/extrapolation/LearningCurveExtrapolator.class */
public class LearningCurveExtrapolator<I extends ILabeledAttributeArrayInstance<?>, D extends IOrderedLabeledAttributeArrayDataset<I, ?>> implements ILoggingCustomizable {
    protected Classifier learner;
    protected D dataset;
    protected D train;
    protected D test;
    protected ISamplingAlgorithmFactory<I, D, ? extends ASamplingAlgorithm<I, D>> samplingAlgorithmFactory;
    protected Random random;
    protected LearningCurveExtrapolationMethod extrapolationMethod;
    private final int[] anchorPoints;
    private final double[] yValues;
    private final int[] trainingTimes;
    private Logger logger = LoggerFactory.getLogger(LearningCurveExtrapolator.class);
    protected ASamplingAlgorithm<I, D> samplingAlgorithm = null;

    public LearningCurveExtrapolator(LearningCurveExtrapolationMethod learningCurveExtrapolationMethod, Classifier classifier, D d, double d2, int[] iArr, ISamplingAlgorithmFactory<I, D, ? extends ASamplingAlgorithm<I, D>> iSamplingAlgorithmFactory, long j) throws DatasetCreationException {
        this.extrapolationMethod = learningCurveExtrapolationMethod;
        this.learner = classifier;
        this.dataset = d;
        this.anchorPoints = iArr;
        this.samplingAlgorithmFactory = iSamplingAlgorithmFactory;
        this.random = new Random(j);
        createSplit(d2, j);
        this.yValues = new double[this.anchorPoints.length];
        this.trainingTimes = new int[this.anchorPoints.length];
    }

    public LearningCurve extrapolateLearningCurve() throws InvalidAnchorPointsException, AlgorithmException, InterruptedException {
        try {
            Instances list = ((WekaInstances) this.test).getList();
            for (int i = 0; i < this.anchorPoints.length; i++) {
                if ((this.samplingAlgorithmFactory instanceof IRerunnableSamplingAlgorithmFactory) && this.samplingAlgorithm != null) {
                    ((IRerunnableSamplingAlgorithmFactory) this.samplingAlgorithmFactory).setPreviousRun(this.samplingAlgorithm);
                }
                this.samplingAlgorithm = this.samplingAlgorithmFactory.getAlgorithm(this.anchorPoints[i], this.train, this.random);
                D m25call = this.samplingAlgorithm.m25call();
                this.logger.debug("Running classifier with {} data points.", Integer.valueOf(this.anchorPoints[i]));
                long currentTimeMillis = System.currentTimeMillis();
                this.learner.buildClassifier(((WekaInstances) m25call).getList());
                this.trainingTimes[i] = (int) (System.currentTimeMillis() - currentTimeMillis);
                double d = 0.0d;
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    Instance instance = (Instance) it.next();
                    if (this.learner.classifyInstance(instance) == instance.classValue()) {
                        d += 1.0d;
                    }
                }
                this.yValues[i] = d / list.size();
                this.logger.debug("Training finished. Observed learning curve value (accuracy) of {}.", Double.valueOf(this.yValues[i]));
            }
            if (this.logger.isInfoEnabled()) {
                this.logger.info("Computed accuracies of {} for anchor points {}. Now extrapolating a curve from these observations.", Arrays.toString(this.yValues), Arrays.toString(this.anchorPoints));
            }
            return this.extrapolationMethod.extrapolateLearningCurveFromAnchorPoints(this.anchorPoints, this.yValues, this.dataset.size());
        } catch (AlgorithmExecutionCanceledException | TimeoutException | AlgorithmException e) {
            throw new AlgorithmException(e, "Error during creation of the subsamples for the anchorpoints");
        } catch (InvalidAnchorPointsException | InterruptedException e2) {
            throw e2;
        } catch (ExecutionException e3) {
            throw new AlgorithmException(e3, "Error during learning curve extrapolation");
        } catch (Exception e4) {
            throw new AlgorithmException(e4, "Error during training/testing the classifier");
        } catch (UnsupportedAttributeTypeException e5) {
            throw new AlgorithmException(e5, "Error during convertion of the dataset to WEKA instances");
        }
    }

    private void createSplit(double d, long j) throws DatasetCreationException {
        long currentTimeMillis = System.currentTimeMillis();
        this.logger.debug("Creating split with training portion {} and seed {}", Double.valueOf(d), Long.valueOf(j));
        Random random = new Random(j);
        this.train = (D) this.dataset.createEmpty();
        this.test = (D) this.dataset.createEmpty();
        IOrderedLabeledAttributeArrayDataset iOrderedLabeledAttributeArrayDataset = (IOrderedLabeledAttributeArrayDataset) this.dataset.createEmpty();
        iOrderedLabeledAttributeArrayDataset.addAll(this.dataset);
        Collections.shuffle(iOrderedLabeledAttributeArrayDataset, random);
        HashMap hashMap = new HashMap();
        for (ILabeledAttributeArrayInstance iLabeledAttributeArrayInstance : this.dataset) {
            Object targetValue = iLabeledAttributeArrayInstance.getTargetValue2();
            if (!hashMap.containsKey(targetValue)) {
                hashMap.put(targetValue, (IOrderedLabeledAttributeArrayDataset) this.dataset.createEmpty());
            }
            ((IOrderedLabeledAttributeArrayDataset) hashMap.get(targetValue)).add(iLabeledAttributeArrayInstance);
        }
        HashMap hashMap2 = new HashMap(hashMap.size());
        for (Map.Entry entry : hashMap.entrySet()) {
            hashMap2.put(entry.getKey(), Integer.valueOf(((IOrderedLabeledAttributeArrayDataset) hashMap.get(entry.getKey())).size()));
        }
        Iterator it = hashMap.entrySet().iterator();
        while (it.hasNext()) {
            IOrderedLabeledAttributeArrayDataset iOrderedLabeledAttributeArrayDataset2 = (IOrderedLabeledAttributeArrayDataset) hashMap.get(((Map.Entry) it.next()).getKey());
            if (!iOrderedLabeledAttributeArrayDataset2.isEmpty()) {
                this.train.add(iOrderedLabeledAttributeArrayDataset2.get(0));
                iOrderedLabeledAttributeArrayDataset2.remove(0);
            }
            if (!iOrderedLabeledAttributeArrayDataset2.isEmpty()) {
                this.test.add(iOrderedLabeledAttributeArrayDataset2.get(0));
                iOrderedLabeledAttributeArrayDataset2.remove(0);
            }
        }
        Iterator it2 = hashMap.entrySet().iterator();
        while (it2.hasNext()) {
            IOrderedLabeledAttributeArrayDataset iOrderedLabeledAttributeArrayDataset3 = (IOrderedLabeledAttributeArrayDataset) hashMap.get(((Map.Entry) it2.next()).getKey());
            int min = (int) Math.min(iOrderedLabeledAttributeArrayDataset3.size(), Math.ceil(d * ((Integer) hashMap2.get(r0.getKey())).intValue()));
            for (int i = 0; i < min; i++) {
                this.train.add(iOrderedLabeledAttributeArrayDataset3.get(0));
                iOrderedLabeledAttributeArrayDataset3.remove(0);
            }
            int min2 = (int) Math.min(iOrderedLabeledAttributeArrayDataset3.size(), Math.ceil((1.0d - d) * ((Integer) hashMap2.get(r0.getKey())).intValue()));
            for (int i2 = 0; i2 < min2; i2++) {
                this.test.add(iOrderedLabeledAttributeArrayDataset3.get(0));
                iOrderedLabeledAttributeArrayDataset3.remove(0);
            }
        }
        this.logger.debug("Shuffling train and test data");
        Collections.shuffle(this.train, random);
        Collections.shuffle(this.test, random);
        this.logger.debug("Finished split creation after {}ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
    }

    public Classifier getLearner() {
        return this.learner;
    }

    public D getDataset() {
        return this.dataset;
    }

    public LearningCurveExtrapolationMethod getExtrapolationMethod() {
        return this.extrapolationMethod;
    }

    public int[] getAnchorPoints() {
        return this.anchorPoints;
    }

    public double[] getyValues() {
        return this.yValues;
    }

    public int[] getTrainingTimes() {
        return this.trainingTimes;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
    }
}
