/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.mlp.Cost;
import smile.base.mlp.Layer;
import smile.base.mlp.LayerBuilder;
import smile.base.mlp.MultilayerPerceptron;
import smile.classification.Classifier;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.util.Strings;

public class MLP
extends MultilayerPerceptron
implements Classifier<double[]>,
Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(MLP.class);
    private final int k;
    private final IntSet classes;

    public MLP(LayerBuilder ... builders) {
        super(MLP.net(builders));
        int outSize = this.output.getOutputSize();
        this.k = outSize == 1 ? 2 : outSize;
        this.classes = IntSet.of((int)this.k);
    }

    public MLP(IntSet classes, LayerBuilder ... builders) {
        super(MLP.net(builders));
        int outSize = this.output.getOutputSize();
        this.k = outSize == 1 ? 2 : outSize;
        this.classes = classes;
    }

    private static Layer[] net(LayerBuilder ... builders) {
        int p = 0;
        int l = builders.length;
        Layer[] net = new Layer[l];
        for (int i = 0; i < l; ++i) {
            net[i] = builders[i].build(p);
            p = builders[i].neurons();
        }
        return net;
    }

    @Override
    public int numClasses() {
        return this.classes.size();
    }

    @Override
    public int[] classes() {
        return this.classes.values;
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        this.propagate(x, false);
        int n = this.output.getOutputSize();
        if (n == 1 && this.k == 2) {
            posteriori[1] = this.output.output()[0];
            posteriori[0] = 1.0 - posteriori[1];
        } else {
            System.arraycopy(this.output.output(), 0, posteriori, 0, n);
        }
        return this.classes.valueOf(MathEx.whichMax((double[])posteriori));
    }

    @Override
    public int predict(double[] x) {
        this.propagate(x, false);
        int n = this.output.getOutputSize();
        if (n == 1 && this.k == 2) {
            return this.classes.valueOf(this.output.output()[0] > 0.5 ? 1 : 0);
        }
        return this.classes.valueOf(MathEx.whichMax((double[])this.output.output()));
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public boolean online() {
        return true;
    }

    @Override
    public void update(double[] x, int y) {
        this.propagate(x, true);
        this.setTarget(this.classes.indexOf(y));
        this.backpropagate(true);
        ++this.t;
    }

    public void update(double[][] x, int[] y) {
        for (int i = 0; i < x.length; ++i) {
            this.propagate(x[i], true);
            this.setTarget(this.classes.indexOf(y[i]));
            this.backpropagate(false);
        }
        this.update(x.length);
        ++this.t;
    }

    private void setTarget(int y) {
        int n = this.output.getOutputSize();
        double t = this.output.cost() == Cost.LIKELIHOOD ? 1.0 : 0.9;
        double f = 1.0 - t;
        double[] target = (double[])this.target.get();
        if (n == 1) {
            target[0] = y == 1 ? t : f;
        } else {
            Arrays.fill(target, f);
            target[y] = t;
        }
    }

    public static MLP fit(double[][] x, int[] y, Properties params) {
        int p = x[0].length;
        int k = MathEx.max((int[])y) + 1;
        LayerBuilder[] layers = Layer.of(k, p, params.getProperty("smile.mlp.layers", "ReLU(100)"));
        MLP model = new MLP(layers);
        model.setParameters(params);
        int epochs = Integer.parseInt(params.getProperty("smile.mlp.epochs", "100"));
        int batch = Integer.parseInt(params.getProperty("smile.mlp.mini_batch", "32"));
        double[][] batchx = new double[batch][];
        int[] batchy = new int[batch];
        for (int epoch = 1; epoch <= epochs; ++epoch) {
            logger.info("{} epoch", (Object)Strings.ordinal((int)epoch));
            int[] permutation = MathEx.permutate((int)x.length);
            for (int i = 0; i < x.length; i += batch) {
                int size = Math.min(batch, x.length - i);
                for (int j = 0; j < size; ++j) {
                    int index = permutation[i + j];
                    batchx[j] = x[index];
                    batchy[j] = y[index];
                }
                if (size < batch) {
                    model.update((double[][])Arrays.copyOf(batchx, size), Arrays.copyOf(batchy, size));
                    continue;
                }
                model.update(batchx, batchy);
            }
        }
        return model;
    }
}

