/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.ClassLabels;
import smile.clustering.CentroidClustering;
import smile.math.MathEx;
import smile.math.distance.HammingDistance;

public class KModes
extends CentroidClustering<int[], int[]> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(KModes.class);

    public KModes(double distortion, int[][] centroids, int[] y) {
        super(distortion, (T[])centroids, y);
    }

    @Override
    public double distance(int[] x, int[] y) {
        return HammingDistance.d((int[])x, (int[])y);
    }

    public static KModes fit(int[][] data, int k) {
        return KModes.fit(data, k, 100);
    }

    public static KModes fit(int[][] data, int k, int maxIter) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.length;
        int d = data[0].length;
        ClassLabels[] codec = (ClassLabels[])IntStream.range(0, d).parallel().mapToObj(j -> {
            int[] x = new int[n];
            for (int i = 0; i < n; ++i) {
                x[i] = data[i][j];
            }
            return ClassLabels.fit(x);
        }).toArray(ClassLabels[]::new);
        int[] y = new int[n];
        int[][] medoids = new int[k][];
        int[][] centroids = new int[k][d];
        double distortion = MathEx.sum((double[])KModes.seed(data, medoids, y, HammingDistance::d));
        logger.info(String.format("Distortion after initialization: %d", (int)distortion));
        double diff = 2.147483647E9;
        for (int iter = 1; iter <= maxIter && diff > 0.0; ++iter) {
            KModes.updateCentroids(centroids, data, y, codec);
            double wcss = KModes.assign(y, data, centroids, HammingDistance::d);
            logger.info(String.format("Distortion after %3d iterations: %d", iter, (int)wcss));
            diff = distortion - wcss;
            distortion = wcss;
        }
        if (diff > 0.0) {
            KModes.updateCentroids(centroids, data, y, codec);
        }
        return new KModes(distortion, centroids, y);
    }

    private static void updateCentroids(int[][] centroids, int[][] data, int[] y, ClassLabels[] codec) {
        int n = data.length;
        int k = centroids.length;
        int d = centroids[0].length;
        IntStream.range(0, k).parallel().forEach(cluster -> {
            int[] centroid = centroids[cluster];
            for (int j = 0; j < d; ++j) {
                int[] count = new int[codec[j].k];
                int[] x = codec[j].y;
                for (int i = 0; i < n; ++i) {
                    if (y[i] != cluster) continue;
                    int n2 = x[i];
                    count[n2] = count[n2] + 1;
                }
                centroid[j] = codec[j].labels.valueOf(MathEx.whichMax((int[])count));
            }
        });
    }
}

