/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dimensionalityreduction;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

public class RandomProjection {
    private int components;
    private Random rng;
    private double eps;
    private boolean autoMode;
    private long[] projectionMatrixShape;
    private INDArray _projectionMatrix;

    public RandomProjection(double eps, Random rng) {
        this.rng = rng;
        this.eps = eps;
        this.autoMode = true;
    }

    public RandomProjection(double eps) {
        this(eps, Nd4j.getRandom());
    }

    public RandomProjection(int components, Random rng) {
        this.rng = rng;
        this.components = components;
        this.autoMode = false;
    }

    public RandomProjection(int components) {
        this(components, Nd4j.getRandom());
    }

    public static List<Integer> johnsonLindenstraussMinDim(int[] n, double ... eps) {
        Boolean basicCheck = n == null || n.length == 0 || eps == null || eps.length == 0;
        if (basicCheck.booleanValue()) {
            throw new IllegalArgumentException("Johnson-Lindenstrauss dimension estimation requires > 0 components and at least a relative error");
        }
        for (double epsilon : eps) {
            if (!(epsilon <= 0.0) && !(epsilon >= 1.0)) continue;
            throw new IllegalArgumentException("A relative error should be in ]0, 1[");
        }
        ArrayList<Integer> res = new ArrayList<Integer>(n.length * eps.length);
        for (double epsilon : eps) {
            double denom = Math.pow(epsilon, 2.0) / 2.0 - Math.pow(epsilon, 3.0) / 3.0;
            for (int components : n) {
                res.add((int)(4.0 * Math.log(components) / denom));
            }
        }
        return res;
    }

    public static List<Long> johnsonLindenstraussMinDim(long[] n, double ... eps) {
        Boolean basicCheck = n == null || n.length == 0 || eps == null || eps.length == 0;
        if (basicCheck.booleanValue()) {
            throw new IllegalArgumentException("Johnson-Lindenstrauss dimension estimation requires > 0 components and at least a relative error");
        }
        for (double epsilon : eps) {
            if (!(epsilon <= 0.0) && !(epsilon >= 1.0)) continue;
            throw new IllegalArgumentException("A relative error should be in ]0, 1[");
        }
        ArrayList<Long> res = new ArrayList<Long>(n.length * eps.length);
        for (double epsilon : eps) {
            double denom = Math.pow(epsilon, 2.0) / 2.0 - Math.pow(epsilon, 3.0) / 3.0;
            for (long components : n) {
                res.add((long)(4.0 * Math.log(components) / denom));
            }
        }
        return res;
    }

    public static List<Integer> johnsonLindenStraussMinDim(int n, double ... eps) {
        return RandomProjection.johnsonLindenstraussMinDim(new int[]{n}, eps);
    }

    public static List<Long> johnsonLindenStraussMinDim(long n, double ... eps) {
        return RandomProjection.johnsonLindenstraussMinDim(new long[]{n}, eps);
    }

    private INDArray gaussianRandomMatrix(long[] shape, Random rng) {
        Nd4j.checkShapeValues(shape);
        INDArray res = Nd4j.create(shape);
        GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0]));
        Nd4j.getExecutioner().exec(op1, rng);
        return res;
    }

    private INDArray getProjectionMatrix(long[] shape, Random rng) {
        if (!Arrays.equals(this.projectionMatrixShape, shape) || this._projectionMatrix == null) {
            this._projectionMatrix = this.gaussianRandomMatrix(shape, rng);
        }
        return this._projectionMatrix;
    }

    private static int[] targetShape(int[] shape, double eps, int targetDimension, boolean auto) {
        int components = targetDimension;
        if (auto) {
            components = RandomProjection.johnsonLindenStraussMinDim(shape[0], eps).get(0);
        }
        if (auto && (components <= 0 || components > shape[1])) {
            throw new ND4JIllegalStateException(String.format("Estimation led to a target dimension of %d, which is invalid", components));
        }
        return new int[]{shape[1], components};
    }

    private static long[] targetShape(long[] shape, double eps, int targetDimension, boolean auto) {
        long components = targetDimension;
        if (auto) {
            components = RandomProjection.johnsonLindenStraussMinDim(shape[0], eps).get(0);
        }
        if (auto && (components <= 0L || components > shape[1])) {
            throw new ND4JIllegalStateException(String.format("Estimation led to a target dimension of %d, which is invalid", components));
        }
        return new long[]{shape[1], components};
    }

    public static long[] targetShape(INDArray X, double eps) {
        return RandomProjection.targetShape(X.shape(), eps, -1, true);
    }

    protected static long[] targetShape(INDArray X, int targetDimension) {
        return RandomProjection.targetShape(X.shape(), -1.0, targetDimension, false);
    }

    public INDArray project(INDArray data) {
        long[] tShape = RandomProjection.targetShape(data.shape(), this.eps, this.components, this.autoMode);
        return data.mmul(this.getProjectionMatrix(tShape, this.rng));
    }

    public INDArray project(INDArray data, INDArray result) {
        long[] tShape = RandomProjection.targetShape(data.shape(), this.eps, this.components, this.autoMode);
        return data.mmuli(this.getProjectionMatrix(tShape, this.rng), result);
    }

    public INDArray projecti(INDArray data) {
        long[] tShape = RandomProjection.targetShape(data.shape(), this.eps, this.components, this.autoMode);
        return data.mmuli(this.getProjectionMatrix(tShape, this.rng));
    }

    public INDArray projecti(INDArray data, INDArray result) {
        long[] tShape = RandomProjection.targetShape(data.shape(), this.eps, this.components, this.autoMode);
        return data.mmuli(this.getProjectionMatrix(tShape, this.rng), result);
    }
}

