package ai.libs.jaicore.ml.tsc.classifier.neighbors;

import ai.libs.jaicore.basic.algorithm.IAlgorithmConfig;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.tsc.classifier.trees.TimeSeriesTreeLearningAlgorithm;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import ai.libs.jaicore.ml.tsc.distances.ITimeSeriesDistance;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.aeonbits.owner.ConfigCache;

/* loaded from: input_file:ai/libs/jaicore/ml/tsc/classifier/neighbors/NearestNeighborClassifier.class */
public class NearestNeighborClassifier extends ASimplifiedTSClassifier<Integer> {
    protected static final NearestNeighborComparator nearestNeighborComparator = new NearestNeighborComparator(null);
    private int k;
    private ITimeSeriesDistance distanceMeasure;
    private VoteType voteType;
    protected double[][] values;
    protected double[][] timestamps;
    protected int[] targets;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.libs.jaicore.ml.tsc.classifier.neighbors.NearestNeighborClassifier$1, reason: invalid class name */
    /* loaded from: input_file:ai/libs/jaicore/ml/tsc/classifier/neighbors/NearestNeighborClassifier$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$libs$jaicore$ml$tsc$classifier$neighbors$NearestNeighborClassifier$VoteType = new int[VoteType.values().length];

        static {
            try {
                $SwitchMap$ai$libs$jaicore$ml$tsc$classifier$neighbors$NearestNeighborClassifier$VoteType[VoteType.WEIGHTED_STEPWISE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$ml$tsc$classifier$neighbors$NearestNeighborClassifier$VoteType[VoteType.WEIGHTED_PROPORTIONAL_TO_DISTANCE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$libs$jaicore$ml$tsc$classifier$neighbors$NearestNeighborClassifier$VoteType[VoteType.MAJORITY.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/libs/jaicore/ml/tsc/classifier/neighbors/NearestNeighborClassifier$NearestNeighborComparator.class */
    public static class NearestNeighborComparator implements Comparator<Pair<Integer, Double>> {
        private NearestNeighborComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Pair<Integer, Double> pair, Pair<Integer, Double> pair2) {
            return (-1) * ((Double) pair.getY()).compareTo((Double) pair2.getY());
        }

