/*
 * Decompiled with CFR 0.152.
 */
package smile.feature.extraction;

import smile.data.DataFrame;
import smile.data.Tuple;
import smile.feature.extraction.Projection;
import smile.math.MathEx;
import smile.math.TimeFunction;
import smile.math.matrix.Matrix;

public class GHA
extends Projection {
    private static final long serialVersionUID = 2L;
    private final int p;
    private final int n;
    private final TimeFunction r;
    private final double[] y;
    private final double[] wy;
    protected int t = 0;

    public GHA(int n, int p, TimeFunction r, String ... columns) {
        super(new Matrix(p, n), "GHA", columns);
        if (n < 2) {
            throw new IllegalArgumentException("Invalid dimension of input space: " + n);
        }
        if (p < 1 || p > n) {
            throw new IllegalArgumentException("Invalid dimension of feature space: " + p);
        }
        this.n = n;
        this.p = p;
        this.r = r;
        this.y = new double[p];
        this.wy = new double[n];
        for (int i = 0; i < p; ++i) {
            for (int j = 0; j < n; ++j) {
                this.projection.set(i, j, 0.1 * MathEx.random());
            }
        }
    }

    public GHA(double[][] w, TimeFunction r, String ... columns) {
        super(Matrix.of((double[][])w), "GHA", columns);
        this.p = w.length;
        this.n = w[0].length;
        this.r = r;
        this.y = new double[this.p];
        this.wy = new double[this.n];
    }

    public double update(double[] x) {
        if (x.length != this.n) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.n));
        }
        this.projection.mv(x, this.y);
        for (int j = 0; j < this.p; ++j) {
            for (int i = 0; i < this.n; ++i) {
                double delta = x[i];
                for (int l = 0; l <= j; ++l) {
                    delta -= this.projection.get(l, i) * this.y[l];
                }
                this.projection.add(j, i, this.r.apply(this.t) * this.y[j] * delta);
                if (!Double.isInfinite(this.projection.get(j, i))) continue;
                throw new IllegalStateException("GHA lost convergence. Lower learning rate?");
            }
        }
        ++this.t;
        this.projection.mv(x, this.y);
        this.projection.tv(this.y, this.wy);
        return MathEx.squaredDistance((double[])x, (double[])this.wy);
    }

    public double update(Tuple x) {
        return this.update(x.toArray(this.columns));
    }

    public void update(double[][] data) {
        for (double[] x : data) {
            this.update(x);
        }
    }

    public void update(DataFrame data) {
        this.update(data.toArray(this.columns));
    }
}

