package ai.libs.jaicore.ml.classification.multiclass.reduction;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.WekaUtil;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/classification/multiclass/reduction/AllPairsTable.class */
public class AllPairsTable {
    private final Map<String, Integer> classCount;
    private final Map<String, Map<String, Double>> separabilities = new HashMap();
    private final int sum;

    public AllPairsTable(Instances instances, Instances instances2, Classifier classifier) throws Exception {
        Iterator it = SetUtil.getAllPossibleSubsetsWithSize(WekaUtil.getClassesActuallyContainedInDataset(instances), 2).iterator();
        while (it.hasNext()) {
            List list = (List) ((Collection) it.next()).stream().sorted().collect(Collectors.toList());
            String str = (String) list.get(0);
            String str2 = (String) list.get(1);
            Instances instancesOfClass = WekaUtil.getInstancesOfClass(instances, str);
            instancesOfClass.addAll(WekaUtil.getInstancesOfClass(instances, str2));
            classifier.buildClassifier(instancesOfClass);
            Instances instancesOfClass2 = WekaUtil.getInstancesOfClass(instances2, str);
            instancesOfClass2.addAll(WekaUtil.getInstancesOfClass(instances2, str2));
            Evaluation evaluation = new Evaluation(instancesOfClass);
            evaluation.evaluateModel(classifier, instancesOfClass2, new Object[0]);
            if (!this.separabilities.containsKey(str)) {
                this.separabilities.put(str, new HashMap());
            }
            this.separabilities.get(str).put(str2, Double.valueOf(evaluation.pctCorrect() / 100.0d));
        }
        this.classCount = WekaUtil.getNumberOfInstancesPerClass(instances);
        this.sum = instances.size();
    }

    public double getSeparability(String str, String str2) {
        if (str.equals(str2)) {
            throw new IllegalArgumentException("Cannot separate a class from itself.");
        }
        return str.compareTo(str2) > 0 ? getSeparability(str2, str) : this.separabilities.get(str).get(str2).doubleValue();
    }

    public double getUpperBoundOnSeparability(Collection<String> collection) {
        double d = 0.0d;
        Iterator it = SetUtil.getAllPossibleSubsetsWithSize(collection, 2).iterator();
        while (it.hasNext()) {
            Iterator it2 = ((Collection) it.next()).iterator();
            d = Math.max(d, ((1.0d - getSeparability((String) it2.next(), (String) it2.next())) * (this.classCount.get(r0).intValue() + this.classCount.get(r0).intValue())) / (1.0f * this.sum));
        }
        return 1.0d - d;
    }

    public double getAverageSeparability(Collection<String> collection) {
        DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
        Iterator it = SetUtil.getAllPossibleSubsetsWithSize(collection, 2).iterator();
        while (it.hasNext()) {
            Iterator it2 = ((Collection) it.next()).iterator();
            descriptiveStatistics.addValue(getSeparability((String) it2.next(), (String) it2.next()));
        }
        return descriptiveStatistics.getMean();
    }

    public double getMultipliedSeparability(Collection<String> collection) {
        double d = 1.0d;
        Iterator it = SetUtil.getAllPossibleSubsetsWithSize(collection, 2).iterator();
        while (it.hasNext()) {
            Iterator it2 = ((Collection) it.next()).iterator();
            d *= getSeparability((String) it2.next(), (String) it2.next());
        }
        return d;
    }
}
