/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.BBDTree;
import smile.clustering.CentroidClustering;
import smile.clustering.KMeans;
import smile.math.MathEx;
import smile.sort.QuickSort;

public class XMeans
extends CentroidClustering<double[], double[]> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(XMeans.class);
    private static final double LOG2PI = Math.log(Math.PI * 2);

    public XMeans(double distortion, double[][] centroids, int[] y) {
        super(distortion, (T[])centroids, y);
    }

    @Override
    protected double distance(double[] x, double[] y) {
        return MathEx.squaredDistance((double[])x, (double[])y);
    }

    public static XMeans fit(double[][] data, int kmax) {
        return XMeans.fit(data, kmax, 100, 1.0E-4);
    }

    public static XMeans fit(double[][] data, int kmax, int maxIter, double tol) {
        if (kmax < 2) {
            throw new IllegalArgumentException("Invalid parameter kmax = " + kmax);
        }
        int n = data.length;
        int d = data[0].length;
        int k = 1;
        int[] size = new int[kmax];
        size[0] = n;
        int[] y = new int[n];
        double[][] sum = new double[kmax][d];
        double[] mean = MathEx.colMeans((double[][])data);
        Object centroids = new double[][]{mean};
        double distortion = ((Stream)Arrays.stream(data).parallel()).mapToDouble(x -> MathEx.squaredDistance((double[])x, (double[])mean)).sum();
        double[] distortions = new double[kmax];
        distortions[0] = distortion;
        BBDTree bbd = new BBDTree(data);
        KMeans[] kmeans = new KMeans[kmax];
        ArrayList<double[]> centers = new ArrayList<double[]>();
        while (k < kmax) {
            centers.clear();
            double[] score = new double[k];
            for (int i = 0; i < k; ++i) {
                int ni = size[i];
                if (ni < 25) {
                    logger.info("Cluster {} too small to split: {} observations", (Object)i, (Object)ni);
                    score[i] = 0.0;
                    kmeans[i] = null;
                    continue;
                }
                double[][] subset = new double[ni][];
                int l = 0;
                for (int j = 0; j < n; ++j) {
                    if (y[j] != i) continue;
                    subset[l++] = data[j];
                }
                kmeans[i] = KMeans.fit(subset, 2, maxIter, tol);
                double newBIC = XMeans.bic(2, ni, d, kmeans[i].distortion, kmeans[i].size);
                double oldBIC = XMeans.bic(ni, d, distortions[i]);
                score[i] = newBIC - oldBIC;
                logger.info("Cluster {} BIC: {}, BIC after split: {}, improvement: {}", new Object[]{i, oldBIC, newBIC, score[i]});
            }
            int[] index = QuickSort.sort((double[])score);
            for (int i = 0; i < k; ++i) {
                if (!(score[i] <= 0.0)) continue;
                centers.add(centroids[index[i]]);
            }
            int m = centers.size();
            int i = k;
            while (--i >= 0) {
                if (!(score[i] > 0.0)) continue;
                if (centers.size() + i - m + 1 < kmax) {
                    logger.info("Split cluster {}", (Object)index[i]);
                    centers.add(((double[][])kmeans[index[i]].centroids)[0]);
                    centers.add(((double[][])kmeans[index[i]].centroids)[1]);
                    continue;
                }
                centers.add(centroids[index[i]]);
            }
            if (centers.size() == k) {
                logger.info("No more split. Finish with {} clusters", (Object)k);
                break;
            }
            k = centers.size();
            centroids = (double[][])centers.toArray((T[])new double[k][]);
            double diff = Double.MAX_VALUE;
            for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
                double wcss = bbd.clustering((double[][])centroids, sum, size, y);
                diff = distortion - wcss;
                distortion = wcss;
            }
            Arrays.fill(distortions, 0.0);
            IntStream.range(0, k).parallel().forEach(cluster -> {
                double[] centroid = (double[])centers.get(cluster);
                for (int i = 0; i < n; ++i) {
                    if (y[i] != cluster) continue;
                    int n2 = cluster;
                    distortions[n2] = distortions[n2] + MathEx.squaredDistance((double[])data[i], (double[])centroid);
                }
            });
            logger.info("Distortion with {} clusters: {}", (Object)k, (Object)distortion);
        }
        return new XMeans(distortion, (double[][])centroids, y);
    }

    private static double bic(int n, int d, double distortion) {
        double variance = distortion / (double)(n - 1);
        double p1 = (double)(-n) * LOG2PI;
        double p2 = (double)(-n * d) * Math.log(variance);
        double p3 = -(n - 1);
        double L = (p1 + p2 + p3) / 2.0;
        int numParameters = d + 1;
        return L - 0.5 * (double)numParameters * Math.log(n);
    }

    private static double bic(int k, int n, int d, double distortion, int[] clusterSize) {
        double variance = distortion / (double)(n - k);
        double L = 0.0;
        for (int i = 0; i < k; ++i) {
            L += XMeans.logLikelihood(k, n, clusterSize[i], d, variance);
        }
        int numParameters = k + k * d;
        return L - 0.5 * (double)numParameters * Math.log(n);
    }

    private static double logLikelihood(int k, int n, int ni, int d, double variance) {
        double p1 = (double)(-ni) * LOG2PI;
        double p2 = (double)(-ni * d) * Math.log(variance);
        double p3 = -(ni - k);
        double p4 = (double)ni * Math.log(ni);
        double p5 = (double)(-ni) * Math.log(n);
        return (p1 + p2 + p3) / 2.0 + p4 + p5;
    }
}

