/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Properties;
import smile.classification.Classifier;
import smile.classification.DiscriminantAnalysis;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.projection.Projection;
import smile.util.IntSet;

public class FLD
implements Classifier<double[]>,
Projection<double[]> {
    private static final long serialVersionUID = 2L;
    private final int p;
    private final int k;
    private final Matrix scaling;
    private final double[] mean;
    private final double[][] mu;
    private final IntSet labels;

    public FLD(double[] mean, double[][] mu, Matrix scaling) {
        this(mean, mu, scaling, IntSet.of((int)mu.length));
    }

    public FLD(double[] mean, double[][] mu, Matrix scaling, IntSet labels) {
        this.k = mu.length;
        this.p = mean.length;
        this.scaling = scaling;
        this.labels = labels;
        int L = scaling.ncols();
        this.mean = new double[L];
        scaling.tv(mean, this.mean);
        this.mu = new double[this.k][L];
        for (int i = 0; i < this.k; ++i) {
            scaling.tv(mu[i], this.mu[i]);
        }
    }

    public static FLD fit(Formula formula, DataFrame data) {
        return FLD.fit(formula, data, new Properties());
    }

    public static FLD fit(Formula formula, DataFrame data, Properties prop) {
        int L = Integer.valueOf(prop.getProperty("smile.fld.dimension", "-1"));
        double tol = Double.valueOf(prop.getProperty("smile.fld.tolerance", "1E-4"));
        double[][] x = formula.x(data).toArray(false, CategoricalEncoder.DUMMY);
        int[] y = formula.y(data).toIntArray();
        return FLD.fit(x, y, L, tol);
    }

    public static FLD fit(double[][] x, int[] y) {
        return FLD.fit(x, y, -1, 1.0E-4);
    }

    public static FLD fit(double[][] x, int[] y, int L, double tol) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, null, tol);
        int n = x.length;
        int k = da.k;
        int p = da.mean.length;
        if (L >= k) {
            throw new IllegalArgumentException(String.format("The dimensionality of mapped space is too high: %d >= %d", L, k));
        }
        if (L <= 0) {
            L = k - 1;
        }
        double[] mean = da.mean;
        double[][] mu = da.mu;
        Matrix scaling = n - k < p ? FLD.small(L, x, mean, mu, da.priori, tol) : FLD.fld(L, x, mean, mu, tol);
        return new FLD(mean, mu, scaling, da.labels);
    }

    private static Matrix fld(int L, double[][] x, double[] mean, double[][] mu, double tol) {
        int k = mu.length;
        int p = mean.length;
        Matrix St = DiscriminantAnalysis.St(x, mean, k, tol);
        for (int i = 0; i < k; ++i) {
            double[] mui = mu[i];
            for (int j = 0; j < p; ++j) {
                int n = j;
                mui[n] = mui[n] - mean[j];
            }
        }
        Matrix Sb = new Matrix(p, p);
        for (int c = 0; c < k; ++c) {
            double[] mui = mu[c];
            for (int j = 0; j < p; ++j) {
                for (int i = 0; i <= j; ++i) {
                    Sb.add(i, j, mui[i] * mui[j]);
                }
            }
        }
        for (int j = 0; j < p; ++j) {
            for (int i = 0; i <= j; ++i) {
                Sb.div(i, j, (double)k);
                Sb.set(j, i, Sb.get(i, j));
            }
        }
        Matrix Sw = St.sub(1.0, Sb);
        Matrix SwInvSb = Sw.inverse().mm(Sb);
        Matrix scaling = SwInvSb.eigen((boolean)false, (boolean)true, (boolean)true).Vr.submatrix(0, 0, p - 1, L - 1);
        return scaling;
    }

    private static Matrix small(int L, double[][] x, double[] mean, double[][] mu, double[] priori, double tol) {
        int j;
        int i;
        int k = mu.length;
        int p = mean.length;
        int n = x.length;
        double sqrtn = Math.sqrt(n);
        Matrix X = new Matrix(p, n);
        for (i = 0; i < n; ++i) {
            double[] xi = x[i];
            for (j = 0; j < p; ++j) {
                X.set(j, i, (xi[j] - mean[j]) / sqrtn);
            }
        }
        for (i = 0; i < k; ++i) {
            double[] mui = mu[i];
            for (j = 0; j < p; ++j) {
                int n2 = j;
                mui[n2] = mui[n2] - mean[j];
            }
        }
        Matrix M2 = new Matrix(p, k);
        for (int i2 = 0; i2 < k; ++i2) {
            double pi = Math.sqrt(priori[i2]);
            double[] mui = mu[i2];
            for (int j2 = 0; j2 < p; ++j2) {
                M2.set(j2, i2, pi * mui[j2]);
            }
        }
        Matrix.SVD svd = X.svd(true, true);
        Matrix U = svd.U;
        double[] s = svd.s;
        tol *= tol;
        Matrix UTM = U.tm(M2);
        for (int i3 = 0; i3 < n; ++i3) {
            double si = 0.0;
            if (s[i3] > tol) {
                si = 1.0 / Math.sqrt(s[i3]);
            }
            for (int j3 = 0; j3 < k; ++j3) {
                UTM.mul(i3, j3, si);
            }
        }
        Matrix StInvM = U.mm(UTM);
        Matrix U2 = U.tm(StInvM.svd((boolean)true, (boolean)true).U.submatrix(0, 0, p - 1, L - 1));
        for (int i4 = 0; i4 < n; ++i4) {
            double si = 0.0;
            if (s[i4] > tol) {
                si = 1.0 / Math.sqrt(s[i4]);
            }
            for (int j4 = 0; j4 < L; ++j4) {
                U2.mul(i4, j4, si);
            }
        }
        Matrix scaling = U.mm(U2);
        return scaling;
    }

    @Override
    public int predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] wx = this.project(x);
        int y = 0;
        double nearest = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.k; ++i) {
            double d = MathEx.distance((double[])wx, (double[])this.mu[i]);
            if (!(d < nearest)) continue;
            nearest = d;
            y = i;
        }
        return this.labels.valueOf(y);
    }

    @Override
    public double[] project(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] y = this.scaling.tv(x);
        MathEx.sub((double[])y, (double[])this.mean);
        return y;
    }

    public double[][] project(double[][] x) {
        double[][] y = new double[x.length][this.scaling.ncols()];
        for (int i = 0; i < x.length; ++i) {
            if (x[i].length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x[i].length, this.p));
            }
            this.scaling.tv(x[i], y[i]);
            MathEx.sub((double[])y[i], (double[])this.mean);
        }
        return y;
    }

    public Matrix getProjection() {
        return this.scaling;
    }
}

