/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import com.google.common.util.concurrent.AtomicDouble;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.SpTree;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.plot.Tsne;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BarnesHutTsne
implements Model {
    private static final Logger log = LoggerFactory.getLogger(BarnesHutTsne.class);
    public static final String workspaceCache = "LOOP_CACHE";
    public static final String workspaceExternal = "LOOP_EXTERNAL";
    protected int maxIter = 1000;
    protected double realMin = Nd4j.EPS_THRESHOLD;
    protected double initialMomentum = 0.5;
    protected double finalMomentum = 0.8;
    protected double minGain = 0.01;
    protected double momentum = this.initialMomentum;
    protected int switchMomentumIteration = 100;
    protected boolean normalize = true;
    protected boolean usePca = false;
    protected int stopLyingIteration = 250;
    protected double tolerance = 1.0E-5;
    protected double learningRate = 500.0;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad = true;
    protected double perplexity = 30.0;
    protected INDArray Y;
    private int N;
    private double theta;
    private INDArray rows;
    private INDArray cols;
    private INDArray vals;
    private String simiarlityFunction = "cosinesimilarity";
    private boolean invert = true;
    private INDArray x;
    private int numDimensions = 0;
    public static final String Y_GRAD = "yIncs";
    private SpTree tree;
    private INDArray gains;
    private INDArray yIncs;
    private int vpTreeWorkers;
    protected transient TrainingListener trainingListener;
    protected WorkspaceMode workspaceMode;
    protected static final WorkspaceConfiguration workspaceConfigurationExternal = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.3).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    protected WorkspaceConfiguration workspaceConfigurationFeedForward = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).policyLearning(LearningPolicy.OVER_TIME).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    public static final WorkspaceConfiguration workspaceConfigurationCache = WorkspaceConfiguration.builder().overallocationLimit(0.2).policyReset(ResetPolicy.BLOCK_LEFT).cyclesBeforeInitialization(3).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();

    public BarnesHutTsne(int numDimensions, String simiarlityFunction, double theta, boolean invert, int maxIter, double realMin, double initialMomentum, double finalMomentum, double momentum, int switchMomentumIteration, boolean normalize, int stopLyingIteration, double tolerance, double learningRate, boolean useAdaGrad, double perplexity, TrainingListener TrainingListener2, double minGain, int vpTreeWorkers) {
        this(numDimensions, simiarlityFunction, theta, invert, maxIter, realMin, initialMomentum, finalMomentum, momentum, switchMomentumIteration, normalize, stopLyingIteration, tolerance, learningRate, useAdaGrad, perplexity, TrainingListener2, minGain, vpTreeWorkers, WorkspaceMode.NONE);
    }

    public BarnesHutTsne(int numDimensions, String simiarlityFunction, double theta, boolean invert, int maxIter, double realMin, double initialMomentum, double finalMomentum, double momentum, int switchMomentumIteration, boolean normalize, int stopLyingIteration, double tolerance, double learningRate, boolean useAdaGrad, double perplexity, TrainingListener TrainingListener2, double minGain, int vpTreeWorkers, WorkspaceMode workspaceMode) {
        this.maxIter = maxIter;
        this.realMin = realMin;
        this.initialMomentum = initialMomentum;
        this.finalMomentum = finalMomentum;
        this.momentum = momentum;
        this.normalize = normalize;
        this.useAdaGrad = useAdaGrad;
        this.stopLyingIteration = stopLyingIteration;
        this.learningRate = learningRate;
        this.switchMomentumIteration = switchMomentumIteration;
        this.tolerance = tolerance;
        this.perplexity = perplexity;
        this.minGain = minGain;
        this.numDimensions = numDimensions;
        this.simiarlityFunction = simiarlityFunction;
        this.theta = theta;
        this.trainingListener = TrainingListener2;
        this.invert = invert;
        this.vpTreeWorkers = vpTreeWorkers;
        this.workspaceMode = workspaceMode;
        if (this.workspaceMode == null) {
            this.workspaceMode = WorkspaceMode.NONE;
        }
    }

    public String getSimiarlityFunction() {
        return this.simiarlityFunction;
    }

    public void setSimiarlityFunction(String simiarlityFunction) {
        this.simiarlityFunction = simiarlityFunction;
    }

    public boolean isInvert() {
        return this.invert;
    }

    public void setInvert(boolean invert) {
        this.invert = invert;
    }

    public double getTheta() {
        return this.theta;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public int getNumDimensions() {
        return this.numDimensions;
    }

    public void setNumDimensions(int numDimensions) {
        this.numDimensions = numDimensions;
    }

    public INDArray computeGaussianPerplexity(INDArray d, double u) {
        this.N = d.rows();
        int k = (int)(3.0 * u);
        if (u > (double)k) {
            throw new IllegalStateException("Illegal k value " + k + "greater than " + u);
        }
        this.rows = Nd4j.zeros((long)1L, (long)(this.N + 1));
        this.cols = Nd4j.zeros((long)1L, (long)(this.N * k));
        this.vals = Nd4j.zeros((long)1L, (long)(this.N * k));
        for (int n = 0; n < this.N; ++n) {
            this.rows.putScalar((long)(n + 1), this.rows.getDouble((long)n) + (double)k);
        }
        INDArray beta = Nd4j.ones((int)this.N, (int)1);
        double logU = FastMath.log((double)u);
        VPTree tree = new VPTree(d, this.simiarlityFunction, this.vpTreeWorkers, this.invert);
        DummyWorkspace workspace = this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
            log.info("Calculating probabilities of data similarities...");
            for (int i = 0; i < this.N; ++i) {
                if (i % 500 == 0) {
                    log.info("Handled " + i + " records");
                }
                double betaMin = -1.7976931348623157E308;
                double betaMax = Double.MAX_VALUE;
                ArrayList results = new ArrayList();
                tree.search(d.slice((long)i), k + 1, results, new ArrayList());
                double betas = beta.getDouble((long)i);
                if (results.size() == 0) {
                    throw new IllegalStateException("Search returned no values for vector " + i + " - similarity \"" + this.simiarlityFunction + "\" may not be defined (for example, vector is all zeros with cosine similarity)");
                }
                INDArray cArr = VPTree.buildFromData(results);
                Pair<INDArray, Double> pair = this.computeGaussianKernel(cArr, beta.getDouble((long)i), k);
                INDArray currP = (INDArray)pair.getFirst();
                double hDiff = (Double)pair.getSecond() - logU;
                int tries = 0;
                boolean found = false;
                while (!found && tries < 200) {
                    if (hDiff < this.tolerance && -hDiff < this.tolerance) {
                        found = true;
                        continue;
                    }
                    if (hDiff > 0.0) {
                        betaMin = betas;
                        betas = betaMax == Double.MAX_VALUE || betaMax == -1.7976931348623157E308 ? (betas *= 2.0) : (betas + betaMax) / 2.0;
                    } else {
                        betaMax = betas;
                        betas = betaMin == -1.7976931348623157E308 || betaMin == Double.MAX_VALUE ? (betas /= 2.0) : (betas + betaMin) / 2.0;
                    }
                    pair = this.computeGaussianKernel(cArr, betas, k);
                    hDiff = (Double)pair.getSecond() - logU;
                    ++tries;
                }
                currP.divi(currP.sum(new int[]{Integer.MAX_VALUE}));
                INDArray indices = Nd4j.create((int)1, (int)(k + 1));
                for (int j = 0; (long)j < indices.length() && j < results.size(); ++j) {
                    indices.putScalar((long)j, ((DataPoint)results.get(j)).getIndex());
                }
                for (int l = 0; l < k; ++l) {
                    this.cols.putScalar((long)(this.rows.getInt(new int[]{i}) + l), indices.getDouble((long)(l + 1)));
                    this.vals.putScalar((long)(this.rows.getInt(new int[]{i}) + l), currP.getDouble((long)l));
                }
            }
        }
        return this.vals;
    }

    public INDArray input() {
        return this.x;
    }

    public ConvexOptimizer getOptimizer() {
        return null;
    }

    public INDArray getParam(String param) {
        return null;
    }

    public void addListeners(TrainingListener ... listener) {
    }

    public Map<String, INDArray> paramTable() {
        return null;
    }

    public Map<String, INDArray> paramTable(boolean backprapParamsOnly) {
        return null;
    }

    public void setParamTable(Map<String, INDArray> paramTable) {
    }

    public void setParam(String key, INDArray val) {
    }

    public void clear() {
    }

    public void applyConstraints(int iteration, int epoch) {
    }

    protected Pair<Double, INDArray> gradient(INDArray p) {
        throw new UnsupportedOperationException();
    }

    public INDArray symmetrized(INDArray rowP, INDArray colP, INDArray valP) {
        INDArray rowCounts = Nd4j.create((int)this.N);
        DummyWorkspace workspace = this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
            int n;
            for (int n2 = 0; n2 < this.N; ++n2) {
                int begin = rowP.getInt(new int[]{n2});
                int end = rowP.getInt(new int[]{n2 + 1});
                for (int i = begin; i < end; ++i) {
                    boolean present = false;
                    for (int m = rowP.getInt(new int[]{colP.getInt(new int[]{i})}); m < rowP.getInt(new int[]{colP.getInt(new int[]{i}) + 1}); ++m) {
                        if (colP.getInt(new int[]{m}) != n2) continue;
                        present = true;
                    }
                    if (present) {
                        rowCounts.putScalar((long)n2, rowCounts.getDouble((long)n2) + 1.0);
                        continue;
                    }
                    rowCounts.putScalar((long)n2, rowCounts.getDouble((long)n2) + 1.0);
                    rowCounts.putScalar((long)colP.getInt(new int[]{i}), rowCounts.getDouble((long)colP.getInt(new int[]{i})) + 1.0);
                }
            }
            int numElements = rowCounts.sum(new int[]{Integer.MAX_VALUE}).getInt(new int[]{0});
            INDArray offset = Nd4j.create((int)this.N);
            INDArray symRowP = Nd4j.create((int)(this.N + 1));
            INDArray symColP = Nd4j.create((int)numElements);
            INDArray symValP = Nd4j.create((int)numElements);
            for (n = 0; n < this.N; ++n) {
                symRowP.putScalar((long)(n + 1), symRowP.getDouble((long)n) + rowCounts.getDouble((long)n));
            }
            for (n = 0; n < this.N; ++n) {
                for (int i = rowP.getInt(new int[]{n}); i < rowP.getInt(new int[]{n + 1}); ++i) {
                    int colPI;
                    boolean present = false;
                    for (int m = rowP.getInt(new int[]{colP.getInt(new int[]{i})}); m < rowP.getInt(new int[]{colP.getInt(new int[]{i})}) + 1; ++m) {
                        if (colP.getInt(new int[]{m}) != n) continue;
                        present = true;
                        if (n >= colP.getInt(new int[]{i})) continue;
                        symColP.putScalar((long)(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n})), colP.getInt(new int[]{i}));
                        symColP.putScalar((long)(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colP.getInt(new int[]{i})})), n);
                        symValP.putScalar((long)(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n})), valP.getDouble((long)i) + valP.getDouble((long)m));
                        symValP.putScalar((long)(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colP.getInt(new int[]{i})})), valP.getDouble((long)i) + valP.getDouble((long)m));
                    }
                    if (!present && n < (colPI = colP.getInt(new int[]{i}))) {
                        symColP.putScalar((long)(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n})), colPI);
                        symColP.putScalar((long)(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colPI})), n);
                        symValP.putScalar((long)(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n})), valP.getDouble((long)i));
                        symValP.putScalar((long)(symRowP.getInt(new int[]{colPI}) + offset.getInt(new int[]{colPI})), valP.getDouble((long)i));
                    }
                    if (present && (!present || n >= colP.getInt(new int[]{i}))) continue;
                    offset.putScalar((long)n, offset.getInt(new int[]{n}) + 1);
                    colPI = colP.getInt(new int[]{i});
                    if (colPI == n) continue;
                    offset.putScalar((long)colPI, offset.getDouble((long)colPI) + 1.0);
                }
            }
            symValP.divi((Number)2.0);
            INDArray iNDArray = symValP;
            return iNDArray;
        }
    }

    public Pair<INDArray, Double> computeGaussianKernel(INDArray distances, double beta, int k) {
        INDArray currP = Nd4j.create((int)k);
        for (int m = 0; m < k; ++m) {
            currP.putScalar((long)m, FastMath.exp((double)(-beta * distances.getDouble((long)(m + 1)))));
        }
        double sum = currP.sum(new int[]{Integer.MAX_VALUE}).getDouble(0L);
        double h = 0.0;
        for (int m = 0; m < k; ++m) {
            h += beta * (distances.getDouble((long)(m + 1)) * currP.getDouble((long)m));
        }
        h = h / sum + FastMath.log((double)sum);
        return new Pair((Object)currP, (Object)h);
    }

    public void init() {
    }

    public void setListeners(Collection<TrainingListener> listeners) {
    }

    public void setListeners(TrainingListener ... listeners) {
    }

    public void fit() {
        if (this.theta == 0.0) {
            log.debug("theta == 0, using decomposed version, might be slow");
            Tsne decomposedTsne = new Tsne(this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.minGain, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity);
            this.Y = decomposedTsne.calculate(this.x, this.numDimensions, this.perplexity);
        } else {
            if (this.Y == null) {
                this.Y = Nd4j.randn((long)this.x.rows(), (long)this.numDimensions, (Random)Nd4j.getRandom()).muli((Number)Float.valueOf(0.001f));
            }
            DummyWorkspace workspace = this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
            try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
                this.computeGaussianPerplexity(this.x, this.perplexity);
                this.vals = this.symmetrized(this.rows, this.cols, this.vals).divi(this.vals.sum(new int[]{Integer.MAX_VALUE}));
                this.vals.muli((Number)12);
                for (int i = 0; i < this.maxIter; ++i) {
                    this.step(this.vals, i);
                    if (i == this.switchMomentumIteration) {
                        this.momentum = this.finalMomentum;
                    }
                    if (i == this.stopLyingIteration) {
                        this.vals.divi((Number)12);
                    }
                    if (this.trainingListener == null) continue;
                    this.trainingListener.iterationDone((Model)this, i, 0);
                }
            }
        }
    }

    public void update(Gradient gradient) {
    }

    public void step(INDArray p, int i) {
        this.update(this.gradient().getGradientFor(Y_GRAD), Y_GRAD);
    }

    public void update(INDArray gradient, String paramType) {
        DummyWorkspace workspace = this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
            INDArray yGrads = gradient;
            this.gains = this.gains.add((Number)0.2).muli(Transforms.sign((INDArray)yGrads)).neq(Transforms.sign((INDArray)this.yIncs)).castTo(this.gains.dataType()).addi(this.gains.mul((Number)0.8).muli(Transforms.sign((INDArray)yGrads)).neq(Transforms.sign((INDArray)this.yIncs)).castTo(this.gains.dataType()));
            BooleanIndexing.replaceWhere((INDArray)this.gains, (Number)this.minGain, (Condition)Conditions.lessThan((Number)this.minGain));
            INDArray gradChange = this.gains.mul(yGrads);
            if (this.useAdaGrad) {
                if (this.adaGrad == null) {
                    this.adaGrad = new AdaGrad(ArrayUtil.toInts((long[])gradient.shape()), this.learningRate);
                    this.adaGrad.setStateViewArray(Nd4j.zeros((long[])gradient.shape()).reshape(1L, gradChange.length()), gradChange.shape(), gradient.ordering(), true);
                }
                gradChange = this.adaGrad.getGradient(gradChange, 0);
            } else {
                gradChange.muli((Number)this.learningRate);
            }
            this.yIncs.muli((Number)this.momentum).subi(gradChange);
            this.Y.addi(this.yIncs);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void saveAsFile(List<String> labels, String path) throws IOException {
        try (BufferedWriter write = null;){
            write = new BufferedWriter(new FileWriter(new File(path)));
            for (int i = 0; i < this.Y.rows() && i < labels.size(); ++i) {
                String word = labels.get(i);
                if (word == null) continue;
                StringBuilder sb = new StringBuilder();
                INDArray wordVector = this.Y.getRow((long)i);
                int j = 0;
                while ((long)j < wordVector.length()) {
                    sb.append(wordVector.getDouble((long)j));
                    if ((long)j < wordVector.length() - 1L) {
                        sb.append(",");
                    }
                    ++j;
                }
                sb.append(",");
                sb.append(word);
                sb.append("\n");
                write.write(sb.toString());
            }
            write.flush();
            write.close();
        }
    }

    @Deprecated
    public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException {
        this.fit(matrix, nDims);
        this.saveAsFile(labels, path);
    }

    public double score() {
        DummyWorkspace workspace = this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
            INDArray buff = Nd4j.create((int)this.numDimensions);
            AtomicDouble sum_Q = new AtomicDouble(0.0);
            for (int n = 0; n < this.N; ++n) {
                this.tree.computeNonEdgeForces(n, this.theta, buff, sum_Q);
            }
            double C = 0.0;
            INDArray linear = this.Y;
            for (int n = 0; n < this.N; ++n) {
                int begin = this.rows.getInt(new int[]{n});
                int end = this.rows.getInt(new int[]{n + 1});
                int ind1 = n;
                for (int i = begin; i < end; ++i) {
                    int ind2 = this.cols.getInt(new int[]{i});
                    buff.assign(linear.slice((long)ind1));
                    buff.subi(linear.slice((long)ind2));
                    double Q = Transforms.pow((INDArray)buff, (Number)2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0L);
                    Q = 1.0 / (1.0 + Q) / sum_Q.doubleValue();
                    C += this.vals.getDouble((long)i) * FastMath.log((double)(this.vals.getDouble((long)i) + Nd4j.EPS_THRESHOLD)) / (Q + Nd4j.EPS_THRESHOLD);
                }
            }
            double d = C;
            return d;
        }
    }

    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
    }

    public INDArray params() {
        return null;
    }

    public long numParams() {
        return 0L;
    }

    public long numParams(boolean backwards) {
        return 0L;
    }

    public void setParams(INDArray params) {
    }

    public void setParamsViewArray(INDArray params) {
        throw new UnsupportedOperationException();
    }

    public INDArray getGradientsViewArray() {
        throw new UnsupportedOperationException();
    }

    public void setBackpropGradientsViewArray(INDArray gradients) {
        throw new UnsupportedOperationException();
    }

    public void fit(INDArray data) {
        this.x = data;
        this.fit();
    }

    public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) {
        this.fit(data);
    }

    @Deprecated
    public void fit(INDArray data, int nDims) {
        this.x = data;
        this.numDimensions = nDims;
        this.fit();
    }

    public Gradient gradient() {
        DummyWorkspace workspace = this.workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationExternal, workspaceExternal);
        try (MemoryWorkspace ws = workspace.notifyScopeEntered();){
            if (this.yIncs == null) {
                this.yIncs = Nd4j.zeros((long[])this.Y.shape());
            }
            if (this.gains == null) {
                this.gains = Nd4j.ones((long[])this.Y.shape());
            }
            AtomicDouble sumQ = new AtomicDouble(0.0);
            INDArray posF = Nd4j.create((long[])this.Y.shape());
            INDArray negF = Nd4j.create((long[])this.Y.shape());
            if (this.tree == null) {
                this.tree = new SpTree(this.Y);
                this.tree.setWorkspaceMode(this.workspaceMode);
            }
            this.tree.computeEdgeForces(this.rows, this.cols, this.vals, this.N, posF);
            for (int n = 0; n < this.N; ++n) {
                this.tree.computeNonEdgeForces(n, this.theta, negF.slice((long)n), sumQ);
            }
            INDArray dC = posF.subi(negF.divi((Number)sumQ));
            DefaultGradient ret = new DefaultGradient();
            ret.gradientForVariable().put(Y_GRAD, dC);
            DefaultGradient defaultGradient = ret;
            return defaultGradient;
        }
    }

    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    public int batchSize() {
        return 0;
    }

    public NeuralNetConfiguration conf() {
        return null;
    }

    public void setConf(NeuralNetConfiguration conf) {
    }

    public INDArray getData() {
        return this.Y;
    }

    public void setData(INDArray data) {
        this.Y = data;
    }

    public int getMaxIter() {
        return this.maxIter;
    }

    public double getRealMin() {
        return this.realMin;
    }

    public double getInitialMomentum() {
        return this.initialMomentum;
    }

    public double getFinalMomentum() {
        return this.finalMomentum;
    }

    public double getMinGain() {
        return this.minGain;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public int getSwitchMomentumIteration() {
        return this.switchMomentumIteration;
    }

    public boolean isNormalize() {
        return this.normalize;
    }

    public boolean isUsePca() {
        return this.usePca;
    }

    public int getStopLyingIteration() {
        return this.stopLyingIteration;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public AdaGrad getAdaGrad() {
        return this.adaGrad;
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public INDArray getY() {
        return this.Y;
    }

    public int getN() {
        return this.N;
    }

    public INDArray getRows() {
        return this.rows;
    }

    public INDArray getCols() {
        return this.cols;
    }

    public INDArray getVals() {
        return this.vals;
    }

    public INDArray getX() {
        return this.x;
    }

    public SpTree getTree() {
        return this.tree;
    }

    public INDArray getGains() {
        return this.gains;
    }

    public INDArray getYIncs() {
        return this.yIncs;
    }

    public int getVpTreeWorkers() {
        return this.vpTreeWorkers;
    }

    public TrainingListener getTrainingListener() {
        return this.trainingListener;
    }

    public WorkspaceMode getWorkspaceMode() {
        return this.workspaceMode;
    }

    public WorkspaceConfiguration getWorkspaceConfigurationFeedForward() {
        return this.workspaceConfigurationFeedForward;
    }

    public void setMaxIter(int maxIter) {
        this.maxIter = maxIter;
    }

    public void setRealMin(double realMin) {
        this.realMin = realMin;
    }

    public void setInitialMomentum(double initialMomentum) {
        this.initialMomentum = initialMomentum;
    }

    public void setFinalMomentum(double finalMomentum) {
        this.finalMomentum = finalMomentum;
    }

    public void setMinGain(double minGain) {
        this.minGain = minGain;
    }

    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    public void setSwitchMomentumIteration(int switchMomentumIteration) {
        this.switchMomentumIteration = switchMomentumIteration;
    }

    public void setNormalize(boolean normalize) {
        this.normalize = normalize;
    }

    public void setUsePca(boolean usePca) {
        this.usePca = usePca;
    }

    public void setStopLyingIteration(int stopLyingIteration) {
        this.stopLyingIteration = stopLyingIteration;
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    public void setAdaGrad(AdaGrad adaGrad) {
        this.adaGrad = adaGrad;
    }

    public void setUseAdaGrad(boolean useAdaGrad) {
        this.useAdaGrad = useAdaGrad;
    }

    public void setPerplexity(double perplexity) {
        this.perplexity = perplexity;
    }

    public void setY(INDArray Y) {
        this.Y = Y;
    }

    public void setN(int N) {
        this.N = N;
    }

    public void setTheta(double theta) {
        this.theta = theta;
    }

    public void setRows(INDArray rows) {
        this.rows = rows;
    }

    public void setCols(INDArray cols) {
        this.cols = cols;
    }

    public void setVals(INDArray vals) {
        this.vals = vals;
    }

    public void setX(INDArray x) {
        this.x = x;
    }

    public void setTree(SpTree tree) {
        this.tree = tree;
    }

    public void setGains(INDArray gains) {
        this.gains = gains;
    }

    public void setYIncs(INDArray yIncs) {
        this.yIncs = yIncs;
    }

    public void setVpTreeWorkers(int vpTreeWorkers) {
        this.vpTreeWorkers = vpTreeWorkers;
    }

    public void setTrainingListener(TrainingListener trainingListener) {
        this.trainingListener = trainingListener;
    }

    public void setWorkspaceMode(WorkspaceMode workspaceMode) {
        this.workspaceMode = workspaceMode;
    }

    public void setWorkspaceConfigurationFeedForward(WorkspaceConfiguration workspaceConfigurationFeedForward) {
        this.workspaceConfigurationFeedForward = workspaceConfigurationFeedForward;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BarnesHutTsne)) {
            return false;
        }
        BarnesHutTsne other = (BarnesHutTsne)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getMaxIter() != other.getMaxIter()) {
            return false;
        }
        if (Double.compare(this.getRealMin(), other.getRealMin()) != 0) {
            return false;
        }
        if (Double.compare(this.getInitialMomentum(), other.getInitialMomentum()) != 0) {
            return false;
        }
        if (Double.compare(this.getFinalMomentum(), other.getFinalMomentum()) != 0) {
            return false;
        }
        if (Double.compare(this.getMinGain(), other.getMinGain()) != 0) {
            return false;
        }
        if (Double.compare(this.getMomentum(), other.getMomentum()) != 0) {
            return false;
        }
        if (this.getSwitchMomentumIteration() != other.getSwitchMomentumIteration()) {
            return false;
        }
        if (this.isNormalize() != other.isNormalize()) {
            return false;
        }
        if (this.isUsePca() != other.isUsePca()) {
            return false;
        }
        if (this.getStopLyingIteration() != other.getStopLyingIteration()) {
            return false;
        }
        if (Double.compare(this.getTolerance(), other.getTolerance()) != 0) {
            return false;
        }
        if (Double.compare(this.getLearningRate(), other.getLearningRate()) != 0) {
            return false;
        }
        AdaGrad this$adaGrad = this.getAdaGrad();
        AdaGrad other$adaGrad = other.getAdaGrad();
        if (this$adaGrad == null ? other$adaGrad != null : !this$adaGrad.equals(other$adaGrad)) {
            return false;
        }
        if (this.isUseAdaGrad() != other.isUseAdaGrad()) {
            return false;
        }
        if (Double.compare(this.getPerplexity(), other.getPerplexity()) != 0) {
            return false;
        }
        INDArray this$Y = this.getY();
        INDArray other$Y = other.getY();
        if (this$Y == null ? other$Y != null : !this$Y.equals(other$Y)) {
            return false;
        }
        if (this.getN() != other.getN()) {
            return false;
        }
        if (Double.compare(this.getTheta(), other.getTheta()) != 0) {
            return false;
        }
        INDArray this$rows = this.getRows();
        INDArray other$rows = other.getRows();
        if (this$rows == null ? other$rows != null : !this$rows.equals(other$rows)) {
            return false;
        }
        INDArray this$cols = this.getCols();
        INDArray other$cols = other.getCols();
        if (this$cols == null ? other$cols != null : !this$cols.equals(other$cols)) {
            return false;
        }
        INDArray this$vals = this.getVals();
        INDArray other$vals = other.getVals();
        if (this$vals == null ? other$vals != null : !this$vals.equals(other$vals)) {
            return false;
        }
        String this$simiarlityFunction = this.getSimiarlityFunction();
        String other$simiarlityFunction = other.getSimiarlityFunction();
        if (this$simiarlityFunction == null ? other$simiarlityFunction != null : !this$simiarlityFunction.equals(other$simiarlityFunction)) {
            return false;
        }
        if (this.isInvert() != other.isInvert()) {
            return false;
        }
        INDArray this$x = this.getX();
        INDArray other$x = other.getX();
        if (this$x == null ? other$x != null : !this$x.equals(other$x)) {
            return false;
        }
        if (this.getNumDimensions() != other.getNumDimensions()) {
            return false;
        }
        SpTree this$tree = this.getTree();
        SpTree other$tree = other.getTree();
        if (this$tree == null ? other$tree != null : !this$tree.equals(other$tree)) {
            return false;
        }
        INDArray this$gains = this.getGains();
        INDArray other$gains = other.getGains();
        if (this$gains == null ? other$gains != null : !this$gains.equals(other$gains)) {
            return false;
        }
        INDArray this$yIncs = this.getYIncs();
        INDArray other$yIncs = other.getYIncs();
        if (this$yIncs == null ? other$yIncs != null : !this$yIncs.equals(other$yIncs)) {
            return false;
        }
        if (this.getVpTreeWorkers() != other.getVpTreeWorkers()) {
            return false;
        }
        WorkspaceMode this$workspaceMode = this.getWorkspaceMode();
        WorkspaceMode other$workspaceMode = other.getWorkspaceMode();
        if (this$workspaceMode == null ? other$workspaceMode != null : !this$workspaceMode.equals(other$workspaceMode)) {
            return false;
        }
        WorkspaceConfiguration this$workspaceConfigurationFeedForward = this.getWorkspaceConfigurationFeedForward();
        WorkspaceConfiguration other$workspaceConfigurationFeedForward = other.getWorkspaceConfigurationFeedForward();
        return !(this$workspaceConfigurationFeedForward == null ? other$workspaceConfigurationFeedForward != null : !this$workspaceConfigurationFeedForward.equals(other$workspaceConfigurationFeedForward));
    }

    protected boolean canEqual(Object other) {
        return other instanceof BarnesHutTsne;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getMaxIter();
        long $realMin = Double.doubleToLongBits(this.getRealMin());
        result = result * 59 + (int)($realMin >>> 32 ^ $realMin);
        long $initialMomentum = Double.doubleToLongBits(this.getInitialMomentum());
        result = result * 59 + (int)($initialMomentum >>> 32 ^ $initialMomentum);
        long $finalMomentum = Double.doubleToLongBits(this.getFinalMomentum());
        result = result * 59 + (int)($finalMomentum >>> 32 ^ $finalMomentum);
        long $minGain = Double.doubleToLongBits(this.getMinGain());
        result = result * 59 + (int)($minGain >>> 32 ^ $minGain);
        long $momentum = Double.doubleToLongBits(this.getMomentum());
        result = result * 59 + (int)($momentum >>> 32 ^ $momentum);
        result = result * 59 + this.getSwitchMomentumIteration();
        result = result * 59 + (this.isNormalize() ? 79 : 97);
        result = result * 59 + (this.isUsePca() ? 79 : 97);
        result = result * 59 + this.getStopLyingIteration();
        long $tolerance = Double.doubleToLongBits(this.getTolerance());
        result = result * 59 + (int)($tolerance >>> 32 ^ $tolerance);
        long $learningRate = Double.doubleToLongBits(this.getLearningRate());
        result = result * 59 + (int)($learningRate >>> 32 ^ $learningRate);
        AdaGrad $adaGrad = this.getAdaGrad();
        result = result * 59 + ($adaGrad == null ? 43 : $adaGrad.hashCode());
        result = result * 59 + (this.isUseAdaGrad() ? 79 : 97);
        long $perplexity = Double.doubleToLongBits(this.getPerplexity());
        result = result * 59 + (int)($perplexity >>> 32 ^ $perplexity);
        INDArray $Y = this.getY();
        result = result * 59 + ($Y == null ? 43 : $Y.hashCode());
        result = result * 59 + this.getN();
        long $theta = Double.doubleToLongBits(this.getTheta());
        result = result * 59 + (int)($theta >>> 32 ^ $theta);
        INDArray $rows = this.getRows();
        result = result * 59 + ($rows == null ? 43 : $rows.hashCode());
        INDArray $cols = this.getCols();
        result = result * 59 + ($cols == null ? 43 : $cols.hashCode());
        INDArray $vals = this.getVals();
        result = result * 59 + ($vals == null ? 43 : $vals.hashCode());
        String $simiarlityFunction = this.getSimiarlityFunction();
        result = result * 59 + ($simiarlityFunction == null ? 43 : $simiarlityFunction.hashCode());
        result = result * 59 + (this.isInvert() ? 79 : 97);
        INDArray $x = this.getX();
        result = result * 59 + ($x == null ? 43 : $x.hashCode());
        result = result * 59 + this.getNumDimensions();
        SpTree $tree = this.getTree();
        result = result * 59 + ($tree == null ? 43 : $tree.hashCode());
        INDArray $gains = this.getGains();
        result = result * 59 + ($gains == null ? 43 : $gains.hashCode());
        INDArray $yIncs = this.getYIncs();
        result = result * 59 + ($yIncs == null ? 43 : $yIncs.hashCode());
        result = result * 59 + this.getVpTreeWorkers();
        WorkspaceMode $workspaceMode = this.getWorkspaceMode();
        result = result * 59 + ($workspaceMode == null ? 43 : $workspaceMode.hashCode());
        WorkspaceConfiguration $workspaceConfigurationFeedForward = this.getWorkspaceConfigurationFeedForward();
        result = result * 59 + ($workspaceConfigurationFeedForward == null ? 43 : $workspaceConfigurationFeedForward.hashCode());
        return result;
    }

    public String toString() {
        return "BarnesHutTsne(maxIter=" + this.getMaxIter() + ", realMin=" + this.getRealMin() + ", initialMomentum=" + this.getInitialMomentum() + ", finalMomentum=" + this.getFinalMomentum() + ", minGain=" + this.getMinGain() + ", momentum=" + this.getMomentum() + ", switchMomentumIteration=" + this.getSwitchMomentumIteration() + ", normalize=" + this.isNormalize() + ", usePca=" + this.isUsePca() + ", stopLyingIteration=" + this.getStopLyingIteration() + ", tolerance=" + this.getTolerance() + ", learningRate=" + this.getLearningRate() + ", adaGrad=" + this.getAdaGrad() + ", useAdaGrad=" + this.isUseAdaGrad() + ", perplexity=" + this.getPerplexity() + ", Y=" + this.getY() + ", N=" + this.getN() + ", theta=" + this.getTheta() + ", rows=" + this.getRows() + ", cols=" + this.getCols() + ", vals=" + this.getVals() + ", simiarlityFunction=" + this.getSimiarlityFunction() + ", invert=" + this.isInvert() + ", x=" + this.getX() + ", numDimensions=" + this.getNumDimensions() + ", tree=" + this.getTree() + ", gains=" + this.getGains() + ", yIncs=" + this.getYIncs() + ", vpTreeWorkers=" + this.getVpTreeWorkers() + ", trainingListener=" + this.getTrainingListener() + ", workspaceMode=" + this.getWorkspaceMode() + ", workspaceConfigurationFeedForward=" + this.getWorkspaceConfigurationFeedForward() + ")";
    }

    public static class Builder {
        private int maxIter = 1000;
        private double realMin = 1.0E-12f;
        private double initialMomentum = 0.5;
        private double finalMomentum = 0.8f;
        private double momentum = 0.5;
        private int switchMomentumIteration = 100;
        private boolean normalize = true;
        private int stopLyingIteration = 100;
        private double tolerance = 1.0E-5f;
        private double learningRate = 0.1f;
        private boolean useAdaGrad = false;
        private double perplexity = 30.0;
        private double minGain = 0.1f;
        private double theta = 0.5;
        private boolean invert = true;
        private int numDim = 2;
        private String similarityFunction = "cosinesimilarity";
        private int vpTreeWorkers = 1;
        protected WorkspaceMode workspaceMode = WorkspaceMode.NONE;

        public Builder vpTreeWorkers(int vpTreeWorkers) {
            this.vpTreeWorkers = vpTreeWorkers;
            return this;
        }

        public Builder minGain(double minGain) {
            this.minGain = minGain;
            return this;
        }

        public Builder perplexity(double perplexity) {
            this.perplexity = perplexity;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder tolerance(double tolerance) {
            this.tolerance = tolerance;
            return this;
        }

        public Builder stopLyingIteration(int stopLyingIteration) {
            this.stopLyingIteration = stopLyingIteration;
            return this;
        }

        public Builder normalize(boolean normalize) {
            this.normalize = normalize;
            return this;
        }

        public Builder setMaxIter(int maxIter) {
            this.maxIter = maxIter;
            return this;
        }

        public Builder setRealMin(double realMin) {
            this.realMin = realMin;
            return this;
        }

        public Builder setInitialMomentum(double initialMomentum) {
            this.initialMomentum = initialMomentum;
            return this;
        }

        public Builder setFinalMomentum(double finalMomentum) {
            this.finalMomentum = finalMomentum;
            return this;
        }

        public Builder setMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder setSwitchMomentumIteration(int switchMomentumIteration) {
            this.switchMomentumIteration = switchMomentumIteration;
            return this;
        }

        public Builder similarityFunction(String similarityFunction) {
            this.similarityFunction = similarityFunction;
            return this;
        }

        public Builder invertDistanceMetric(boolean invert) {
            this.invert = invert;
            return this;
        }

        public Builder theta(double theta) {
            this.theta = theta;
            return this;
        }

        public Builder numDimension(int numDim) {
            this.numDim = numDim;
            return this;
        }

        public Builder workspaceMode(WorkspaceMode workspaceMode) {
            this.workspaceMode = workspaceMode;
            return this;
        }

        public BarnesHutTsne build() {
            return new BarnesHutTsne(this.numDim, this.similarityFunction, this.theta, this.invert, this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.momentum, this.switchMomentumIteration, this.normalize, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity, null, this.minGain, this.vpTreeWorkers, this.workspaceMode);
        }
    }
}

