/*
 * Decompiled with CFR 0.152.
 */
package smile.base.cart;

import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import smile.base.cart.InternalNode;
import smile.base.cart.LeafNode;
import smile.base.cart.SplitRule;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.math.MathEx;

public class DecisionNode
extends LeafNode {
    private static final long serialVersionUID = 2L;
    private final int output;
    private final int[] count;

    public DecisionNode(int[] count) {
        super((int)MathEx.sum((int[])count));
        this.output = MathEx.whichMax((int[])count);
        this.count = count;
    }

    public int output() {
        return this.output;
    }

    public int[] count() {
        return this.count;
    }

    @Override
    public double deviance() {
        return DecisionNode.deviance(this.count, DecisionNode.posteriori(this.count, new double[this.count.length]));
    }

    @Override
    public String dot(StructType schema, StructField response, int id) {
        return String.format(" %d [label=<%s = %s<br/>size = %d<br/>deviance = %.4f>, fillcolor=\"#00000000\", shape=ellipse];\n", id, response.name, response.toString((Object)this.output), this.size, this.deviance());
    }

    @Override
    public int[] toString(StructType schema, StructField response, InternalNode parent, int depth, BigInteger id, List<String> lines) {
        StringBuilder line = new StringBuilder();
        line.append(" ".repeat(depth));
        line.append(id).append(") ");
        line.append(parent == null ? "root" : parent.toString(schema, this == parent.trueChild)).append(" ");
        line.append(this.size).append(" ");
        double[] prob = DecisionNode.posteriori(this.count, new double[this.count.length]);
        line.append(String.format("%.5g", DecisionNode.deviance(this.count, prob))).append(" ");
        line.append(response.toString((Object)this.output)).append(" ");
        line.append(Arrays.stream(prob).mapToObj(p -> String.format("%.5g", p)).collect(Collectors.joining(" ", "(", ")")));
        line.append(" *");
        lines.add(line.toString());
        return this.count;
    }

    public double impurity(SplitRule rule) {
        return DecisionNode.impurity(rule, this.size, this.count);
    }

    public static double impurity(SplitRule rule, int size, int[] count) {
        double impurity = 0.0;
        switch (rule) {
            case GINI: {
                double squared_sum = 0.0;
                for (int c : count) {
                    if (c <= 0) continue;
                    squared_sum += (double)c * (double)c;
                }
                impurity = 1.0 - squared_sum / ((double)size * (double)size);
                break;
            }
            case ENTROPY: {
                for (int c : count) {
                    if (c <= 0) continue;
                    double p = (double)c / (double)size;
                    impurity -= p * MathEx.log2((double)p);
                }
                break;
            }
            case CLASSIFICATION_ERROR: {
                impurity = Math.abs(1.0 - (double)MathEx.max((int[])count) / (double)size);
            }
        }
        return impurity;
    }

    public boolean equals(Object o) {
        if (o instanceof DecisionNode) {
            DecisionNode a = (DecisionNode)o;
            return this.output == a.output;
        }
        return false;
    }

    public double[] posteriori(double[] prob) {
        return DecisionNode.posteriori(this.count, prob);
    }

    public static double[] posteriori(int[] count, double[] prob) {
        int k = count.length;
        double n = MathEx.sum((int[])count) + (long)k;
        for (int i = 0; i < k; ++i) {
            prob[i] = (double)(count[i] + 1) / n;
        }
        return prob;
    }

    public static double deviance(int[] count, double[] prob) {
        int k = count.length;
        double d = 0.0;
        for (int i = 0; i < k; ++i) {
            d -= (double)count[i] * Math.log(prob[i]);
        }
        return 2.0 * d;
    }
}

