package ai.libs.jaicore.ml.classification.multiclass.reduction;

import ai.libs.jaicore.ml.WekaUtil;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.meta.MultiClassClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.WekaException;

/* loaded from: input_file:ai/libs/jaicore/ml/classification/multiclass/reduction/MCTreeNode.class */
public class MCTreeNode implements Classifier, ITreeClassifier, Serializable, Iterable<MCTreeNode> {
    private static final long serialVersionUID = 8873192747068561266L;
    private EMCNodeType nodeType;
    private List<MCTreeNode> children;
    private Classifier classifier;
    private String classifierID;
    private final List<Integer> containedClasses;
    private boolean trained;
    private boolean fromCache;
    public static AtomicInteger cacheRetrievals;
    private static Map<String, Classifier> classifierCacheMap;
    private static Lock classifierCacheMapLock;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MCTreeNode(Classifier classifier, Classifier classifier2, String str) {
        this.children = new ArrayList();
        this.trained = false;
        this.fromCache = false;
        this.containedClasses = new ArrayList();
    }

    public MCTreeNode(List<Integer> list) {
        this.children = new ArrayList();
        this.trained = false;
        this.fromCache = false;
        this.containedClasses = list;
    }

    public MCTreeNode(List<Integer> list, EMCNodeType eMCNodeType, String str) throws Exception {
        this(list, eMCNodeType, AbstractClassifier.forName(str, (String[]) null));
    }

    public MCTreeNode(List<Integer> list, EMCNodeType eMCNodeType, Classifier classifier) {
        this(list);
        setNodeType(eMCNodeType);
        setBaseClassifier(classifier);
    }

    public EMCNodeType getNodeType() {
        return this.nodeType;
    }

    public void addChild(MCTreeNode mCTreeNode) {
        if (mCTreeNode.getNodeType() != EMCNodeType.MERGE) {
            this.children.add(mCTreeNode);
            return;
        }
        Iterator<MCTreeNode> it = mCTreeNode.getChildren().iterator();
        while (it.hasNext()) {
            this.children.add(it.next());
        }
    }

    public List<MCTreeNode> getChildren() {
        return this.children;
    }

    public Collection<Integer> getContainedClasses() {
        return this.containedClasses;
    }

    public boolean isCompletelyConfigured() {
        if (this.classifier == null || this.children.isEmpty()) {
            return false;
        }
        Iterator<MCTreeNode> it = this.children.iterator();
        while (it.hasNext()) {
            if (!it.next().isCompletelyConfigured()) {
                return false;
            }
        }
        return true;
    }

    public void buildClassifier(Instances instances) throws Exception {
        if (!$assertionsDisabled && getNodeType() == EMCNodeType.MERGE) {
            throw new AssertionError("MERGE node detected while building classifier. This must not happen!");
        }
        if (!$assertionsDisabled && instances.isEmpty()) {
            throw new AssertionError("Cannot train MCTree with empty set of instances.");
        }
        if (!$assertionsDisabled && this.children.isEmpty()) {
            throw new AssertionError("Cannot train MCTree without children");
        }
        ArrayList arrayList = new ArrayList();
        IntStream.range(0, this.children.size()).forEach(i -> {
            arrayList.add(new HashSet());
        });
        int i2 = 0;
        Iterator<MCTreeNode> it = this.children.iterator();
        while (it.hasNext()) {
            Iterator<Integer> it2 = it.next().getContainedClasses().iterator();
            while (it2.hasNext()) {
                ((Set) arrayList.get(i2)).add(instances.classAttribute().value(it2.next().intValue()));
            }
            i2++;
        }
        String str = this.classifier.getClass().getName() + "#" + arrayList + "#" + instances.size() + "#" + new HashCodeBuilder().append(instances.toString()).toHashCode();
        Instances mergeClassesOfInstances = WekaUtil.mergeClassesOfInstances(instances, arrayList);
        try {
            AbstractClassifier.makeCopy(classifierCacheMap.get(str));
            this.fromCache = true;
            classifierCacheMapLock.unlock();
            if (0 != 0) {
                this.classifier = null;
            } else {
                try {
                    this.classifier.buildClassifier(mergeClassesOfInstances);
                } catch (WekaException e) {
                    this.classifier = new ZeroR();
                    this.classifier.buildClassifier(mergeClassesOfInstances);
                }
                classifierCacheMapLock.lock();
                try {
                    classifierCacheMap.put(str, this.classifier);
                    classifierCacheMapLock.unlock();
                } finally {
                }
            }
            ((Stream) this.children.stream().parallel()).forEach(mCTreeNode -> {
                try {
                    mCTreeNode.buildClassifier(instances);
                } catch (Exception e2) {
                    e2.printStackTrace();
                }
            });
            this.trained = true;
        } finally {
        }
    }

    @Override // ai.libs.jaicore.ml.classification.multiclass.reduction.ITreeClassifier
    public double classifyInstance(Instance instance) throws Exception {
        double d = -1.0d;
        double d2 = 0.0d;
        double[] distributionForInstance = distributionForInstance(instance);
        for (int i = 0; i < distributionForInstance.length; i++) {
            double d3 = distributionForInstance[i];
            if (d3 > d2) {
                d2 = d3;
                d = i;
            }
        }
        return this.containedClasses.get((int) d).intValue();
    }