        /* synthetic */ NearestNeighborComparator(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    /* loaded from: input_file:ai/libs/jaicore/ml/tsc/classifier/neighbors/NearestNeighborClassifier$VoteType.class */
    public enum VoteType {
        MAJORITY,
        WEIGHTED_STEPWISE,
        WEIGHTED_PROPORTIONAL_TO_DISTANCE
    }

    public NearestNeighborClassifier(int i, ITimeSeriesDistance iTimeSeriesDistance, VoteType voteType) {
        if (iTimeSeriesDistance == null) {
            throw new IllegalArgumentException("Distance measure must not be null");
        }
        if (voteType == null) {
            throw new IllegalArgumentException("Vote type must not be null.");
        }
        this.distanceMeasure = iTimeSeriesDistance;
        this.k = i;
        this.voteType = voteType;
    }

    public NearestNeighborClassifier(int i, ITimeSeriesDistance iTimeSeriesDistance) {
        this(i, iTimeSeriesDistance, VoteType.MAJORITY);
    }

    public NearestNeighborClassifier(ITimeSeriesDistance iTimeSeriesDistance) {
        this(1, iTimeSeriesDistance, VoteType.MAJORITY);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSClassifier
    public Integer predict(double[] dArr) throws PredictionException {
        if (dArr == null) {
            throw new IllegalArgumentException("Instance to predict must not be null.");
        }
        return Integer.valueOf(calculatePrediction(dArr));
    }

    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSClassifier
    public List<Integer> predict(TimeSeriesDataset timeSeriesDataset) throws PredictionException {
        double[][] checkWhetherPredictionIsPossible = checkWhetherPredictionIsPossible(timeSeriesDataset);
        ArrayList arrayList = new ArrayList(timeSeriesDataset.getNumberOfInstances());
        for (double[] dArr : checkWhetherPredictionIsPossible) {
            arrayList.add(Integer.valueOf(calculatePrediction(dArr)));
        }
        return arrayList;
    }

    protected int calculatePrediction(double[] dArr) {
        return vote(calculateNearestNeigbors(dArr));
    }

    protected PriorityQueue<Pair<Integer, Double>> calculateNearestNeigbors(double[] dArr) {
        int length = this.values.length;
        PriorityQueue<Pair<Integer, Double>> priorityQueue = new PriorityQueue<>(nearestNeighborComparator);
        for (int i = 0; i < length; i++) {
            priorityQueue.add(new Pair<>(Integer.valueOf(this.targets[i]), Double.valueOf(this.distanceMeasure.distance(dArr, this.values[i]))));
            if (priorityQueue.size() > this.k) {
                priorityQueue.poll();
            }
        }
        return priorityQueue;
    }

    protected int vote(PriorityQueue<Pair<Integer, Double>> priorityQueue) {
        switch (AnonymousClass1.$SwitchMap$ai$libs$jaicore$ml$tsc$classifier$neighbors$NearestNeighborClassifier$VoteType[this.voteType.ordinal()]) {
            case TimeSeriesTreeLearningAlgorithm.USE_BIAS_CORRECTION /* 1 */:
                return voteWeightedStepwise(priorityQueue);
            case 2:
                return voteWeightedProportionalToDistance(priorityQueue);
            case 3:
            default:
                return voteMajority(priorityQueue);
        }
    }

    protected int voteWeightedStepwise(PriorityQueue<Pair<Integer, Double>> priorityQueue) {
        HashMap hashMap = new HashMap();
        int i = 1;
        while (!priorityQueue.isEmpty()) {
            Integer num = (Integer) priorityQueue.poll().getX();
            Integer num2 = (Integer) hashMap.get(num);
            if (num2 == null) {
                hashMap.put(num, Integer.valueOf(i));
            } else {
                hashMap.put(num, Integer.valueOf(num2.intValue() + i));
            }
            i++;
        }
        Integer num3 = Integer.MIN_VALUE;
        Integer num4 = -1;
        for (Map.Entry entry : hashMap.entrySet()) {
            int intValue = ((Integer) entry.getKey()).intValue();
            int intValue2 = ((Integer) entry.getValue()).intValue();
            if (intValue2 > num3.intValue()) {
                num3 = Integer.valueOf(intValue2);
                num4 = Integer.valueOf(intValue);
            }
        }
        return num4.intValue();
    }

    protected int voteWeightedProportionalToDistance(PriorityQueue<Pair<Integer, Double>> priorityQueue) {
        HashMap hashMap = new HashMap();
        Iterator<Pair<Integer, Double>> it = priorityQueue.iterator();
        while (it.hasNext()) {
            Pair<Integer, Double> next = it.next();
            Integer num = (Integer) next.getX();
            double doubleValue = ((Double) next.getY()).doubleValue();
            Double d = (Double) hashMap.get(num);
            if (d == null) {
                hashMap.put(num, Double.valueOf(1.0d / doubleValue));
            } else {
                hashMap.put(num, Double.valueOf(d.doubleValue() + (1.0d / doubleValue)));
            }
        }
        Double valueOf = Double.valueOf(Double.MIN_VALUE);
        Integer num2 = -1;
        for (Map.Entry entry : hashMap.entrySet()) {
            int intValue = ((Integer) entry.getKey()).intValue();
            double doubleValue2 = ((Double) entry.getValue()).doubleValue();
            if (doubleValue2 > valueOf.doubleValue()) {
                valueOf = Double.valueOf(doubleValue2);
                num2 = Integer.valueOf(intValue);
            }
        }
        return num2.intValue();
    }

    protected int voteMajority(PriorityQueue<Pair<Integer, Double>> priorityQueue) {
        HashMap hashMap = new HashMap();
        Iterator<Pair<Integer, Double>> it = priorityQueue.iterator();
        while (it.hasNext()) {
            Integer num = (Integer) it.next().getX();
            Integer num2 = (Integer) hashMap.get(num);
            if (num2 == null) {
                hashMap.put(num, 1);
            } else {
                hashMap.put(num, Integer.valueOf(num2.intValue() + 1));
            }
        }
        Integer num3 = Integer.MIN_VALUE;
        Integer num4 = -1;
        for (Map.Entry entry : hashMap.entrySet()) {
            int intValue = ((Integer) entry.getKey()).intValue();
            int intValue2 = ((Integer) entry.getValue()).intValue();
            if (intValue2 > num3.intValue()) {
                num3 = Integer.valueOf(intValue2);
                num4 = Integer.valueOf(intValue);
            }
        }
        return num4.intValue();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setValues(double[][] dArr) {
        if (dArr == null) {
            throw new IllegalArgumentException("Values must not be null");
        }
        this.values = dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setTimestamps(double[][] dArr) {
        this.timestamps = dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setTargets(int[] iArr) {
        if (iArr == null) {
            throw new IllegalArgumentException("Targets must not be null");
        }
        this.targets = iArr;
    }

    public int getK() {
        return this.k;
    }

    public VoteType getVoteType() {
        return this.voteType;
    }

    public ITimeSeriesDistance getDistanceMeasure() {
        return this.distanceMeasure;
    }

    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSClassifier
    public NearestNeighborLearningAlgorithm getLearningAlgorithm(TimeSeriesDataset timeSeriesDataset) {
        return new NearestNeighborLearningAlgorithm(ConfigCache.getOrCreate(IAlgorithmConfig.class, new Map[0]), this, timeSeriesDataset);
    }
}
