/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.BBDTree;
import smile.clustering.CentroidClustering;
import smile.clustering.KMeans;
import smile.math.MathEx;
import smile.sort.QuickSort;
import smile.stat.distribution.GaussianDistribution;

public class GMeans
extends CentroidClustering<double[], double[]> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(GMeans.class);

    public GMeans(double distortion, double[][] centroids, int[] y) {
        super(distortion, (T[])centroids, y);
    }

    @Override
    public double distance(double[] x, double[] y) {
        return MathEx.squaredDistance((double[])x, (double[])y);
    }

    public static GMeans fit(double[][] data, int kmax) {
        return GMeans.fit(data, kmax, 100, 1.0E-4);
    }

    public static GMeans fit(double[][] data, int kmax, int maxIter, double tol) {
        if (kmax < 2) {
            throw new IllegalArgumentException("Invalid parameter kmax = " + kmax);
        }
        int n = data.length;
        int d = data[0].length;
        int k = 1;
        int[] size = new int[kmax];
        size[0] = n;
        int[] y = new int[n];
        double[][] sum = new double[kmax][d];
        double[] mean = MathEx.colMeans((double[][])data);
        Object centroids = new double[][]{mean};
        double distortion = ((Stream)Arrays.stream(data).parallel()).mapToDouble(x -> MathEx.squaredDistance((double[])x, (double[])mean)).sum();
        BBDTree bbd = new BBDTree(data);
        KMeans[] kmeans = new KMeans[kmax];
        ArrayList<double[]> centers = new ArrayList<double[]>();
        while (k < kmax) {
            centers.clear();
            double[] score = new double[k];
            for (int i = 0; i < k; ++i) {
                int ni = size[i];
                if (ni < 25) {
                    logger.info("Cluster {} too small to split: {} observations", (Object)i, (Object)ni);
                    score[i] = 0.0;
                    kmeans[i] = null;
                    continue;
                }
                double[][] subset = new double[ni][];
                int l = 0;
                for (int j = 0; j < n; ++j) {
                    if (y[j] != i) continue;
                    subset[l++] = data[j];
                }
                kmeans[i] = KMeans.fit(subset, 2, maxIter, tol);
                double[] v = new double[d];
                for (int j = 0; j < d; ++j) {
                    v[j] = ((double[][])kmeans[i].centroids)[0][j] - ((double[][])kmeans[i].centroids)[1][j];
                }
                double vp = MathEx.dot((double[])v, (double[])v);
                double[] x2 = new double[ni];
                for (int j = 0; j < x2.length; ++j) {
                    x2[j] = MathEx.dot((double[])subset[j], (double[])v) / vp;
                }
                MathEx.standardize((double[])x2);
                score[i] = GMeans.AndersonDarling(x2);
                logger.info(String.format("Cluster %d Anderson-Darling adjusted test statistic: %7.4f", i, score[i]));
            }
            int[] index = QuickSort.sort((double[])score);
            for (int i = 0; i < k; ++i) {
                if (!(score[i] <= 1.8692)) continue;
                centers.add(centroids[index[i]]);
            }
            int m = centers.size();
            int i = k;
            while (--i >= 0) {
                if (!(score[i] > 1.8692)) continue;
                if (centers.size() + i - m + 1 < kmax) {
                    logger.info("Split cluster {}", (Object)index[i]);
                    centers.add(((double[][])kmeans[index[i]].centroids)[0]);
                    centers.add(((double[][])kmeans[index[i]].centroids)[1]);
                    continue;
                }
                centers.add(centroids[index[i]]);
            }
            if (centers.size() == k) {
                logger.info("No more split. Finish with {} clusters", (Object)k);
                break;
            }
            k = centers.size();
            centroids = (double[][])centers.toArray((T[])new double[k][]);
            double diff = Double.MAX_VALUE;
            for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
                double wcss = bbd.clustering((double[][])centroids, sum, size, y);
                diff = distortion - wcss;
                distortion = wcss;
            }
            logger.info(String.format("Distortion with %d clusters: %.5f%n", k, distortion));
        }
        return new GMeans(distortion, (double[][])centroids, y);
    }

    private static double AndersonDarling(double[] x) {
        int n = x.length;
        GaussianDistribution gaussian = GaussianDistribution.getInstance();
        Arrays.sort(x);
        for (int i = 0; i < n; ++i) {
            x[i] = gaussian.cdf(x[i]);
            if (x[i] == 0.0) {
                x[i] = 1.0E-7;
            }
            if (x[i] != 1.0) continue;
            x[i] = 0.9999999;
        }
        double A = 0.0;
        for (int i = 0; i < n; ++i) {
            A -= (double)(2 * i + 1) * (Math.log(x[i]) + Math.log(1.0 - x[n - i - 1]));
        }
        A = A / (double)n - (double)n;
        return A *= 1.0 + 4.0 / (double)n - 25.0 / (double)(n * n);
    }
}

