/*
 * Decompiled with CFR 0.152.
 */
package smile.base.svm;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.svm.KernelMachine;
import smile.base.svm.SupportVector;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;

public class LASVM<T>
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(LASVM.class);
    private static final double TAU = 1.0E-12;
    private final MercerKernel<T> kernel;
    private final double Cp;
    private final double Cn;
    private final double tol;
    private final ArrayList<SupportVector<T>> vectors = new ArrayList();
    private double b = 0.0;
    private boolean minmaxflag = false;
    private SupportVector<T> svmin = null;
    private SupportVector<T> svmax = null;
    private double gmin = Double.MAX_VALUE;
    private double gmax = -1.7976931348623157E308;
    private T[] x;
    private double[][] K;

    public LASVM(MercerKernel<T> kernel, double C, double tol) {
        this(kernel, C, C, tol);
    }

    public LASVM(MercerKernel<T> kernel, double Cp, double Cn, double tol) {
        if (Cp < 0.0) {
            throw new IllegalArgumentException("Invalid C: " + Cp);
        }
        if (Cn < 0.0) {
            throw new IllegalArgumentException("Invalid C: " + Cn);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tol: " + tol);
        }
        this.kernel = kernel;
        this.Cp = Cp;
        this.Cn = Cn;
        this.tol = tol;
    }

    public KernelMachine<T> fit(T[] x, int[] y, int epochs) {
        this.x = x;
        this.K = new double[x.length][];
        this.init(x, y);
        int phase = Math.min(x.length, 1000);
        int iter = 0;
        for (int epoch = 0; epoch < epochs; ++epoch) {
            for (int i : MathEx.permutate((int)x.length)) {
                this.process(i, x[i], y[i]);
                do {
                    this.reprocess(this.tol);
                    this.minmax();
                } while (this.gmax - this.gmin > 1000.0);
                if (++iter % phase != 0) continue;
                logger.info("{} iterations, {} support vectors", (Object)iter, (Object)this.vectors.size());
            }
        }
        this.finish();
        int n = this.vectors.size();
        Object[] sv = (Object[])Array.newInstance(x.getClass().getComponentType(), n);
        double[] alpha = new double[n];
        for (int i = 0; i < n; ++i) {
            SupportVector<T> v = this.vectors.get(i);
            sv[i] = v.x;
            alpha[i] = v.alpha;
        }
        return new KernelMachine<Object>(this.kernel, sv, alpha, this.b);
    }

    private void init(T[] x, int[] y) {
        int few = 5;
        int cp = 0;
        int cn = 0;
        for (int i : MathEx.permutate((int)x.length)) {
            if (y[i] == 1 && cp < few) {
                if (this.process(i, x[i], y[i])) {
                    ++cp;
                }
            } else if (y[i] == -1 && cn < few && this.process(i, x[i], y[i])) {
                ++cn;
            }
            if (cp >= few && cn >= few) break;
        }
    }

    private void minmax() {
        if (this.minmaxflag) {
            return;
        }
        this.gmin = Double.MAX_VALUE;
        this.gmax = -1.7976931348623157E308;
        for (SupportVector<T> v : this.vectors) {
            double gi = v.g;
            double ai = v.alpha;
            if (gi < this.gmin && ai > v.cmin) {
                this.svmin = v;
                this.gmin = gi;
            }
            if (!(gi > this.gmax) || !(ai < v.cmax)) continue;
            this.svmax = v;
            this.gmax = gi;
        }
        this.minmaxflag = true;
    }

    private double k(int i, int j) {
        double k = Double.NaN;
        double[] ki = this.K[i];
        if (ki != null) {
            k = ki[j];
        }
        if (Double.isNaN(k)) {
            k = this.kernel.k(this.x[i], this.x[j]);
            if (ki != null) {
                ki[j] = k;
            }
        }
        return k;
    }

    private boolean smo(SupportVector<T> v1, SupportVector<T> v2, double epsgr) {
        double gain;
        double mu;
        double curv;
        double k;
        double Z;
        double best;
        double gm;
        double km;
        if (v1 == null && v2 == null) {
            this.minmax();
            if (this.gmax > -this.gmin) {
                v2 = this.svmax;
            } else {
                v1 = this.svmin;
            }
        }
        double k12 = Double.NaN;
        if (v2 == null) {
            assert (v1 != null);
            km = v1.k;
            gm = v1.g;
            best = 0.0;
            for (SupportVector<T> v : this.vectors) {
                Z = v.g - gm;
                k = this.k(v1.i, v.i);
                curv = km + v.k - 2.0 * k;
                if (curv <= 0.0) {
                    curv = 1.0E-12;
                }
                if (!((mu = Z / curv) > 0.0 && v.alpha < v.cmax) && (!(mu < 0.0) || !(v.alpha > v.cmin)) || !((gain = Z * mu) > best)) continue;
                best = gain;
                v2 = v;
                k12 = k;
            }
        }
        if (v1 == null) {
            km = v2.k;
            gm = v2.g;
            best = 0.0;
            for (SupportVector<T> v : this.vectors) {
                Z = gm - v.g;
                k = this.k(v2.i, v.i);
                curv = km + v.k - 2.0 * k;
                if (curv <= 0.0) {
                    curv = 1.0E-12;
                }
                if (!((mu = Z / curv) > 0.0 && v.alpha > v.cmin) && (!(mu < 0.0) || !(v.alpha < v.cmax)) || !((gain = Z * mu) > best)) continue;
                best = gain;
                v1 = v;
                k12 = k;
            }
        }
        if (v1 == null || v2 == null) {
            return false;
        }
        if (Double.isNaN(k12)) {
            k12 = this.kernel.k(v1.x, v2.x);
        }
        double step = this.getStep(v1, v2, k12);
        v1.alpha -= step;
        v2.alpha += step;
        for (SupportVector<T> v : this.vectors) {
            v.g -= step * (this.k(v2.i, v.i) - this.k(v1.i, v.i));
        }
        this.minmaxflag = false;
        this.minmax();
        this.b = (this.gmax + this.gmin) / 2.0;
        return this.gmax - this.gmin > epsgr;
    }

    private double getStep(SupportVector<T> v1, SupportVector<T> v2, double k12) {
        double step;
        double curv = v1.k + v2.k - 2.0 * k12;
        if (curv <= 0.0) {
            curv = 1.0E-12;
        }
        if ((step = (v2.g - v1.g) / curv) >= 0.0) {
            double delta = v1.alpha - v1.cmin;
            if (delta < step) {
                step = delta;
            }
            if ((delta = v2.cmax - v2.alpha) < step) {
                step = delta;
            }
        } else {
            double delta = v2.cmin - v2.alpha;
            if (delta > step) {
                step = delta;
            }
            if ((delta = v1.alpha - v1.cmax) > step) {
                step = delta;
            }
        }
        return step;
    }

    private boolean process(int i, T x, int y) {
        if (y != 1 && y != -1) {
            throw new IllegalArgumentException("Invalid label: " + y);
        }
        for (SupportVector<T> v : this.vectors) {
            if (v.x != x) continue;
            return false;
        }
        double[] cache = new double[this.K.length];
        Arrays.fill(cache, Double.NaN);
        double g = y;
        for (SupportVector<T> v : this.vectors) {
            double k;
            cache[v.i] = k = this.kernel.k(v.x, x);
            g -= v.alpha * k;
        }
        this.minmax();
        if (this.gmin < this.gmax && (y > 0 && g < this.gmin || y < 0 && g > this.gmax)) {
            return false;
        }
        SupportVector<T> v = new SupportVector<T>(i, x, y, 0.0, g, this.Cp, this.Cn, this.kernel.k(x, x));
        this.vectors.add(v);
        this.K[i] = cache;
        if (y > 0) {
            this.smo(null, v, 0.0);
        } else {
            this.smo(v, null, 0.0);
        }
        this.minmaxflag = false;
        return true;
    }

    private boolean reprocess(double epsgr) {
        boolean status = this.smo(null, null, epsgr);
        this.evict();
        return status;
    }

    private void finish() {
        this.finish(this.tol, this.vectors.size());
        int bsv = 0;
        for (SupportVector<T> v : this.vectors) {
            if (v.alpha != v.cmin && v.alpha != v.cmax) continue;
            ++bsv;
        }
        logger.info("{} samples, {} support vectors, {} bounded", new Object[]{this.x.length, this.vectors.size(), bsv});
    }

    private void finish(double epsgr, int maxIter) {
        logger.info("Finalizing the training by reprocess.");
        for (int count = 1; count <= maxIter && this.smo(null, null, epsgr); ++count) {
            if (count % 1000 != 0) continue;
            logger.info("{} reprocess iterations.", (Object)count);
        }
        this.evict();
    }

    private void evict() {
        this.minmax();
        this.vectors.removeIf(v -> {
            if (MathEx.isZero((double)v.alpha, (double)1.0E-4) && (v.g >= this.gmax && 0.0 >= v.cmax || v.g <= this.gmin && 0.0 <= v.cmin)) {
                this.K[v.i] = null;
                return true;
            }
            return false;
        });
    }
}

