/*
 * Decompiled with CFR 0.152.
 */
package smile.vq;

import java.io.Serializable;
import java.util.Arrays;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import smile.clustering.CentroidClustering;
import smile.math.MathEx;
import smile.math.TimeFunction;
import smile.mds.MDS;
import smile.sort.QuickSort;
import smile.vq.Neighborhood;
import smile.vq.VectorQuantizer;

public class SOM
implements VectorQuantizer {
    private static final long serialVersionUID = 2L;
    private int nrows;
    private int ncols;
    private Neuron[][] map;
    private Neuron[] neurons;
    private double[] dist;
    private TimeFunction alpha;
    private Neighborhood theta;
    private int t = 0;
    private double eps = 1.0E-5;

    public SOM(double[][][] neurons, TimeFunction alpha, Neighborhood theta) {
        this.alpha = alpha;
        this.theta = theta;
        this.nrows = neurons.length;
        this.ncols = neurons[0].length;
        this.map = new Neuron[this.nrows][this.ncols];
        this.neurons = new Neuron[this.nrows * this.ncols];
        this.dist = new double[this.neurons.length];
        int k = 0;
        for (int i = 0; i < this.nrows; ++i) {
            int j = 0;
            while (j < this.ncols) {
                Neuron neuron;
                this.map[i][j] = neuron = new Neuron(i, j, (double[])neurons[i][j].clone());
                this.neurons[k] = neuron;
                ++j;
                ++k;
            }
        }
    }

    public static double[][][] lattice(int nrows, int ncols, double[][] samples) {
        int k = nrows * ncols;
        int n = samples.length;
        int[] clusters = new int[n];
        double[][] medoids = new double[k][];
        CentroidClustering.seed(samples, medoids, clusters, MathEx::squaredDistance);
        double[][] pdist = MathEx.pdist((double[][])medoids);
        MDS mds = MDS.of(pdist);
        double[][] coordinates = mds.coordinates;
        double[] x = Arrays.stream(coordinates).mapToDouble(point -> point[0]).toArray();
        double[] y = new double[ncols];
        int[] row = new int[ncols];
        int[] index = QuickSort.sort((double[])x);
        double[][][] neurons = new double[nrows][ncols][];
        for (int i = 0; i < nrows; ++i) {
            int j;
            for (j = 0; j < ncols; ++j) {
                int point2 = index[i * ncols + j];
                y[j] = coordinates[point2][1];
                row[j] = point2;
            }
            QuickSort.sort((double[])y, (int[])row);
            for (j = 0; j < ncols; ++j) {
                neurons[i][j] = medoids[row[j]];
            }
        }
        return neurons;
    }

    @Override
    public void update(double[] x) {
        Neuron bmu = this.bmu(x);
        int i = bmu.i;
        int j = bmu.j;
        int d = bmu.w.length;
        double alpha = this.alpha.apply(this.t);
        ((Stream)Arrays.stream(this.neurons).parallel()).forEach(neuron -> {
            double delta = alpha * this.theta.of(neuron.i - i, neuron.j - j, this.t);
            if (delta > this.eps) {
                double[] w = neuron.w;
                for (int k = 0; k < d; ++k) {
                    int n = k;
                    w[n] = w[n] + delta * (x[k] - w[k]);
                }
            }
        });
        ++this.t;
    }

    public double[][][] neurons() {
        double[][][] lattice = new double[this.nrows][this.ncols][];
        for (int i = 0; i < this.nrows; ++i) {
            for (int j = 0; j < this.ncols; ++j) {
                lattice[i][j] = this.map[i][j].w;
            }
        }
        return lattice;
    }

    public double[][] umatrix() {
        int i;
        double[][] umatrix = new double[this.nrows][this.ncols];
        for (i = 0; i < this.nrows - 1; ++i) {
            for (int j = 0; j < this.ncols - 1; ++j) {
                double dist = Math.sqrt(MathEx.distance((double[])this.map[i][j].w, (double[])this.map[i][j + 1].w));
                umatrix[i][j] = Math.max(umatrix[i][j], dist);
                umatrix[i][j + 1] = Math.max(umatrix[i][j + 1], dist);
                dist = Math.sqrt(MathEx.distance((double[])this.map[i][j].w, (double[])this.map[i + 1][j].w));
                umatrix[i][j] = Math.max(umatrix[i][j], dist);
                umatrix[i + 1][j] = Math.max(umatrix[i + 1][j], dist);
            }
        }
        for (i = 0; i < this.nrows - 1; ++i) {
            double dist = Math.sqrt(MathEx.distance((double[])this.map[i][this.ncols - 1].w, (double[])this.map[i + 1][this.ncols - 1].w));
            umatrix[i][this.ncols - 1] = Math.max(umatrix[i][this.ncols - 1], dist);
            umatrix[i + 1][this.ncols - 1] = Math.max(umatrix[i + 1][this.ncols - 1], dist);
        }
        for (int j = 0; j < this.ncols - 1; ++j) {
            double dist = Math.sqrt(MathEx.distance((double[])this.map[this.nrows - 1][j].w, (double[])this.map[this.nrows - 1][j + 1].w));
            umatrix[this.nrows - 1][j] = Math.max(umatrix[this.nrows - 1][j], dist);
            umatrix[this.nrows - 1][j + 1] = Math.max(umatrix[this.nrows - 1][j + 1], dist);
        }
        umatrix[this.nrows - 1][this.ncols - 1] = Math.max(umatrix[this.nrows - 1][this.ncols - 2], umatrix[this.nrows - 2][this.ncols - 1]);
        return umatrix;
    }

    @Override
    public double[] quantize(double[] x) {
        return this.bmu((double[])x).w;
    }

    private Neuron bmu(double[] x) {
        IntStream.range(0, this.neurons.length).parallel().forEach(i -> {
            this.dist[i] = MathEx.distance((double[])this.neurons[i].w, (double[])x);
        });
        QuickSort.sort((double[])this.dist, (Object[])this.neurons);
        return this.neurons[0];
    }

    private static class Neuron
    implements Serializable {
        public final double[] w;
        public final int i;
        public final int j;

        public Neuron(int i, int j, double[] w) {
            this.i = i;
            this.j = j;
            this.w = w;
        }
    }
}

