/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.AbstractClassifier;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.util.SparseArray;

public class DiscreteNaiveBayes
extends AbstractClassifier<int[]> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(DiscreteNaiveBayes.class);
    private static final double EPSILON = 1.0E-20;
    private final Model model;
    private final int k;
    private final int p;
    private final double[] priori;
    private final double sigma;
    private final boolean fixedPriori;
    private int n;
    private final int[] nc;
    private final int[] nt;
    private final int[][] ntc;
    private final double[][] logcondprob;

    public DiscreteNaiveBayes(Model model, int k, int p) {
        this(model, k, p, 1.0, IntSet.of((int)k));
    }

    public DiscreteNaiveBayes(Model model, int k, int p, double sigma, IntSet labels) {
        super(labels);
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of classes: " + k);
        }
        if (p <= 0) {
            throw new IllegalArgumentException("Invalid dimension: " + p);
        }
        if (sigma < 0.0) {
            throw new IllegalArgumentException("Invalid add-k smoothing parameter: " + sigma);
        }
        this.model = model;
        this.k = k;
        this.p = p;
        this.sigma = sigma;
        this.fixedPriori = false;
        this.priori = new double[k];
        this.n = 0;
        this.nc = new int[k];
        this.nt = new int[k];
        this.ntc = new int[k][p];
        this.logcondprob = new double[k][p];
    }

    public DiscreteNaiveBayes(Model model, double[] priori, int p) {
        this(model, priori, p, 1.0, IntSet.of((int)priori.length));
    }

    public DiscreteNaiveBayes(Model model, double[] priori, int p, double sigma, IntSet labels) {
        super(labels);
        if (p <= 0) {
            throw new IllegalArgumentException("Invalid dimension: " + p);
        }
        if (sigma < 0.0) {
            throw new IllegalArgumentException("Invalid add-k smoothing parameter: " + sigma);
        }
        if (priori.length < 2) {
            throw new IllegalArgumentException("Invalid number of classes: " + priori.length);
        }
        double sum = 0.0;
        for (double pr : priori) {
            if (pr <= 0.0 || pr >= 1.0) {
                throw new IllegalArgumentException("Invalid priori probability: " + pr);
            }
            sum += pr;
        }
        if (Math.abs(sum - 1.0) > 1.0E-5) {
            throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
        }
        this.model = model;
        this.k = priori.length;
        this.p = p;
        this.sigma = sigma;
        this.priori = priori;
        this.fixedPriori = true;
        this.n = 0;
        this.nc = new int[this.k];
        this.nt = new int[this.k];
        this.ntc = new int[this.k][p];
        this.logcondprob = new double[this.k][p];
    }

    public double[] priori() {
        return this.priori;
    }

    @Override
    public void update(int[] x, int y) {
        if (!this.isGoodInstance(x)) {
            return;
        }
        if (this.model == Model.TWCNB) {
            throw new UnsupportedOperationException("TWCNB supports only batch learning");
        }
        y = this.classes.indexOf(y);
        switch (this.model.ordinal()) {
            case 0: 
            case 3: 
            case 4: {
                for (int i = 0; i < this.p; ++i) {
                    int[] nArray = this.ntc[y];
                    int n = i;
                    nArray[n] = nArray[n] + x[i];
                    int n2 = y;
                    this.nt[n2] = this.nt[n2] + x[i];
                }
                break;
            }
            case 2: {
                for (int i = 0; i < this.p; ++i) {
                    int[] nArray = this.ntc[y];
                    int n = i;
                    nArray[n] = nArray[n] + x[i] * 2;
                    int n3 = y;
                    this.nt[n3] = this.nt[n3] + x[i] * 2;
                }
                break;
            }
            case 1: {
                for (int i = 0; i < this.p; ++i) {
                    if (x[i] <= 0) continue;
                    int[] nArray = this.ntc[y];
                    int n = i;
                    nArray[n] = nArray[n] + 1;
                }
                break;
            }
            default: {
                throw new IllegalStateException("Unknown model: " + String.valueOf((Object)this.model));
            }
        }
        ++this.n;
        int n = y;
        this.nc[n] = this.nc[n] + 1;
        this.update();
    }

    @Override
    public void update(SparseArray x, int y) {
        if (!this.isGoodInstance(x)) {
            return;
        }
        if (this.model == Model.TWCNB) {
            throw new UnsupportedOperationException("TWCNB supports only batch learning");
        }
        y = this.classes.indexOf(y);
        switch (this.model.ordinal()) {
            case 0: 
            case 3: 
            case 4: {
                for (SparseArray.Entry e : x) {
                    int ex = (int)e.x;
                    int[] nArray = this.ntc[y];
                    int n = e.i;
                    nArray[n] = nArray[n] + ex;
                    int n2 = y;
                    this.nt[n2] = this.nt[n2] + ex;
                }
                break;
            }
            case 2: {
                for (SparseArray.Entry e : x) {
                    int ex = (int)e.x;
                    int[] nArray = this.ntc[y];
                    int n = e.i;
                    nArray[n] = nArray[n] + ex * 2;
                    int n3 = y;
                    this.nt[n3] = this.nt[n3] + ex * 2;
                }
                break;
            }
            case 1: {
                for (SparseArray.Entry e : x) {
                    if (!(e.x > 0.0)) continue;
                    int[] nArray = this.ntc[y];
                    int n = e.i;
                    nArray[n] = nArray[n] + 1;
                }
                break;
            }
            default: {
                throw new IllegalStateException("Unknown model: " + String.valueOf((Object)this.model));
            }
        }
        ++this.n;
        int n = y;
        this.nc[n] = this.nc[n] + 1;
        this.update();
    }

    public void update(int[][] x, int[] y) {
        switch (this.model.ordinal()) {
            case 0: 
            case 3: 
            case 4: {
                for (int i = 0; i < x.length; ++i) {
                    if (!this.isGoodInstance(x[i])) continue;
                    int yi = this.classes.indexOf(y[i]);
                    for (int j = 0; j < this.p; ++j) {
                        int[] nArray = this.ntc[yi];
                        int n = j;
                        nArray[n] = nArray[n] + x[i][j];
                        int n2 = yi;
                        this.nt[n2] = this.nt[n2] + x[i][j];
                    }
                    ++this.n;
                    int n = yi;
                    this.nc[n] = this.nc[n] + 1;
                }
                break;
            }
            case 5: {
                int c;
                int[] ni = new int[this.p];
                double[] d = new double[this.p];
                for (int[] doc : x) {
                    for (int i = 0; i < this.p; ++i) {
                        if (doc[i] <= 0) continue;
                        int n = i;
                        ni[n] = ni[n] + 1;
                    }
                }
                double N = 0.0;
                for (int[] xi : x) {
                    if (!this.isGoodInstance(xi)) continue;
                    N += 1.0;
                }
                for (int i = 0; i < x.length; ++i) {
                    int[] xi = x[i];
                    if (!this.isGoodInstance(xi)) continue;
                    Arrays.fill(d, 0.0);
                    for (int t = 0; t < this.p; ++t) {
                        if (xi[t] <= 0) continue;
                        d[t] = Math.log(1 + xi[t]) * Math.log(N / (double)ni[t]);
                    }
                    MathEx.unitize2((double[])d);
                    int yi = y[i];
                    for (int t = 0; t < this.p; ++t) {
                        double[] dArray = this.logcondprob[yi];
                        int n = t;
                        dArray[n] = dArray[n] + d[t];
                    }
                }
                double[] rsum = MathEx.rowSums((double[][])this.logcondprob);
                double[] csum = MathEx.colSums((double[][])this.logcondprob);
                double sum = MathEx.sum((double[])csum);
                for (c = 0; c < this.k; ++c) {
                    for (int t = 0; t < this.p; ++t) {
                        this.logcondprob[c][t] = Math.log((csum[t] - this.logcondprob[c][t] + this.sigma) / (sum - rsum[c] + this.sigma * (double)this.p));
                    }
                }
                for (c = 0; c < this.k; ++c) {
                    MathEx.unitize1((double[])this.logcondprob[c]);
                }
                break;
            }
            case 2: {
                for (int i = 0; i < x.length; ++i) {
                    if (!this.isGoodInstance(x[i])) continue;
                    int yi = this.classes.indexOf(y[i]);
                    for (int j = 0; j < this.p; ++j) {
                        int[] nArray = this.ntc[yi];
                        int n = j;
                        nArray[n] = nArray[n] + x[i][j] * 2;
                        int n3 = yi;
                        this.nt[n3] = this.nt[n3] + x[i][j] * 2;
                    }
                    ++this.n;
                    int n = yi;
                    this.nc[n] = this.nc[n] + 1;
                }
                break;
            }
            case 1: {
                for (int i = 0; i < x.length; ++i) {
                    if (!this.isGoodInstance(x[i])) continue;
                    int yi = this.classes.indexOf(y[i]);
                    for (int j = 0; j < this.p; ++j) {
                        if (x[i][j] <= 0) continue;
                        int[] nArray = this.ntc[yi];
                        int n = j;
                        nArray[n] = nArray[n] + 1;
                    }
                    ++this.n;
                    int n = yi;
                    this.nc[n] = this.nc[n] + 1;
                }
                break;
            }
            default: {
                throw new IllegalStateException("Unknown model: " + String.valueOf((Object)this.model));
            }
        }
        this.update();
    }

    public void update(SparseArray[] x, int[] y) {
        switch (this.model.ordinal()) {
            case 0: 
            case 3: 
            case 4: {
                for (int i = 0; i < x.length; ++i) {
                    if (!this.isGoodInstance(x[i])) continue;
                    int yi = this.classes.indexOf(y[i]);
                    for (SparseArray.Entry e : x[i]) {
                        int ex = (int)e.x;
                        int[] nArray = this.ntc[yi];
                        int n = e.i;
                        nArray[n] = nArray[n] + ex;
                        int n2 = yi;
                        this.nt[n2] = this.nt[n2] + ex;
                    }
                    ++this.n;
                    int n = yi;
                    this.nc[n] = this.nc[n] + 1;
                }
                break;
            }
            case 5: {
                int c;
                int[] ni = new int[this.p];
                double[] d = new double[this.p];
                for (SparseArray doc : x) {
                    for (SparseArray.Entry e : doc) {
                        if (!(e.x > 0.0)) continue;
                        int n = e.i;
                        ni[n] = ni[n] + 1;
                    }
                }
                double N = 0.0;
                for (SparseArray xi : x) {
                    if (!this.isGoodInstance(xi)) continue;
                    N += 1.0;
                }
                for (int i = 0; i < x.length; ++i) {
                    SparseArray xi = x[i];
                    if (!this.isGoodInstance(xi)) continue;
                    Arrays.fill(d, 0.0);
                    for (SparseArray.Entry e : xi) {
                        if (!(e.x > 0.0)) continue;
                        d[e.i] = Math.log(1.0 + e.x) * Math.log(N / (double)ni[e.i]);
                    }
                    MathEx.unitize2((double[])d);
                    int yi = y[i];
                    for (int t = 0; t < this.p; ++t) {
                        double[] dArray = this.logcondprob[yi];
                        int n = t;
                        dArray[n] = dArray[n] + d[t];
                    }
                }
                double[] rsum = MathEx.rowSums((double[][])this.logcondprob);
                double[] csum = MathEx.colSums((double[][])this.logcondprob);
                double sum = MathEx.sum((double[])csum);
                for (c = 0; c < this.k; ++c) {
                    for (int t = 0; t < this.p; ++t) {
                        this.logcondprob[c][t] = Math.log((csum[t] - this.logcondprob[c][t] + this.sigma) / (sum - rsum[c] + this.sigma * (double)this.p));
                    }
                }
                for (c = 0; c < this.k; ++c) {
                    MathEx.unitize1((double[])this.logcondprob[c]);
                }
                break;
            }
            case 2: {
                for (int i = 0; i < x.length; ++i) {
                    if (!this.isGoodInstance(x[i])) continue;
                    int yi = this.classes.indexOf(y[i]);
                    for (SparseArray.Entry e : x[i]) {
                        int ex = (int)e.x;
                        int[] nArray = this.ntc[yi];
                        int n = e.i;
                        nArray[n] = nArray[n] + ex * 2;
                        int n3 = yi;
                        this.nt[n3] = this.nt[n3] + ex * 2;
                    }
                    ++this.n;
                    int n = yi;
                    this.nc[n] = this.nc[n] + 1;
                }
                break;
            }
            case 1: {
                for (int i = 0; i < x.length; ++i) {
                    if (!this.isGoodInstance(x[i])) continue;
                    int yi = this.classes.indexOf(y[i]);
                    for (SparseArray.Entry e : x[i]) {
                        if (!(e.x > 0.0)) continue;
                        int[] nArray = this.ntc[yi];
                        int n = e.i;
                        nArray[n] = nArray[n] + 1;
                    }
                    ++this.n;
                    int n = yi;
                    this.nc[n] = this.nc[n] + 1;
                }
                break;
            }
            default: {
                throw new IllegalStateException("Unknown model: " + String.valueOf((Object)this.model));
            }
        }
        this.update();
    }

    private void update() {
        int c;
        if (!this.fixedPriori) {
            for (c = 0; c < this.k; ++c) {
                this.priori[c] = ((double)this.nc[c] + 1.0E-20) / ((double)this.n + (double)this.k * 1.0E-20);
            }
        }
        switch (this.model.ordinal()) {
            case 0: 
            case 2: {
                for (c = 0; c < this.k; ++c) {
                    for (int t = 0; t < this.p; ++t) {
                        this.logcondprob[c][t] = Math.log(((double)this.ntc[c][t] + this.sigma) / ((double)this.nt[c] + this.sigma * (double)this.p));
                    }
                }
                break;
            }
            case 1: {
                for (c = 0; c < this.k; ++c) {
                    for (int t = 0; t < this.p; ++t) {
                        this.logcondprob[c][t] = Math.log(((double)this.ntc[c][t] + this.sigma) / ((double)this.nc[c] + this.sigma * 2.0));
                    }
                }
                break;
            }
            case 3: 
            case 4: {
                int c2;
                long ntsum = MathEx.sum((int[])this.nt);
                long[] ntcsum = MathEx.colSums((int[][])this.ntc);
                for (c2 = 0; c2 < this.k; ++c2) {
                    for (int t = 0; t < this.p; ++t) {
                        this.logcondprob[c2][t] = Math.log(((double)(ntcsum[t] - (long)this.ntc[c2][t]) + this.sigma) / ((double)(ntsum - (long)this.nt[c2]) + this.sigma * (double)this.p));
                    }
                }
                if (this.model != Model.WCNB) break;
                for (c2 = 0; c2 < this.k; ++c2) {
                    MathEx.unitize1((double[])this.logcondprob[c2]);
                }
                break;
            }
            case 5: {
                break;
            }
            default: {
                throw new IllegalStateException("Unknown model: " + String.valueOf((Object)this.model));
            }
        }
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public boolean online() {
        return true;
    }

    @Override
    public int predict(int[] x) {
        return this.predict(x, new double[this.k]);
    }

    @Override
    public int predict(int[] x, double[] posteriori) {
        if (!this.isGoodInstance(x)) {
            return Integer.MIN_VALUE;
        }
        for (int i = 0; i < this.k; ++i) {
            double logprob;
            switch (this.model.ordinal()) {
                case 0: 
                case 2: {
                    int j;
                    logprob = Math.log(this.priori[i]);
                    for (j = 0; j < this.p; ++j) {
                        if (x[j] <= 0) continue;
                        logprob += (double)x[j] * this.logcondprob[i][j];
                    }
                    break;
                }
                case 1: {
                    int j;
                    logprob = Math.log(this.priori[i]);
                    for (j = 0; j < this.p; ++j) {
                        if (x[j] > 0) {
                            logprob += this.logcondprob[i][j];
                            continue;
                        }
                        logprob += Math.log(1.0 - Math.exp(this.logcondprob[i][j]));
                    }
                    break;
                }
                case 3: 
                case 4: 
                case 5: {
                    int j;
                    logprob = 0.0;
                    for (j = 0; j < this.p; ++j) {
                        if (x[j] <= 0) continue;
                        logprob -= (double)x[j] * this.logcondprob[i][j];
                    }
                    break;
                }
                default: {
                    throw new IllegalStateException("Unknown model: " + String.valueOf((Object)this.model));
                }
            }
            posteriori[i] = logprob;
        }
        MathEx.softmax((double[])posteriori);
        return MathEx.whichMax((double[])posteriori);
    }

    private boolean isGoodInstance(int[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid vector size: %d", x.length));
        }
        boolean any = false;
        for (int xi : x) {
            if (xi <= 0) continue;
            any = true;
            break;
        }
        return any;
    }

    private boolean isGoodInstance(SparseArray x) {
        return !x.isEmpty();
    }

    @Override
    public int predict(SparseArray x) {
        return this.predict(x, new double[this.k]);
    }

    @Override
    public int predict(SparseArray x, double[] posteriori) {
        if (!this.isGoodInstance(x)) {
            return Integer.MIN_VALUE;
        }
        for (int i = 0; i < this.k; ++i) {
            double logprob;
            switch (this.model.ordinal()) {
                case 0: 
                case 2: {
                    logprob = Math.log(this.priori[i]);
                    for (SparseArray.Entry e : x) {
                        if (!(e.x > 0.0)) continue;
                        logprob += e.x * this.logcondprob[i][e.i];
                    }
                    break;
                }
                case 1: {
                    logprob = Math.log(this.priori[i]);
                    for (SparseArray.Entry e : x) {
                        if (e.x > 0.0) {
                            logprob += this.logcondprob[i][e.i];
                            continue;
                        }
                        logprob += Math.log(1.0 - Math.exp(this.logcondprob[i][e.i]));
                    }
                    break;
                }
                case 3: 
                case 4: 
                case 5: {
                    logprob = 0.0;
                    for (SparseArray.Entry e : x) {
                        if (!(e.x > 0.0)) continue;
                        logprob -= e.x * this.logcondprob[i][e.i];
                    }
                    break;
                }
                default: {
                    throw new IllegalStateException("Unknown model: " + String.valueOf((Object)this.model));
                }
            }
            posteriori[i] = logprob;
        }
        MathEx.softmax((double[])posteriori);
        return this.classes.valueOf(MathEx.whichMax((double[])posteriori));
    }

    public static enum Model {
        MULTINOMIAL,
        BERNOULLI,
        POLYAURN,
        CNB,
        WCNB,
        TWCNB;

    }
}

