/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

public class DeterministicAnnealing
extends CentroidClustering<double[], double[]> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(DeterministicAnnealing.class);

    public DeterministicAnnealing(double distortion, double[][] centroids, int[] y) {
        super(distortion, (T[])centroids, y);
    }

    @Override
    protected double distance(double[] x, double[] y) {
        return MathEx.squaredDistance((double[])x, (double[])y);
    }

    public static DeterministicAnnealing fit(double[][] data, int Kmax) {
        return DeterministicAnnealing.fit(data, Kmax, 0.9, 100, 1.0E-4, 0.01);
    }

    public static DeterministicAnnealing fit(double[][] data, int Kmax, double alpha, int maxIter, double tol, double splitTol) {
        int j;
        int i;
        int i2;
        if (alpha <= 0.0 || alpha >= 1.0) {
            throw new IllegalArgumentException("Invalid alpha: " + alpha);
        }
        int n = data.length;
        int d = data[0].length;
        double[][] centroids = new double[2 * Kmax][d];
        double[][] posteriori = new double[n][2 * Kmax];
        double[] priori = new double[2 * Kmax];
        centroids[0] = MathEx.colMeans((double[][])data);
        for (int i3 = 0; i3 < d; ++i3) {
            centroids[1][i3] = centroids[0][i3] * 1.01;
        }
        priori[1] = 0.5;
        priori[0] = 0.5;
        Matrix cov = Matrix.of((double[][])MathEx.cov((double[][])data, (double[])centroids[0]));
        double[] ev = new double[d];
        Arrays.fill(ev, 1.0);
        double lambda = cov.eigen(ev, 0.0, 1.0E-4, Math.max(20, 2 * cov.nrow()));
        double T = 2.0 * lambda + 0.01;
        int k = 2;
        boolean stop = false;
        boolean split = false;
        while (!stop) {
            DeterministicAnnealing.update(data, T, k, centroids, posteriori, priori, maxIter, tol);
            if (k >= 2 * Kmax && split) {
                stop = true;
            }
            int currentK = k;
            for (i2 = 0; i2 < currentK; i2 += 2) {
                int j2;
                double norm = 0.0;
                for (j2 = 0; j2 < d; ++j2) {
                    double diff = centroids[i2][j2] - centroids[i2 + 1][j2];
                    norm += diff * diff;
                }
                if (norm > splitTol) {
                    if (k < 2 * Kmax) {
                        for (j2 = 0; j2 < d; ++j2) {
                            centroids[k][j2] = centroids[i2 + 1][j2];
                            centroids[k + 1][j2] = centroids[i2 + 1][j2] * 1.01;
                        }
                        priori[k] = priori[i2 + 1] / 2.0;
                        priori[k + 1] = priori[i2 + 1] / 2.0;
                        priori[i2] = priori[i2] / 2.0;
                        priori[i2 + 1] = priori[i2] / 2.0;
                        k += 2;
                    }
                    if (currentK >= 2 * Kmax) {
                        split = true;
                    }
                }
                for (j2 = 0; j2 < d; ++j2) {
                    centroids[i2 + 1][j2] = centroids[i2][j2] * 1.01;
                }
            }
            if (split) {
                T /= alpha;
            } else if (k - currentK > 2) {
                T /= alpha;
                alpha += 5.0 * Math.pow(10.0, Math.log10(1.0 - alpha) - 1.0);
            } else {
                if (k > currentK && k == 2 * Kmax - 2) {
                    alpha += 5.0 * Math.pow(10.0, Math.log10(1.0 - alpha) - 1.0);
                }
                T *= alpha;
            }
            if (!(alpha >= 1.0)) continue;
            break;
        }
        double[][] centers = new double[k /= 2][];
        for (i2 = 0; i2 < k; ++i2) {
            centers[i2] = centroids[2 * i2];
        }
        int[] y = new int[n];
        double distortion = DeterministicAnnealing.assign(y, data, centers, MathEx::squaredDistance);
        int[] size = new int[k];
        centroids = new double[k][d];
        for (i = 0; i < n; ++i) {
            int n2 = y[i];
            size[n2] = size[n2] + 1;
            for (j = 0; j < d; ++j) {
                double[] dArray = centroids[y[i]];
                int n3 = j;
                dArray[n3] = dArray[n3] + data[i][j];
            }
        }
        for (i = 0; i < k; ++i) {
            j = 0;
            while (j < d) {
                double[] dArray = centroids[i];
                int n4 = j++;
                dArray[n4] = dArray[n4] / (double)size[i];
            }
        }
        return new DeterministicAnnealing(distortion, centroids, y);
    }

    private static double update(double[][] data, double T, int k, double[][] centroids, double[][] posteriori, double[] priori, int maxIter, double tol) {
        int n = data.length;
        int d = data[0].length;
        double distortion = Double.MAX_VALUE;
        double diff = Double.MAX_VALUE;
        for (int iter = 1; iter <= maxIter && diff > tol; ++iter) {
            int i2;
            double D = IntStream.range(0, n).parallel().mapToDouble(i -> {
                double Z = 0.0;
                double[] p = posteriori[i];
                double[] dist = new double[k];
                for (int j = 0; j < k; ++j) {
                    dist[j] = MathEx.squaredDistance((double[])data[i], (double[])centroids[j]);
                    p[j] = priori[j] * Math.exp(-dist[j] / T);
                    Z += p[j];
                }
                double sum = 0.0;
                for (int j = 0; j < k; ++j) {
                    int n = j;
                    p[n] = p[n] / Z;
                    sum += p[j] * dist[j];
                }
                return sum;
            }).sum();
            double H = IntStream.range(0, n).parallel().mapToDouble(i -> {
                double[] p = posteriori[i];
                double sum = 0.0;
                for (int j = 0; j < k; ++j) {
                    sum += -p[j] * Math.log(p[j]);
                }
                return sum;
            }).sum();
            Arrays.fill(priori, 0.0);
            for (i2 = 0; i2 < n; ++i2) {
                double[] p = posteriori[i2];
                for (int j = 0; j < k; ++j) {
                    int n2 = j;
                    priori[n2] = priori[n2] + p[j];
                }
            }
            i2 = 0;
            while (i2 < k) {
                int n3 = i2++;
                priori[n3] = priori[n3] / (double)n;
            }
            IntStream.range(0, k).parallel().forEach(i -> {
                Arrays.fill(centroids[i], 0.0);
                int j = 0;
                while (j < d) {
                    for (int m = 0; m < n; ++m) {
                        double[] dArray = centroids[i];
                        int n2 = j;
                        dArray[n2] = dArray[n2] + data[m][j] * posteriori[m][i];
                    }
                    double[] dArray = centroids[i];
                    int n3 = j++;
                    dArray[n3] = dArray[n3] / ((double)n * priori[i]);
                }
            });
            double DTH = D - T * H;
            diff = distortion - DTH;
            distortion = DTH;
            logger.info("Entropy after {} iterations at temperature {} and k = {}: {} (soft distortion = {})", new Object[]{iter, T, k / 2, H, D});
        }
        return distortion;
    }
}

