/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.classification.AbstractClassifier;
import smile.classification.ClassLabels;
import smile.math.BFGS;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.validation.ModelSelection;

public abstract class Maxent
extends AbstractClassifier<int[]> {
    private static final long serialVersionUID = 2L;
    final int p;
    final int k;
    final double L;
    final double lambda;
    double eta = 0.1;

    public Maxent(int p, double L, double lambda, IntSet labels) {
        super(labels);
        this.k = labels.size();
        this.p = p;
        this.L = L;
        this.lambda = lambda;
    }

    public static Maxent fit(int p, int[][] x, int[] y) {
        return Maxent.fit(p, x, y, new Properties());
    }

    public static Maxent fit(int p, int[][] x, int[] y, Properties params) {
        double lambda = Double.parseDouble(params.getProperty("smile.maxent.lambda", "0.1"));
        double tol = Double.parseDouble(params.getProperty("smile.maxent.tolerance", "1E-5"));
        int maxIter = Integer.parseInt(params.getProperty("smile.maxent.iterations", "500"));
        return Maxent.fit(p, x, y, lambda, tol, maxIter);
    }

    public static Maxent fit(int p, int[][] x, int[] y, double lambda, double tol, int maxIter) {
        ClassLabels codec = ClassLabels.fit(y);
        if (codec.k == 2) {
            return Maxent.binomial(p, x, y, lambda, tol, maxIter);
        }
        return Maxent.multinomial(p, x, y, lambda, tol, maxIter);
    }

    public static Binomial binomial(int p, int[][] x, int[] y) {
        return Maxent.binomial(p, x, y, new Properties());
    }

    public static Binomial binomial(int p, int[][] x, int[] y, Properties params) {
        double lambda = Double.parseDouble(params.getProperty("smile.maxent.lambda", "0.1"));
        double tol = Double.parseDouble(params.getProperty("smile.maxent.tolerance", "1E-5"));
        int maxIter = Integer.parseInt(params.getProperty("smile.maxent.iterations", "500"));
        return Maxent.binomial(p, x, y, lambda, tol, maxIter);
    }

    public static Binomial binomial(int p, int[][] x, int[] y, double lambda, double tol, int maxIter) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (p < 0) {
            throw new IllegalArgumentException("Invalid dimension: " + p);
        }
        if (lambda < 0.0) {
            throw new IllegalArgumentException("Invalid regularization factor: " + lambda);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        if (k != 2) {
            throw new IllegalArgumentException("Fits binomial model on multi-class data.");
        }
        BinomialObjective objective = new BinomialObjective(x, codec.y, p, lambda);
        double[] w = new double[p + 1];
        double L = -BFGS.minimize((DifferentiableMultivariateFunction)objective, (int)5, (double[])w, (double)tol, (int)maxIter);
        Binomial model = new Binomial(w, L, lambda, codec.classes);
        model.setLearningRate(0.1 / (double)x.length);
        return model;
    }

    public static Multinomial multinomial(int p, int[][] x, int[] y) {
        return Maxent.multinomial(p, x, y, new Properties());
    }

    public static Multinomial multinomial(int p, int[][] x, int[] y, Properties params) {
        double lambda = Double.parseDouble(params.getProperty("smile.maxent.lambda", "0.1"));
        double tol = Double.parseDouble(params.getProperty("smile.maxent.tolerance", "1E-5"));
        int maxIter = Integer.parseInt(params.getProperty("smile.maxent.iterations", "500"));
        return Maxent.multinomial(p, x, y, lambda, tol, maxIter);
    }

    public static Multinomial multinomial(int p, int[][] x, int[] y, double lambda, double tol, int maxIter) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (p < 0) {
            throw new IllegalArgumentException("Invalid dimension: " + p);
        }
        if (lambda < 0.0) {
            throw new IllegalArgumentException("Invalid regularization factor: " + lambda);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        if (k <= 2) {
            throw new IllegalArgumentException("Fits multinomial model on binary class data.");
        }
        MultinomialObjective objective = new MultinomialObjective(x, codec.y, k, p, lambda);
        double[] w = new double[(k - 1) * (p + 1)];
        double L = -BFGS.minimize((DifferentiableMultivariateFunction)objective, (int)5, (double[])w, (double)tol, (int)maxIter);
        double[][] W = new double[k - 1][p + 1];
        int l = 0;
        for (int i = 0; i < k - 1; ++i) {
            int j = 0;
            while (j <= p) {
                W[i][j] = w[l];
                ++j;
                ++l;
            }
        }
        Multinomial model = new Multinomial(W, L, lambda, codec.classes);
        model.setLearningRate(0.1 / (double)x.length);
        return model;
    }

    private static double dot(int[] x, double[] w) {
        double dot = w[w.length - 1];
        for (int i : x) {
            dot += w[i];
        }
        return dot;
    }

    private static double dot(int[] x, double[] w, int j, int p) {
        int pos = j * (p + 1);
        double dot = w[pos + p];
        for (int i : x) {
            dot += w[pos + i];
        }
        return dot;
    }

    public int dimension() {
        return this.p;
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public boolean online() {
        return true;
    }

    public void setLearningRate(double rate) {
        if (rate <= 0.0) {
            throw new IllegalArgumentException("Invalid learning rate: " + rate);
        }
        this.eta = rate;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public double loglikelihood() {
        return this.L;
    }

    public double AIC() {
        return ModelSelection.AIC(this.L, (this.k - 1) * (this.p + 1));
    }

    public static class Binomial
    extends Maxent {
        private final double[] w;

        public Binomial(double[] w, double L, double lambda, IntSet labels) {
            super(w.length - 1, L, lambda, labels);
            this.w = w;
        }

        public double[] coefficients() {
            return this.w;
        }

        @Override
        public double score(int[] x) {
            return 1.0 / (1.0 + Math.exp(-Maxent.dot(x, this.w)));
        }

        @Override
        public int predict(int[] x) {
            double f = 1.0 / (1.0 + Math.exp(-Maxent.dot(x, this.w)));
            return this.classes.valueOf(f < 0.5 ? 0 : 1);
        }

        @Override
        public int predict(int[] x, double[] posteriori) {
            if (posteriori.length != this.k) {
                throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
            }
            double f = 1.0 / (1.0 + Math.exp(-Maxent.dot(x, this.w)));
            posteriori[0] = 1.0 - f;
            posteriori[1] = f;
            return this.classes.valueOf(f < 0.5 ? 0 : 1);
        }

        @Override
        public void update(int[] x, int y) {
            y = this.classes.indexOf(y);
            double wx = Maxent.dot(x, this.w);
            double err = (double)y - MathEx.sigmoid((double)wx);
            int n = this.p;
            this.w[n] = this.w[n] + this.eta * err;
            int[] nArray = x;
            int n2 = nArray.length;
            for (int i = 0; i < n2; ++i) {
                int j;
                int n3 = j = nArray[i];
                this.w[n3] = this.w[n3] + this.eta * err;
            }
        }
    }

    public static class Multinomial
    extends Maxent {
        private final double[][] w;

        public Multinomial(double[][] w, double L, double lambda, IntSet labels) {
            super(w[0].length - 1, L, lambda, labels);
            this.w = w;
        }

        public double[][] coefficients() {
            return this.w;
        }

        @Override
        public int predict(int[] x) {
            return this.predict(x, new double[this.k]);
        }

        @Override
        public int predict(int[] x, double[] posteriori) {
            if (posteriori.length != this.k) {
                throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
            }
            posteriori[this.k - 1] = 0.0;
            for (int i = 0; i < this.k - 1; ++i) {
                posteriori[i] = Maxent.dot(x, this.w[i]);
            }
            MathEx.softmax((double[])posteriori);
            return this.classes.valueOf(MathEx.whichMax((double[])posteriori));
        }

        @Override
        public void update(int[] x, int y) {
            y = this.classes.indexOf(y);
            double[] prob = new double[this.k];
            for (int j = 0; j < this.k - 1; ++j) {
                prob[j] = Maxent.dot(x, this.w[j]);
            }
            MathEx.softmax((double[])prob);
            for (int i = 0; i < this.k - 1; ++i) {
                double[] wi = this.w[i];
                double err = (y == i ? 1.0 : 0.0) - prob[i];
                int n = this.p;
                wi[n] = wi[n] + this.eta * err;
                int[] nArray = x;
                int n2 = nArray.length;
                for (int j = 0; j < n2; ++j) {
                    int j2;
                    int n3 = j2 = nArray[j];
                    wi[n3] = wi[n3] + this.eta * err;
                }
            }
        }
    }

    static class BinomialObjective
    implements DifferentiableMultivariateFunction {
        final int[][] x;
        final int[] y;
        final int p;
        final double lambda;
        final int partitionSize;
        final int partitions;
        final double[][] gradients;

        BinomialObjective(int[][] x, int[] y, int p, double lambda) {
            this.x = x;
            this.y = y;
            this.p = p;
            this.lambda = lambda;
            this.partitionSize = Integer.parseInt(System.getProperty("smile.data.partition.size", "1000"));
            this.partitions = x.length / this.partitionSize + (x.length % this.partitionSize == 0 ? 0 : 1);
            this.gradients = new double[this.partitions][p + 1];
        }

        public double f(double[] w) {
            double f = IntStream.range(0, this.x.length).parallel().mapToDouble(i -> {
                double wx = Maxent.dot(this.x[i], w);
                return MathEx.log1pe((double)wx) - (double)this.y[i] * wx;
            }).sum();
            if (this.lambda > 0.0) {
                double wnorm = 0.0;
                for (int i2 = 0; i2 < this.p; ++i2) {
                    wnorm += w[i2] * w[i2];
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }

        public double g(double[] w, double[] g) {
            double f = IntStream.range(0, this.partitions).parallel().mapToDouble(r -> {
                double[] gradient = this.gradients[r];
                Arrays.fill(gradient, 0.0);
                int begin = r * this.partitionSize;
                int end = (r + 1) * this.partitionSize;
                if (end > this.x.length) {
                    end = this.x.length;
                }
                return IntStream.range(begin, end).sequential().mapToDouble(i -> {
                    double wx = Maxent.dot(this.x[i], w);
                    double err = (double)this.y[i] - MathEx.sigmoid((double)wx);
                    int[] nArray = this.x[i];
                    int n = nArray.length;
                    for (int j = 0; j < n; ++j) {
                        int j2;
                        int n2 = j2 = nArray[j];
                        gradient[n2] = gradient[n2] - err;
                    }
                    int n3 = this.p;
                    gradient[n3] = gradient[n3] - err;
                    return MathEx.log1pe((double)wx) - (double)this.y[i] * wx;
                }).sum();
            }).sum();
            Arrays.fill(g, 0.0);
            for (double[] gradient : this.gradients) {
                for (int i = 0; i < g.length; ++i) {
                    int n = i;
                    g[n] = g[n] + gradient[i];
                }
            }
            if (this.lambda > 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < this.p; ++i) {
                    wnorm += w[i] * w[i];
                    int n = i;
                    g[n] = g[n] + this.lambda * w[i];
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }
    }

    static class MultinomialObjective
    implements DifferentiableMultivariateFunction {
        final int[][] x;
        final int[] y;
        final int k;
        final int p;
        final double lambda;
        final int partitionSize;
        final int partitions;
        final double[][] gradients;
        final double[][] posterioris;

        MultinomialObjective(int[][] x, int[] y, int k, int p, double lambda) {
            this.x = x;
            this.y = y;
            this.k = k;
            this.p = p;
            this.lambda = lambda;
            this.partitionSize = Integer.parseInt(System.getProperty("smile.data.partition.size", "1000"));
            this.partitions = x.length / this.partitionSize + (x.length % this.partitionSize == 0 ? 0 : 1);
            this.gradients = new double[this.partitions][(k - 1) * (p + 1)];
            this.posterioris = new double[this.partitions][k];
        }

        public double f(double[] w) {
            double f = IntStream.range(0, this.partitions).parallel().mapToDouble(r -> {
                double[] posteriori = this.posterioris[r];
                int begin = r * this.partitionSize;
                int end = (r + 1) * this.partitionSize;
                if (end > this.x.length) {
                    end = this.x.length;
                }
                return IntStream.range(begin, end).sequential().mapToDouble(i -> {
                    posteriori[this.k - 1] = 0.0;
                    for (int j = 0; j < this.k - 1; ++j) {
                        posteriori[j] = Maxent.dot(this.x[i], w, j, this.p);
                    }
                    MathEx.softmax((double[])posteriori);
                    return -MathEx.log((double)posteriori[this.y[i]]);
                }).sum();
            }).sum();
            if (this.lambda > 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < this.k - 1; ++i) {
                    int pos = i * (this.p + 1);
                    for (int j = 0; j < this.p; ++j) {
                        double wi = w[pos + j];
                        wnorm += wi * wi;
                    }
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }

        public double g(double[] w, double[] g) {
            double f = IntStream.range(0, this.partitions).parallel().mapToDouble(r -> {
                double[] posteriori = this.posterioris[r];
                double[] gradient = this.gradients[r];
                Arrays.fill(gradient, 0.0);
                int begin = r * this.partitionSize;
                int end = (r + 1) * this.partitionSize;
                if (end > this.x.length) {
                    end = this.x.length;
                }
                return IntStream.range(begin, end).sequential().mapToDouble(i -> {
                    int j;
                    posteriori[this.k - 1] = 0.0;
                    for (j = 0; j < this.k - 1; ++j) {
                        posteriori[j] = Maxent.dot(this.x[i], w, j, this.p);
                    }
                    MathEx.softmax((double[])posteriori);
                    for (j = 0; j < this.k - 1; ++j) {
                        double err = (this.y[i] == j ? 1.0 : 0.0) - posteriori[j];
                        int pos = j * (this.p + 1);
                        for (int l : this.x[i]) {
                            int n = pos + l;
                            gradient[n] = gradient[n] - err;
                        }
                        int n = pos + this.p;
                        gradient[n] = gradient[n] - err;
                    }
                    return -MathEx.log((double)posteriori[this.y[i]]);
                }).sum();
            }).sum();
            Arrays.fill(g, 0.0);
            for (double[] gradient : this.gradients) {
                for (int i = 0; i < g.length; ++i) {
                    int n = i;
                    g[n] = g[n] + gradient[i];
                }
            }
            if (this.lambda > 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < this.k - 1; ++i) {
                    int pos = i * (this.p + 1);
                    for (int j = 0; j < this.p; ++j) {
                        double wi = w[pos + j];
                        wnorm += wi * wi;
                        int n = pos + j;
                        g[n] = g[n] + this.lambda * wi;
                    }
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }
    }
}