    public void distributionForInstance(Instance instance, double[] dArr) throws Exception {
        Instance refactoredInstance = WekaUtil.getRefactoredInstance(instance, (List) IntStream.range(0, this.children.size()).mapToObj(i -> {
            return i + ".0";
        }).collect(Collectors.toList()));
        double[] dArr2 = new double[this.containedClasses.size()];
        double[] distributionForInstance = this.classifier.distributionForInstance(refactoredInstance);
        for (MCTreeNode mCTreeNode : this.children) {
            mCTreeNode.distributionForInstance(instance, dArr);
            int indexOf = this.children.indexOf(mCTreeNode);
            Iterator<Integer> it = mCTreeNode.getContainedClasses().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                dArr[intValue] = dArr[intValue] * distributionForInstance[indexOf];
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (!$assertionsDisabled && !this.trained) {
            throw new AssertionError("Cannot get distribution from untrained classifier " + toStringWithOffset());
        }
        double[] dArr = new double[this.containedClasses.size()];
        distributionForInstance(instance, dArr);
        return dArr;
    }

    public Capabilities getCapabilities() {
        return this.classifier.getCapabilities();
    }

    @Override // ai.libs.jaicore.ml.classification.multiclass.reduction.ITreeClassifier
    public int getHeight() {
        return 1 + this.children.stream().map(mCTreeNode -> {
            return Integer.valueOf(mCTreeNode.getHeight());
        }).mapToInt(num -> {
            return num.intValue();
        }).max().getAsInt();
    }

    @Override // ai.libs.jaicore.ml.classification.multiclass.reduction.ITreeClassifier
    public int getDepthOfFirstCommonParent(List<Integer> list) {
        for (MCTreeNode mCTreeNode : this.children) {
            if (mCTreeNode.getContainedClasses().containsAll(list)) {
                return 1 + mCTreeNode.getDepthOfFirstCommonParent(list);
            }
        }
        return 1;
    }

    public static void clearCache() {
        classifierCacheMap.clear();
    }

    public static Map<String, Classifier> getClassifierCache() {
        return classifierCacheMap;
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public void setBaseClassifier(Classifier classifier) {
        if (!$assertionsDisabled && classifier == null) {
            throw new AssertionError("Cannot set null classifier!");
        }
        this.classifierID = classifier.getClass().getName();
        switch (this.nodeType) {
            case ONEVSREST:
                MultiClassClassifier multiClassClassifier = new MultiClassClassifier();
                multiClassClassifier.setClassifier(classifier);
                this.classifier = multiClassClassifier;
                return;
            case ALLPAIRS:
                MultiClassClassifier multiClassClassifier2 = new MultiClassClassifier();
                try {
                    multiClassClassifier2.setOptions(new String[]{"-M", "3"});
                } catch (Exception e) {
                    e.printStackTrace();
                }
                multiClassClassifier2.setClassifier(classifier);
                this.classifier = multiClassClassifier2;
                return;
            case DIRECT:
                this.classifier = classifier;
                return;
            default:
                return;
        }
    }

    public void setNodeType(EMCNodeType eMCNodeType) {
        this.nodeType = eMCNodeType;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("(");
        sb.append(this.classifierID);
        sb.append(":");
        sb.append(this.nodeType);
        sb.append(")");
        sb.append("{");
        boolean z = true;
        for (MCTreeNode mCTreeNode : this.children) {
            if (z) {
                z = false;
            } else {
                sb.append(",");
            }
            sb.append(mCTreeNode);
        }
        sb.append("}");
        return sb.toString();
    }

    public String toStringWithOffset() {
        return toStringWithOffset("");
    }

    public String toStringWithOffset(String str) {
        StringBuilder sb = new StringBuilder();
        sb.append(str);
        sb.append("(");
        sb.append(getContainedClasses());
        sb.append(":");
        sb.append(this.classifierID);
        sb.append(":");
        sb.append(this.nodeType);
        sb.append(") {");
        boolean z = true;
        for (MCTreeNode mCTreeNode : this.children) {
            if (z) {
                z = false;
            } else {
                sb.append(",");
            }
            sb.append("\n");
            sb.append(mCTreeNode.toStringWithOffset(str + "  "));
        }
        sb.append("\n");
        sb.append(str);
        sb.append("}");
        return sb.toString();
    }

    @Override // java.lang.Iterable
    public Iterator<MCTreeNode> iterator() {
        return new Iterator<MCTreeNode>() { // from class: ai.libs.jaicore.ml.classification.multiclass.reduction.MCTreeNode.1
            int currentlyTraversedChild = -1;
            Iterator<MCTreeNode> childIterator = null;

            @Override // java.util.Iterator
            public boolean hasNext() {
                if (this.currentlyTraversedChild < 0) {
                    return true;
                }
                if (MCTreeNode.this.children.isEmpty()) {
                    return false;
                }
                if (this.childIterator == null) {
                    this.childIterator = ((MCTreeNode) MCTreeNode.this.children.get(this.currentlyTraversedChild)).iterator();
                }
                if (this.childIterator.hasNext()) {
                    return true;
                }
                if (this.currentlyTraversedChild == MCTreeNode.this.children.size() - 1) {
                    return false;
                }
                this.currentlyTraversedChild++;
                this.childIterator = ((MCTreeNode) MCTreeNode.this.children.get(this.currentlyTraversedChild)).iterator();
                return this.childIterator.hasNext();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public MCTreeNode next() {
                if (this.currentlyTraversedChild != -1) {
                    return this.childIterator.next();
                }
                this.currentlyTraversedChild++;
                return MCTreeNode.this;
            }
        };
    }

    static {
        $assertionsDisabled = !MCTreeNode.class.desiredAssertionStatus();
        cacheRetrievals = new AtomicInteger();
        classifierCacheMap = new HashMap();
        classifierCacheMapLock = new ReentrantLock();
    }
}
