/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import smile.base.rbf.RBF;
import smile.classification.ClassLabels;
import smile.classification.Classifier;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.util.IntSet;

public class RBFNetwork<T>
implements Classifier<T> {
    private static final long serialVersionUID = 2L;
    private int k;
    private Matrix w;
    private RBF<T>[] rbf;
    private boolean normalized;
    private IntSet labels;

    public RBFNetwork(int k, RBF<T>[] rbf, Matrix w, boolean normalized) {
        this(k, rbf, w, normalized, IntSet.of((int)k));
    }

    public RBFNetwork(int k, RBF<T>[] rbf, Matrix w, boolean normalized, IntSet labels) {
        this.k = k;
        this.rbf = rbf;
        this.w = w;
        this.normalized = normalized;
        this.labels = labels;
    }

    public static <T> RBFNetwork<T> fit(T[] x, int[] y, RBF<T>[] rbf) {
        return RBFNetwork.fit(x, y, rbf, false);
    }

    public static <T> RBFNetwork<T> fit(T[] x, int[] y, RBF<T>[] rbf, boolean normalized) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        int n = x.length;
        int m = rbf.length;
        Matrix G = new Matrix(n, m + 1);
        Matrix b = new Matrix(n, k);
        for (int i = 0; i < n; ++i) {
            double sum = 0.0;
            for (int j = 0; j < m; ++j) {
                double r = rbf[j].f(x[i]);
                G.set(i, j, r);
                sum += r;
            }
            G.set(i, m, 1.0);
            if (normalized) {
                b.set(i, codec.y[i], sum);
                continue;
            }
            b.set(i, codec.y[i], 1.0);
        }
        Matrix.QR qr = G.qr(true);
        qr.solve(b);
        return new RBFNetwork<T>(k, rbf, b.submatrix(0, 0, m, k - 1), normalized, codec.labels);
    }

    public boolean isNormalized() {
        return this.normalized;
    }

    @Override
    public int predict(T x) {
        int m = this.rbf.length;
        double[] f = new double[m + 1];
        f[m] = 1.0;
        for (int i = 0; i < m; ++i) {
            f[i] = this.rbf[i].f(x);
        }
        double[] sumw = new double[this.k];
        this.w.tv(f, sumw);
        return this.labels.valueOf(MathEx.whichMax((double[])sumw));
    }
}

