/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import smile.base.mlp.Cost;
import smile.base.mlp.Layer;
import smile.base.mlp.LayerBuilder;
import smile.base.mlp.MultilayerPerceptron;
import smile.classification.OnlineClassifier;
import smile.classification.SoftClassifier;
import smile.math.MathEx;
import smile.util.IntSet;

public class MLP
extends MultilayerPerceptron
implements OnlineClassifier<double[]>,
SoftClassifier<double[]>,
Serializable {
    private static final long serialVersionUID = 2L;
    private int k;
    private IntSet labels;

    public MLP(int p, LayerBuilder ... builders) {
        super(MLP.net(p, builders));
        this.k = this.output.getOutputSize();
        if (this.k == 1) {
            this.k = 2;
        }
        this.labels = IntSet.of((int)this.k);
    }

    public MLP(IntSet labels, int p, LayerBuilder ... builders) {
        super(MLP.net(p, builders));
        this.k = this.output.getOutputSize();
        if (this.k == 1) {
            this.k = 2;
        }
        this.labels = labels;
    }

    private static Layer[] net(int p, LayerBuilder ... builders) {
        int l = builders.length;
        Layer[] net = new Layer[l];
        for (int i = 0; i < l; ++i) {
            net[i] = builders[i].build(p);
            p = builders[i].neurons();
        }
        return net;
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        this.propagate(x);
        int n = this.output.getOutputSize();
        if (n == 1 && this.k == 2) {
            posteriori[1] = this.output.output()[0];
            posteriori[0] = 1.0 - posteriori[1];
        } else {
            System.arraycopy(this.output.output(), 0, posteriori, 0, n);
        }
        return this.labels.valueOf(MathEx.whichMax((double[])posteriori));
    }

    @Override
    public int predict(double[] x) {
        this.propagate(x);
        int n = this.output.getOutputSize();
        if (n == 1 && this.k == 2) {
            return this.labels.valueOf(this.output.output()[0] > 0.5 ? 1 : 0);
        }
        return this.labels.valueOf(MathEx.whichMax((double[])this.output.output()));
    }

    @Override
    public void update(double[] x, int y) {
        this.propagate(x);
        this.setTarget(this.labels.indexOf(y));
        this.backpropagate(x, true);
        ++this.t;
    }

    public void update(double[][] x, int[] y) {
        for (int i = 0; i < x.length; ++i) {
            this.propagate(x[i]);
            this.setTarget(this.labels.indexOf(y[i]));
            this.backpropagate(x[i], false);
        }
        this.update(x.length);
        ++this.t;
    }

    private void setTarget(int y) {
        int n = this.output.getOutputSize();
        double t = this.output.cost() == Cost.LIKELIHOOD ? 1.0 : 0.9;
        double f = 1.0 - t;
        double[] target = (double[])this.target.get();
        if (n == 1) {
            target[0] = y == 1 ? t : f;
        } else {
            Arrays.fill(target, f);
            target[y] = t;
        }
    }
}

