/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.SplitRule;
import smile.classification.AbstractClassifier;
import smile.classification.ClassLabels;
import smile.classification.DataFrameClassifier;
import smile.classification.DecisionTree;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.feature.importance.TreeSHAP;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.util.Strings;

public class AdaBoost
extends AbstractClassifier<Tuple>
implements DataFrameClassifier,
TreeSHAP {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(AdaBoost.class);
    private final Formula formula;
    private final int k;
    private DecisionTree[] trees;
    private double[] alpha;
    private double[] error;
    private final double[] importance;

    public AdaBoost(Formula formula, int k, DecisionTree[] trees, double[] alpha, double[] error, double[] importance) {
        this(formula, k, trees, alpha, error, importance, IntSet.of((int)k));
    }

    public AdaBoost(Formula formula, int k, DecisionTree[] trees, double[] alpha, double[] error, double[] importance, IntSet labels) {
        super(labels);
        this.formula = formula;
        this.k = k;
        this.trees = trees;
        this.alpha = alpha;
        this.error = error;
        this.importance = importance;
    }

    public static AdaBoost fit(Formula formula, DataFrame data) {
        return AdaBoost.fit(formula, data, new Properties());
    }

    public static AdaBoost fit(Formula formula, DataFrame data, Properties params) {
        int ntrees = Integer.parseInt(params.getProperty("smile.adaboost.trees", "500"));
        int maxDepth = Integer.parseInt(params.getProperty("smile.adaboost.max_depth", "20"));
        int maxNodes = Integer.parseInt(params.getProperty("smile.adaboost.max_nodes", "6"));
        int nodeSize = Integer.parseInt(params.getProperty("smile.adaboost.node_size", "1"));
        return AdaBoost.fit(formula, data, ntrees, maxDepth, maxNodes, nodeSize);
    }

    public static AdaBoost fit(Formula formula, DataFrame data, int ntrees, int maxDepth, int maxNodes, int nodeSize) {
        int i;
        if (ntrees < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + ntrees);
        }
        formula = formula.expand(data.schema());
        DataFrame x = formula.x(data);
        BaseVector y = formula.y(data);
        ClassLabels codec = ClassLabels.fit(y);
        int[][] order = CART.order(x);
        int k = codec.k;
        int n = data.size();
        int[] samples = new int[n];
        double[] w = new double[n];
        boolean[] wrong = new boolean[n];
        Arrays.fill(w, 1.0);
        double guess = 1.0 / (double)k;
        double b = Math.log(k - 1);
        int failures = 0;
        DecisionTree[] trees = new DecisionTree[ntrees];
        double[] alpha = new double[ntrees];
        double[] error = new double[ntrees];
        for (int t = 0; t < ntrees; ++t) {
            int[] rand;
            double W = MathEx.sum((double[])w);
            int i2 = 0;
            while (i2 < n) {
                int n2 = i2++;
                w[n2] = w[n2] / W;
            }
            Arrays.fill(samples, 0);
            int[] nArray = rand = MathEx.random((double[])w, (int)n);
            int n3 = nArray.length;
            for (int j = 0; j < n3; ++j) {
                int s;
                int n4 = s = nArray[j];
                samples[n4] = samples[n4] + 1;
            }
            trees[t] = new DecisionTree(x, codec.y, y.field(), k, SplitRule.GINI, maxDepth, maxNodes, nodeSize, -1, samples, order);
            for (int i3 = 0; i3 < n; ++i3) {
                wrong[i3] = trees[t].predict(x.get(i3)) != codec.y[i3];
            }
            double e = 0.0;
            for (i = 0; i < n; ++i) {
                if (!wrong[i]) continue;
                e += w[i];
            }
            logger.info("Training {} tree, weighted error = {}%", (Object)Strings.ordinal((int)(t + 1)), (Object)(100.0 * e));
            if (!(1.0 - e > guess)) {
                logger.error("Skip the weak classifier");
                if (++failures > 3) {
                    trees = Arrays.copyOf(trees, t);
                    alpha = Arrays.copyOf(alpha, t);
                    error = Arrays.copyOf(error, t);
                    break;
                }
                --t;
                continue;
            }
            failures = 0;
            error[t] = e;
            alpha[t] = Math.log((1.0 - e) / Math.max(1.0E-10, e)) + b;
            double a = Math.exp(alpha[t]);
            for (int i4 = 0; i4 < n; ++i4) {
                if (!wrong[i4]) continue;
                int n5 = i4;
                w[n5] = w[n5] * a;
            }
        }
        double[] importance = new double[x.ncol()];
        for (DecisionTree tree : trees) {
            double[] imp = tree.importance();
            for (i = 0; i < imp.length; ++i) {
                int n6 = i;
                importance[n6] = importance[n6] + imp[i];
            }
        }
        return new AdaBoost(formula, k, trees, alpha, error, importance, codec.classes);
    }

    @Override
    public Formula formula() {
        return this.formula;
    }

    @Override
    public StructType schema() {
        return this.trees[0].schema();
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.length;
    }

    public DecisionTree[] trees() {
        return this.trees;
    }

    public void trim(int ntrees) {
        if (ntrees > this.trees.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (ntrees <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + ntrees);
        }
        if (ntrees < this.trees.length) {
            this.trees = Arrays.copyOf(this.trees, ntrees);
            this.alpha = Arrays.copyOf(this.alpha, ntrees);
            this.error = Arrays.copyOf(this.error, ntrees);
        }
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(Tuple x) {
        Tuple xt = this.formula.x(x);
        double[] y = new double[this.k];
        for (int i = 0; i < this.trees.length; ++i) {
            int n = this.trees[i].predict(xt);
            y[n] = y[n] + this.alpha[i];
        }
        return this.classes.valueOf(MathEx.whichMax((double[])y));
    }

    @Override
    public int predict(Tuple x, double[] posteriori) {
        Tuple xt = this.formula.x(x);
        Arrays.fill(posteriori, 0.0);
        for (int i = 0; i < this.trees.length; ++i) {
            int n = this.trees[i].predict(xt);
            posteriori[n] = posteriori[n] + this.alpha[i];
        }
        double sum = MathEx.sum((double[])posteriori);
        int i = 0;
        while (i < this.k) {
            int n = i++;
            posteriori[n] = posteriori[n] / sum;
        }
        return this.classes.valueOf(MathEx.whichMax((double[])posteriori));
    }

    public int[][] test(DataFrame data) {
        DataFrame x = this.formula.x(data);
        int n = x.size();
        int ntrees = this.trees.length;
        int[][] prediction = new int[ntrees][n];
        if (this.k == 2) {
            for (int j = 0; j < n; ++j) {
                Tuple xj = x.get(j);
                double base = 0.0;
                for (int i = 0; i < ntrees; ++i) {
                    prediction[i][j] = (base += this.alpha[i] * (double)this.trees[i].predict(xj)) > 0.0 ? 1 : 0;
                }
            }
        } else {
            double[] p = new double[this.k];
            for (int j = 0; j < n; ++j) {
                Tuple xj = x.get(j);
                Arrays.fill(p, 0.0);
                for (int i = 0; i < ntrees; ++i) {
                    int n2 = this.trees[i].predict(xj);
                    p[n2] = p[n2] + this.alpha[i];
                    prediction[i][j] = MathEx.whichMax((double[])p);
                }
            }
        }
        return prediction;
    }
}

