/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Properties;
import smile.classification.AbstractClassifier;
import smile.classification.DiscriminantAnalysis;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.util.IntSet;
import smile.util.Strings;

public class QDA
extends AbstractClassifier<double[]> {
    private static final long serialVersionUID = 2L;
    private final int p;
    private final int k;
    private final double[] logppriori;
    private final double[] priori;
    private final double[][] mu;
    private final double[][] eigen;
    private final Matrix[] scaling;

    public QDA(double[] priori, double[][] mu, double[][] eigen, Matrix[] scaling) {
        this(priori, mu, eigen, scaling, IntSet.of((int)priori.length));
    }

    public QDA(double[] priori, double[][] mu, double[][] eigen, Matrix[] scaling, IntSet labels) {
        super(labels);
        this.k = priori.length;
        this.p = mu[0].length;
        this.priori = priori;
        this.mu = mu;
        this.eigen = eigen;
        this.scaling = scaling;
        this.logppriori = new double[this.k];
        for (int i = 0; i < this.k; ++i) {
            double logev = 0.0;
            for (int j = 0; j < this.p; ++j) {
                logev += Math.log(eigen[i][j]);
            }
            this.logppriori[i] = Math.log(priori[i]) - 0.5 * logev;
        }
    }

    public static QDA fit(double[][] x, int[] y) {
        return QDA.fit(x, y, null, 1.0E-4);
    }

    public static QDA fit(double[][] x, int[] y, Properties params) {
        double[] priori = Strings.parseDoubleArray((String)params.getProperty("smile.qda.priori"));
        double tol = Double.parseDouble(params.getProperty("smile.qda.tolerance", "1E-4"));
        return QDA.fit(x, y, priori, tol);
    }

    public static QDA fit(double[][] x, int[] y, double[] priori, double tol) {
        DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, priori, tol);
        Matrix[] cov = DiscriminantAnalysis.cov(x, y, da.mu, da.ni);
        int k = cov.length;
        int p = cov[0].nrow();
        double[][] eigen = new double[k][];
        Matrix[] scaling = new Matrix[k];
        tol *= tol;
        for (int i = 0; i < k; ++i) {
            for (int j = 0; j < p; ++j) {
                if (!(cov[i].get(j, j) < tol)) continue;
                throw new IllegalArgumentException(String.format("Class %d covariance matrix (column %d) is close to singular.", i, j));
            }
            Matrix.EVD evd = cov[i].eigen(false, true, true).sort();
            for (double s : evd.wr) {
                if (!(s < tol)) continue;
                throw new IllegalArgumentException(String.format("Class %d covariance matrix is close to singular.", i));
            }
            eigen[i] = evd.wr;
            scaling[i] = evd.Vr;
        }
        return new QDA(da.priori, da.mu, eigen, scaling, da.labels);
    }

    public double[] priori() {
        return this.priori;
    }

    @Override
    public int predict(double[] x) {
        return this.predict(x, new double[this.k]);
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] d = new double[this.p];
        double[] ux = new double[this.p];
        for (int i = 0; i < this.k; ++i) {
            double[] mui = this.mu[i];
            for (int j = 0; j < this.p; ++j) {
                d[j] = x[j] - mui[j];
            }
            this.scaling[i].tv(d, ux);
            double f = 0.0;
            double[] ev = this.eigen[i];
            for (int j = 0; j < this.p; ++j) {
                f += ux[j] * ux[j] / ev[j];
            }
            posteriori[i] = this.logppriori[i] - 0.5 * f;
        }
        return this.classes.valueOf(MathEx.softmax((double[])posteriori));
    }
}

