/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.rng.distribution;

import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.solvers.UnivariateSolverUtils;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.exception.util.Localizable;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;

public abstract class BaseDistribution
implements Distribution {
    protected Random random;
    protected double solverAbsoluteAccuracy;

    public BaseDistribution(Random rng) {
        this.random = rng;
    }

    public BaseDistribution() {
        this(Nd4j.getRandom());
    }

    public double probability(double x0, double x1) {
        if (x0 > x1) {
            throw new NumberIsTooLargeException((Localizable)LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, (Number)x0, (Number)x1, true);
        }
        return this.cumulativeProbability(x1) - this.cumulativeProbability(x0);
    }

    @Override
    public double inverseCumulativeProbability(final double p) throws OutOfRangeException {
        double dx;
        boolean chebyshevApplies;
        if (p < 0.0 || p > 1.0) {
            throw new OutOfRangeException((Number)p, (Number)0, (Number)1);
        }
        double lowerBound = this.getSupportLowerBound();
        if (p == 0.0) {
            return lowerBound;
        }
        double upperBound = this.getSupportUpperBound();
        if (p == 1.0) {
            return upperBound;
        }
        double mu = this.getNumericalMean();
        double sig = FastMath.sqrt((double)this.getNumericalVariance());
        boolean bl = chebyshevApplies = !Double.isInfinite(mu) && !Double.isNaN(mu) && !Double.isInfinite(sig) && !Double.isNaN(sig);
        if (lowerBound == Double.NEGATIVE_INFINITY) {
            if (chebyshevApplies) {
                lowerBound = mu - sig * FastMath.sqrt((double)((1.0 - p) / p));
            } else {
                lowerBound = -1.0;
                while (this.cumulativeProbability(lowerBound) >= p) {
                    lowerBound *= 2.0;
                }
            }
        }
        if (upperBound == Double.POSITIVE_INFINITY) {
            if (chebyshevApplies) {
                upperBound = mu + sig * FastMath.sqrt((double)(p / (1.0 - p)));
            } else {
                upperBound = 1.0;
                while (this.cumulativeProbability(upperBound) < p) {
                    upperBound *= 2.0;
                }
            }
        }
        UnivariateFunction toSolve = new UnivariateFunction(){

            public double value(double x) {
                return BaseDistribution.this.cumulativeProbability(x) - p;
            }
        };
        double x = UnivariateSolverUtils.solve((UnivariateFunction)toSolve, (double)lowerBound, (double)upperBound, (double)this.getSolverAbsoluteAccuracy());
        if (!this.isSupportConnected() && x - (dx = this.getSolverAbsoluteAccuracy()) >= this.getSupportLowerBound()) {
            double px = this.cumulativeProbability(x);
            if (this.cumulativeProbability(x - dx) == px) {
                upperBound = x;
                while (upperBound - lowerBound > dx) {
                    double midPoint = 0.5 * (lowerBound + upperBound);
                    if (this.cumulativeProbability(midPoint) < px) {
                        lowerBound = midPoint;
                        continue;
                    }
                    upperBound = midPoint;
                }
                return upperBound;
            }
        }
        return x;
    }

    protected double getSolverAbsoluteAccuracy() {
        return this.solverAbsoluteAccuracy;
    }

    @Override
    public void reseedRandomGenerator(long seed) {
        this.random.setSeed(seed);
    }

    @Override
    public double sample() {
        return this.inverseCumulativeProbability(this.random.nextDouble());
    }

    @Override
    public double[] sample(int sampleSize) {
        if (sampleSize <= 0) {
            throw new NotStrictlyPositiveException((Localizable)LocalizedFormats.NUMBER_OF_SAMPLES, (Number)sampleSize);
        }
        double[] out = new double[sampleSize];
        for (int i = 0; i < sampleSize; ++i) {
            out[i] = this.sample();
        }
        return out;
    }

    @Override
    public double probability(double x) {
        return 0.0;
    }

    @Override
    public INDArray sample(int[] shape) {
        INDArray ret = Nd4j.create(shape);
        NdIndexIterator idxIter = new NdIndexIterator(shape);
        int len = ret.length();
        for (int i = 0; i < len; ++i) {
            ret.putScalar((int[])idxIter.next(), this.sample());
        }
        return ret;
    }
}

