/*
 * Decompiled with CFR 0.152.
 */
package smile.base.svm;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;
import smile.regression.KernelMachine;

public class SVR<T> {
    private static final Logger logger = LoggerFactory.getLogger(SVR.class);
    private static final double TAU = 1.0E-12;
    private final MercerKernel<T> kernel;
    private final double eps;
    private final double C;
    private final double tol;
    private List<SupportVector> vectors;
    private double b = 0.0;
    private SupportVector svmin = null;
    private SupportVector svmax = null;
    private double gmin = Double.MAX_VALUE;
    private double gmax = -1.7976931348623157E308;
    private int gminindex;
    private int gmaxindex;
    private double[][] K;

    public SVR(MercerKernel<T> kernel, double eps, double C, double tol) {
        if (eps <= 0.0) {
            throw new IllegalArgumentException("Invalid error threshold: " + eps);
        }
        if (C < 0.0) {
            throw new IllegalArgumentException("Invalid soft margin penalty: " + C);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance of convergence test:" + tol);
        }
        this.kernel = kernel;
        this.eps = eps;
        this.C = C;
        this.tol = tol;
    }

    public KernelMachine<T> fit(T[] x, double[] y) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        int n = x.length;
        this.K = new double[n][];
        this.vectors = new ArrayList<SupportVector>(n);
        for (int i = 0; i < n; ++i) {
            this.vectors.add(new SupportVector(this, i, x[i], y[i]));
        }
        this.minmax();
        int phase = Math.min(n, 1000);
        int count = 1;
        while (this.smo(this.tol)) {
            if (count % phase == 0) {
                logger.info("{} SMO iterations", (Object)count);
            }
            ++count;
        }
        int nsv = 0;
        int bsv = 0;
        for (int i = 0; i < n; ++i) {
            SupportVector v = this.vectors.get(i);
            if (v.alpha[0] == v.alpha[1]) {
                this.vectors.set(i, null);
                continue;
            }
            ++nsv;
            if (v.alpha[0] != this.C && v.alpha[1] != this.C) continue;
            ++bsv;
        }
        double[] alpha = new double[nsv];
        Object[] sv = (Object[])Array.newInstance(x.getClass().getComponentType(), nsv);
        int i = 0;
        for (SupportVector v : this.vectors) {
            if (v == null) continue;
            sv[i] = v.x;
            alpha[i++] = v.alpha[1] - v.alpha[0];
        }
        logger.info("{} samples, {} support vectors, {} bounded", new Object[]{n, nsv, bsv});
        return new KernelMachine<Object>(this.kernel, sv, alpha, this.b);
    }

    private void minmax() {
        this.gmin = Double.MAX_VALUE;
        this.gmax = -1.7976931348623157E308;
        for (SupportVector v : this.vectors) {
            double g = -v.g[0];
            double a = v.alpha[0];
            if (g < this.gmin && a > 0.0) {
                this.svmin = v;
                this.gmin = g;
                this.gminindex = 0;
            }
            if (g > this.gmax && a < this.C) {
                this.svmax = v;
                this.gmax = g;
                this.gmaxindex = 0;
            }
            g = v.g[1];
            a = v.alpha[1];
            if (g < this.gmin && a < this.C) {
                this.svmin = v;
                this.gmin = g;
                this.gminindex = 1;
            }
            if (!(g > this.gmax) || !(a > 0.0)) continue;
            this.svmax = v;
            this.gmax = g;
            this.gmaxindex = 1;
        }
    }

    private double[] gram(SupportVector v) {
        if (this.K[v.i] == null) {
            double[] ki = new double[this.vectors.size()];
            ((Stream)this.vectors.stream().parallel()).forEach(vi -> {
                ki[vi.i] = this.kernel.k(v.x, vi.x);
            });
            this.K[v.i] = ki;
        }
        return this.K[v.i];
    }

    private boolean smo(double epsgr) {
        SupportVector v1 = this.svmax;
        int i = this.gmaxindex;
        double old_alpha_i = v1.alpha[i];
        double[] k1 = this.gram(v1);
        SupportVector v2 = this.svmin;
        int j = this.gminindex;
        double old_alpha_j = v2.alpha[j];
        double best = 0.0;
        double gi = i == 0 ? -v1.g[0] : v1.g[1];
        for (SupportVector v : this.vectors) {
            double gain;
            double curv = v1.k + v.k - 2.0 * k1[v.i];
            if (curv <= 0.0) {
                curv = 1.0E-12;
            }
            double gj = -v.g[0];
            if (v.alpha[0] > 0.0 && gj < gi && (gain = -MathEx.pow2((double)(gi - gj)) / curv) < best) {
                best = gain;
                v2 = v;
                j = 0;
                old_alpha_j = v2.alpha[0];
            }
            gj = v.g[1];
            if (!(v.alpha[1] < this.C) || !(gj < gi) || !((gain = -MathEx.pow2((double)(gi - gj)) / curv) < best)) continue;
            best = gain;
            v2 = v;
            j = 1;
            old_alpha_j = v2.alpha[1];
        }
        double[] k2 = this.gram(v2);
        double curv = v1.k + v2.k - 2.0 * k1[v2.i];
        if (curv <= 0.0) {
            curv = 1.0E-12;
        }
        if (i != j) {
            delta = (-v1.g[i] - v2.g[j]) / curv;
            double diff = v1.alpha[i] - v2.alpha[j];
            int n = i;
            v1.alpha[n] = v1.alpha[n] + delta;
            int n2 = j;
            v2.alpha[n2] = v2.alpha[n2] + delta;
            if (diff > 0.0) {
                if (v2.alpha[j] < 0.0) {
                    v2.alpha[j] = 0.0;
                    v1.alpha[i] = diff;
                }
            } else if (v1.alpha[i] < 0.0) {
                v1.alpha[i] = 0.0;
                v2.alpha[j] = -diff;
            }
            if (diff > 0.0) {
                if (v1.alpha[i] > this.C) {
                    v1.alpha[i] = this.C;
                    v2.alpha[j] = this.C - diff;
                }
            } else if (v2.alpha[j] > this.C) {
                v2.alpha[j] = this.C;
                v1.alpha[i] = this.C + diff;
            }
        } else {
            delta = (v1.g[i] - v2.g[j]) / curv;
            double sum = v1.alpha[i] + v2.alpha[j];
            int n = i;
            v1.alpha[n] = v1.alpha[n] - delta;
            int n3 = j;
            v2.alpha[n3] = v2.alpha[n3] + delta;
            if (sum > this.C) {
                if (v1.alpha[i] > this.C) {
                    v1.alpha[i] = this.C;
                    v2.alpha[j] = sum - this.C;
                }
            } else if (v2.alpha[j] < 0.0) {
                v2.alpha[j] = 0.0;
                v1.alpha[i] = sum;
            }
            if (sum > this.C) {
                if (v2.alpha[j] > this.C) {
                    v2.alpha[j] = this.C;
                    v1.alpha[i] = sum - this.C;
                }
            } else if (v1.alpha[i] < 0.0) {
                v1.alpha[i] = 0.0;
                v2.alpha[j] = sum;
            }
        }
        double delta_alpha_i = v1.alpha[i] - old_alpha_i;
        double delta_alpha_j = v2.alpha[j] - old_alpha_j;
        int si = 2 * i - 1;
        int sj = 2 * j - 1;
        for (SupportVector v : this.vectors) {
            v.g[0] = v.g[0] - ((double)si * k1[v.i] * delta_alpha_i + (double)sj * k2[v.i] * delta_alpha_j);
            v.g[1] = v.g[1] + ((double)si * k1[v.i] * delta_alpha_i + (double)sj * k2[v.i] * delta_alpha_j);
        }
        this.minmax();
        this.b = -(this.gmax + this.gmin) / 2.0;
        return this.gmax - this.gmin > epsgr;
    }

    class SupportVector {
        final int i;
        final T x;
        final double[] alpha = new double[2];
        final double[] g = new double[2];
        final double k;

        SupportVector(SVR this$0, int i, T x, double y) {
            this.i = i;
            this.x = x;
            this.g[0] = this$0.eps + y;
            this.g[1] = this$0.eps - y;
            this.k = this$0.kernel.k(x, x);
        }
    }
}

