package ai.libs.jaicore.ml.evaluation.evaluators.weka;

import ai.libs.jaicore.basic.ILoggingCustomizable;
import ai.libs.jaicore.basic.algorithm.exceptions.ObjectEvaluationFailedException;
import ai.libs.jaicore.basic.events.IEvent;
import ai.libs.jaicore.basic.events.IEventEmitter;
import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.events.MCCVSplitEvaluationEvent;
import ai.libs.jaicore.ml.evaluation.evaluators.weka.splitevaluation.ISplitBasedClassifierEvaluator;
import ai.libs.jaicore.ml.weka.dataset.splitter.IDatasetSplitter;
import ai.libs.jaicore.ml.weka.dataset.splitter.MulticlassClassStratifiedSplitter;
import ai.libs.jaicore.ml.weka.dataset.splitter.SplitFailedException;
import com.google.common.eventbus.EventBus;
import com.google.common.eventbus.Subscribe;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/evaluation/evaluators/weka/MonteCarloCrossValidationEvaluator.class */
public class MonteCarloCrossValidationEvaluator implements IClassifierEvaluator, ILoggingCustomizable, IEventEmitter {
    private final EventBus eventBus;
    private boolean hasListeners;
    private Logger logger;
    private boolean canceled;
    private final IDatasetSplitter datasetSplitter;
    private final int repeats;
    private final Instances data;
    private final double trainingPortion;
    private final long seed;
    private final ISplitBasedClassifierEvaluator<Double> splitBasedEvaluator;
    private final Map<Long, List<Instances>> splitCache;

    public MonteCarloCrossValidationEvaluator(ISplitBasedClassifierEvaluator<Double> iSplitBasedClassifierEvaluator, IDatasetSplitter iDatasetSplitter, int i, Instances instances, double d, long j) {
        this.eventBus = new EventBus();
        this.logger = LoggerFactory.getLogger(MonteCarloCrossValidationEvaluator.class);
        this.canceled = false;
        this.splitCache = new HashMap();
        if (instances == null) {
            throw new IllegalArgumentException("Cannot work with NULL data");
        }
        if (iSplitBasedClassifierEvaluator == null) {
            throw new IllegalArgumentException("Cannot work with NULL split based evaluator");
        }
        this.datasetSplitter = iDatasetSplitter;
        this.repeats = i;
        this.splitBasedEvaluator = iSplitBasedClassifierEvaluator;
        if (this.splitBasedEvaluator instanceof IEventEmitter) {
            ((IEventEmitter) iSplitBasedClassifierEvaluator).registerListener(this);
        }
        this.data = instances;
        this.trainingPortion = d;
        this.seed = j;
    }

    public MonteCarloCrossValidationEvaluator(ISplitBasedClassifierEvaluator<Double> iSplitBasedClassifierEvaluator, int i, Instances instances, double d, long j) {
        this(iSplitBasedClassifierEvaluator, new MulticlassClassStratifiedSplitter(), i, instances, d, j);
    }

    public void cancel() {
        this.logger.info("Received cancel");
        this.canceled = true;
    }

    public Double evaluate(Classifier classifier) throws ObjectEvaluationFailedException, InterruptedException {
        return evaluate(classifier, new DescriptiveStatistics());
    }

    public Double evaluate(Classifier classifier, DescriptiveStatistics descriptiveStatistics) throws ObjectEvaluationFailedException, InterruptedException {
        if (classifier == null) {
            throw new IllegalArgumentException("Cannot compute score for null pipeline!");
        }
        long currentTimeMillis = System.currentTimeMillis();
        this.logger.info("Starting MMCV evaluation of {} (Description: {})", classifier.getClass().getName(), WekaUtil.getClassifierDescriptor(classifier));
        for (int i = 0; i < this.repeats && !this.canceled; i++) {
            this.logger.debug("Obtaining predictions of {} for split #{}/{}", new Object[]{classifier, Integer.valueOf(i + 1), Integer.valueOf(this.repeats)});
            if (Thread.interrupted()) {
                this.logger.info("MCCV has been interrupted, leaving MCCV.");
                throw new InterruptedException("MCCV has been interrupted.");
            }
            if (!this.splitCache.containsKey(Long.valueOf(this.seed + i))) {
                try {
                    this.splitCache.put(Long.valueOf(this.seed + i), this.datasetSplitter.split(this.data, this.seed + i, this.trainingPortion));
                } catch (SplitFailedException e) {
                    throw new ObjectEvaluationFailedException("Could not evaluate classifier!", e);
                }
            }
            List<Instances> list = this.splitCache.get(Long.valueOf(this.seed + i));
            try {
                long currentTimeMillis2 = System.currentTimeMillis();
                double doubleValue = this.splitBasedEvaluator.evaluateSplit(classifier, list.get(0), list.get(1)).doubleValue();
                if (this.hasListeners) {
                    this.eventBus.post(new MCCVSplitEvaluationEvent(classifier, list.get(0).size(), list.get(1).size(), (int) (System.currentTimeMillis() - currentTimeMillis2), doubleValue));
                }
                this.logger.info("Score for evaluation of {} with split #{}/{}: {} after {}ms", new Object[]{classifier.getClass().getName(), Integer.valueOf(i + 1), Integer.valueOf(this.repeats), Double.valueOf(doubleValue), Long.valueOf(System.currentTimeMillis() - currentTimeMillis)});
                descriptiveStatistics.addValue(doubleValue);
            } catch (InterruptedException e2) {
                throw e2;
            } catch (Exception e3) {
                throw new ObjectEvaluationFailedException("Could not evaluate classifier!", e3);
            }
        }
        Double valueOf = Double.valueOf(descriptiveStatistics.getMean());
        this.logger.info("Obtained score of {} for classifier {} in {}ms.", new Object[]{valueOf, classifier.getClass().getName(), Long.valueOf(System.currentTimeMillis() - currentTimeMillis)});
        return valueOf;
    }

    public ISplitBasedClassifierEvaluator<Double> getBridge() {
        return this.splitBasedEvaluator;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger.info("Switching logger of {} from {} to {}", new Object[]{this, this.logger.getName(), str});
        this.logger = LoggerFactory.getLogger(str);
        this.logger.info("Switched logger of {} to {}", this, str);
    }

    public void registerListener(Object obj) {
        this.hasListeners = true;
        this.eventBus.register(obj);
    }

    @Subscribe
    public void receiveEvent(IEvent iEvent) {
        this.eventBus.post(iEvent);
    }
}
