/*
 * Decompiled with CFR 0.152.
 */
package smile.sequence;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.measure.Measure;
import smile.data.measure.NominalScale;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.data.vector.IntVector;
import smile.math.MathEx;
import smile.regression.RegressionTree;
import smile.sequence.Trellis;
import smile.util.Strings;

public class CRF
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(CRF.class);
    private final StructType schema;
    private final RegressionTree[][] potentials;
    private final double shrinkage;

    public CRF(StructType schema, RegressionTree[][] potentials, double shrinkage) {
        this.potentials = potentials;
        this.shrinkage = shrinkage;
        int k = potentials.length;
        NominalScale scale = new NominalScale((String[])IntStream.range(0, k + 1).mapToObj(String::valueOf).toArray(String[]::new));
        StructField field = new StructField("s(t-1)", (DataType)DataTypes.IntegerType, (Measure)scale);
        int length = schema.length();
        StructField[] fields = new StructField[length + 1];
        System.arraycopy(schema.fields(), 0, fields, 0, length);
        fields[length] = field;
        this.schema = new StructType(fields);
    }

    public int[] viterbi(Tuple[] x) {
        int n = x.length;
        int k = this.potentials.length;
        double[][] trellis = new double[n][k];
        int[][] psy = new int[n][k];
        double[] delta = new double[k];
        double[] t0 = trellis[0];
        int[] p0 = psy[0];
        Tuple x0 = this.extend(x[0], k);
        Tuple[] xt = new Tuple[k];
        for (int j = 0; j < k; ++j) {
            t0[j] = this.f(this.potentials[j], x0);
            p0[j] = 0;
        }
        for (int t = 1; t < n; ++t) {
            double[] tt = trellis[t];
            double[] tt1 = trellis[t - 1];
            int[] pt = psy[t];
            for (int j = 0; j < k; ++j) {
                xt[j] = this.extend(x[t], j);
            }
            for (int i = 0; i < k; ++i) {
                RegressionTree[] pi = this.potentials[i];
                for (int j = 0; j < k; ++j) {
                    delta[j] = this.f(pi, xt[j]) + tt1[j];
                }
                pt[i] = MathEx.whichMax((double[])delta);
                tt[i] = delta[pt[i]];
            }
        }
        int[] label = new int[n];
        label[n - 1] = MathEx.whichMax((double[])trellis[n - 1]);
        int t = n - 1;
        while (t-- > 0) {
            label[t] = psy[t + 1][label[t + 1]];
        }
        return label;
    }

    public int[] predict(Tuple[] x) {
        int n = x.length;
        int k = this.potentials.length;
        Trellis trellis = new Trellis(n, k);
        this.f(x, trellis);
        double[] scaling = new double[n];
        trellis.forward(scaling);
        trellis.backward();
        int[] label = new int[n];
        double[] p = new double[k];
        for (int i = 0; i < n; ++i) {
            Trellis.Cell[] ti = trellis.table[i];
            for (int j = 0; j < k; ++j) {
                Trellis.Cell tij = ti[j];
                p[j] = tij.alpha * tij.beta;
            }
            label[i] = MathEx.whichMax((double[])p);
        }
        return label;
    }

    private void f(Tuple[] x, Trellis trellis) {
        int n = x.length;
        int k = this.potentials.length;
        Tuple x0 = this.extend(x[0], k);
        Tuple[] xt = new Tuple[k];
        for (int i = 0; i < k; ++i) {
            trellis.table[0][i].expf[0] = this.f(this.potentials[i], x0);
        }
        for (int t = 1; t < n; ++t) {
            for (int j = 0; j < k; ++j) {
                xt[j] = this.extend(x[t], j);
            }
            for (int i = 0; i < k; ++i) {
                for (int j = 0; j < k; ++j) {
                    trellis.table[t][i].expf[j] = this.f(this.potentials[i], xt[j]);
                }
            }
        }
    }

    private double f(RegressionTree[] potential, Tuple x) {
        double F = 0.0;
        for (RegressionTree tree : potential) {
            F += this.shrinkage * tree.predict(x);
        }
        return Math.exp(F);
    }

    public static CRF fit(Tuple[][] sequences, int[][] labels) {
        return CRF.fit(sequences, labels, new Properties());
    }

    public static CRF fit(Tuple[][] sequences, int[][] labels, Properties params) {
        int ntrees = Integer.parseInt(params.getProperty("smile.crf.trees", "100"));
        int maxDepth = Integer.parseInt(params.getProperty("smile.crf.max_depth", "20"));
        int maxNodes = Integer.parseInt(params.getProperty("smile.crf.max_nodes", "100"));
        int nodeSize = Integer.parseInt(params.getProperty("smile.crf.node_size", "5"));
        double shrinkage = Double.parseDouble(params.getProperty("smile.crf.shrinkage", "1.0"));
        return CRF.fit(sequences, labels, ntrees, maxDepth, maxNodes, nodeSize, shrinkage);
    }

    public static CRF fit(Tuple[][] sequences, int[][] labels, int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage) {
        int k = MathEx.max((int[][])labels) + 1;
        double[][] scaling = new double[sequences.length][];
        Trellis[] trellis = new Trellis[sequences.length];
        for (int i = 0; i < sequences.length; ++i) {
            scaling[i] = new double[sequences[i].length];
            trellis[i] = new Trellis(sequences[i].length, k);
        }
        int n = Arrays.stream(sequences).mapToInt(s -> ((Tuple[])s).length).map(ni -> 1 + (ni - 1) * k).sum();
        ArrayList<Tuple> x = new ArrayList<Tuple>(n);
        int[] state = new int[n];
        int l = 0;
        for (int s2 = 0; s2 < sequences.length; ++s2) {
            Tuple[] sequence = sequences[s2];
            x.add(sequence[0]);
            state[l++] = k;
            for (int i = 1; i < sequence.length; ++i) {
                int j2 = 0;
                while (j2 < k) {
                    x.add(sequence[i]);
                    state[l++] = j2++;
                }
            }
        }
        NominalScale scale = new NominalScale((String[])IntStream.range(0, k + 1).mapToObj(String::valueOf).toArray(String[]::new));
        DataFrame data = DataFrame.of(x).merge(new BaseVector[]{IntVector.of((StructField)new StructField("s(t-1)", (DataType)DataTypes.IntegerType, (Measure)scale), (int[])state)});
        StructField field = new StructField("residual", (DataType)DataTypes.DoubleType);
        RegressionTree[][] potentials = new RegressionTree[k][ntrees];
        double[][] h = new double[k][n];
        double[][] response = new double[k][n];
        Loss[] loss = new Loss[k];
        for (int i = 0; i < k; ++i) {
            loss[i] = new PotentialLoss(response[i]);
        }
        int[] samples = new int[n];
        Arrays.fill(samples, 1);
        int[][] order = CART.order(data);
        for (int iter = 0; iter < ntrees; ++iter) {
            logger.info("Training {} tree", (Object)Strings.ordinal((int)(iter + 1)));
            IntStream.range(0, k).parallel().forEach(j -> {
                double[] f = h[j];
                int l = 0;
                for (int s = 0; s < sequences.length; ++s) {
                    Trellis grid = trellis[s];
                    grid.table[0][j].expf[0] = Math.exp(f[l++]);
                    for (int t = 1; t < grid.table.length; ++t) {
                        for (int i = 0; i < k; ++i) {
                            grid.table[t][j].expf[i] = Math.exp(f[l++]);
                        }
                    }
                }
            });
            IntStream.range(0, sequences.length).parallel().forEach(s -> {
                trellis[s].forward(scaling[s]);
                trellis[s].backward();
                trellis[s].gradient(scaling[s], labels[s]);
            });
            IntStream.range(0, k).parallel().forEach(j -> {
                double[] r = response[j];
                int l = 0;
                for (int s = 0; s < sequences.length; ++s) {
                    Trellis grid = trellis[s];
                    r[l++] = grid.table[0][j].residual[0];
                    for (int t = 1; t < grid.table.length; ++t) {
                        for (int i = 0; i < k; ++i) {
                            r[l++] = grid.table[t][j].residual[i];
                        }
                    }
                }
            });
            for (int j3 = 0; j3 < k; ++j3) {
                RegressionTree tree;
                potentials[j3][iter] = tree = new RegressionTree(data, loss[j3], field, maxDepth, maxNodes, nodeSize, data.ncol(), samples, order);
                double[] hj = h[j3];
                for (int i = 0; i < n; ++i) {
                    int n2 = i;
                    hj[n2] = hj[n2] + shrinkage * tree.predict(data.get(i));
                }
            }
        }
        return new CRF(sequences[0][0].schema(), potentials, shrinkage);
    }

    Tuple extend(final Tuple x, final int state) {
        return new Tuple(){

            public StructType schema() {
                return CRF.this.schema;
            }

            public Object get(int j) {
                return j == x.length() ? Integer.valueOf(state) : x.get(j);
            }

            public int getInt(int j) {
                return j == x.length() ? state : x.getInt(j);
            }
        };
    }

    record PotentialLoss(double[] response) implements Loss
    {
        @Override
        public double output(int[] nodeSamples, int[] sampleCount) {
            int n = 0;
            double output = 0.0;
            for (int i : nodeSamples) {
                n += sampleCount[i];
                output += this.response[i] * (double)sampleCount[i];
            }
            return output / (double)n;
        }

        @Override
        public double intercept(double[] $y) {
            return 0.0;
        }

        @Override
        public double[] residual() {
            throw new IllegalStateException();
        }
    }
}

