/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import smile.classification.AbstractClassifier;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.distance.Metric;
import smile.neighbor.CoverTree;
import smile.neighbor.KDTree;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;

public class KNN<T>
extends AbstractClassifier<T> {
    private static final long serialVersionUID = 2L;
    private final KNNSearch<T, T> knn;
    private final int[] y;
    private final int k;

    public KNN(KNNSearch<T, T> knn, int[] y, int k) {
        super(y);
        this.knn = knn;
        this.k = k;
        this.y = y;
    }

    public static <T> KNN<T> fit(T[] x, int[] y, Distance<T> distance) {
        return KNN.fit(x, y, 1, distance);
    }

    public static <T> KNN<T> fit(T[] x, int[] y, int k, Distance<T> distance) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (k < 1) {
            throw new IllegalArgumentException("Illegal k = " + k);
        }
        Object knn = distance instanceof Metric ? CoverTree.of((Object[])x, (Metric)((Metric)distance)) : LinearSearch.of((Object[])x, distance);
        return new KNN<T>(knn, y, k);
    }

    public static KNN<double[]> fit(double[][] x, int[] y) {
        return KNN.fit(x, y, 1);
    }

    public static KNN<double[]> fit(double[][] x, int[] y, int k) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (k < 1) {
            throw new IllegalArgumentException("Illegal k = " + k);
        }
        Object knn = x[0].length < 10 ? KDTree.of((double[][])x) : CoverTree.of((Object[])x, (Metric)new EuclideanDistance());
        return new KNN<double[]>((KNNSearch<double[], double[]>)knn, y, k);
    }

    @Override
    public int predict(T x) {
        Neighbor[] neighbors = this.knn.search(x, this.k);
        if (this.k == 1) {
            if (neighbors[0] == null) {
                throw new IllegalStateException("No neighbor found.");
            }
            return this.y[neighbors[0].index];
        }
        int[] count = new int[this.classes.size()];
        for (Neighbor neighbor : neighbors) {
            if (neighbor == null) continue;
            int n = this.classes.indexOf(this.y[neighbor.index]);
            count[n] = count[n] + 1;
        }
        int y = MathEx.whichMax((int[])count);
        if (count[y] == 0) {
            throw new IllegalStateException("No neighbor found.");
        }
        return this.classes.valueOf(y);
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(T x, double[] posteriori) {
        Neighbor[] neighbors = this.knn.search(x, this.k);
        if (this.k == 1) {
            if (neighbors[0] == null) {
                throw new IllegalStateException("No neighbor found.");
            }
            Arrays.fill(posteriori, 0.0);
            posteriori[this.classes.indexOf((int)this.y[neighbors[0].index])] = 1.0;
            return this.y[neighbors[0].index];
        }
        int[] count = new int[this.classes.size()];
        for (int i = 0; i < this.k; ++i) {
            int n = this.classes.indexOf(this.y[neighbors[i].index]);
            count[n] = count[n] + 1;
        }
        int y = MathEx.whichMax((int[])count);
        if (count[y] == 0) {
            throw new IllegalStateException("No neighbor found.");
        }
        for (int i = 0; i < count.length; ++i) {
            posteriori[i] = (double)count[i] / (double)this.k;
        }
        return this.classes.valueOf(y);
    }
}

