/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.stream.IntStream;
import smile.data.Dataset;
import smile.data.measure.Measure;
import smile.data.measure.NominalScale;
import smile.data.vector.BaseVector;
import smile.math.MathEx;
import smile.util.IntSet;

public class ClassLabels
implements Serializable {
    private static final long serialVersionUID = 2L;
    public final int k;
    public final IntSet classes;
    public final int[] y;
    public final int[] ni;
    public final double[] priori;

    public ClassLabels(int k, int[] y, IntSet classes) {
        this.k = k;
        this.y = y;
        this.classes = classes;
        this.ni = ClassLabels.count(y, k);
        this.priori = new double[k];
        double n = y.length;
        for (int i = 0; i < k; ++i) {
            this.priori[i] = (double)this.ni[i] / n;
        }
    }

    public NominalScale scale() {
        String[] values = new String[this.classes.size()];
        for (int i = 0; i < this.classes.size(); ++i) {
            values[i] = String.valueOf(this.classes.valueOf(i));
        }
        return new NominalScale(values);
    }

    public int[] indexOf(int[] y) {
        int[] x = new int[y.length];
        for (int i = 0; i < y.length; ++i) {
            x[i] = this.classes.indexOf(y[i]);
        }
        return x;
    }

    public static ClassLabels fit(Dataset<?, Integer> data) {
        int n = data.size();
        int[] y = new int[n];
        for (int i = 0; i < n; ++i) {
            y[i] = (Integer)data.get(i).y();
        }
        return ClassLabels.fit(y);
    }

    public static ClassLabels fit(int[] y) {
        int[] labels = MathEx.unique((int[])y);
        Arrays.sort(labels);
        int k = labels.length;
        if (k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        IntSet encoder = new IntSet(labels);
        if (labels[0] == 0 && labels[k - 1] == k - 1) {
            return new ClassLabels(k, y, encoder);
        }
        return new ClassLabels(k, Arrays.stream(y).map(arg_0 -> ((IntSet)encoder).indexOf(arg_0)).toArray(), encoder);
    }

    public static ClassLabels fit(BaseVector<?, ?, ?> response) {
        int[] y = response.toIntArray();
        Measure measure = response.measure();
        if (measure instanceof NominalScale) {
            NominalScale scale = (NominalScale)measure;
            int k = scale.size();
            int[] labels = IntStream.range(0, k).toArray();
            IntSet encoder = new IntSet(labels);
            return new ClassLabels(k, y, encoder);
        }
        return ClassLabels.fit(y);
    }

    private static int[] count(int[] y, int k) {
        int[] ni = new int[k];
        int[] nArray = y;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int yi;
            int n2 = yi = nArray[i];
            ni[n2] = ni[n2] + 1;
        }
        return ni;
    }
}

