/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.gradientcheck;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradientCheckUtil {
    private static Logger log = LoggerFactory.getLogger(GradientCheckUtil.class);
    private static final List<Class<? extends IActivation>> VALID_ACTIVATION_FUNCTIONS = Arrays.asList(Activation.CUBE.getActivationFunction().getClass(), Activation.ELU.getActivationFunction().getClass(), Activation.IDENTITY.getActivationFunction().getClass(), Activation.RATIONALTANH.getActivationFunction().getClass(), Activation.SIGMOID.getActivationFunction().getClass(), Activation.SOFTMAX.getActivationFunction().getClass(), Activation.SOFTPLUS.getActivationFunction().getClass(), Activation.SOFTSIGN.getActivationFunction().getClass(), Activation.TANH.getActivationFunction().getClass());

    private GradientCheckUtil() {
    }

    public static boolean checkGradients(MultiLayerNetwork mln, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, INDArray labels) {
        if (epsilon <= 0.0 || epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (maxRelError <= 0.0 || maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
        }
        if (!(mln.getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
        }
        int layerCount = 0;
        for (NeuralNetConfiguration n : mln.getLayerWiseConfigurations().getConfs()) {
            org.deeplearning4j.nn.conf.Updater u = n.getLayer().getUpdater();
            if (u == org.deeplearning4j.nn.conf.Updater.SGD) {
                double lr = n.getLayer().getLearningRate();
                if (lr != 1.0) {
                    throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + (Object)((Object)u) + " with lr=" + lr + " for layer \"" + n.getLayer().getLayerName() + "\"");
                }
            } else if (u != org.deeplearning4j.nn.conf.Updater.NONE) {
                throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer " + layerCount + "; got " + (Object)((Object)u));
            }
            double dropout = n.getLayer().getDropOut();
            if (n.isUseRegularization() && dropout != 0.0) {
                throw new IllegalStateException("Must have dropout == 0.0 for gradient checks - got dropout = " + dropout + " for layer " + layerCount);
            }
            IActivation activation = n.getLayer().getActivationFn();
            if (activation == null || VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) continue;
            log.warn("Layer " + layerCount + " is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
        }
        mln.setInput(input);
        mln.setLabels(labels);
        mln.computeGradientAndScore();
        Pair<Gradient, Double> gradAndScore = mln.gradientAndScore();
        Updater updater = UpdaterCreator.getUpdater(mln);
        updater.update(mln, gradAndScore.getFirst(), 0, mln.batchSize());
        INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup();
        INDArray originalParams = mln.params().dup();
        int nParams = originalParams.length();
        Map<String, INDArray> paramTable = mln.paramTable();
        ArrayList<String> paramNames = new ArrayList<String>(paramTable.keySet());
        int[] paramEnds = new int[paramNames.size()];
        paramEnds[0] = paramTable.get(paramNames.get(0)).length();
        for (int i = 1; i < paramEnds.length; ++i) {
            paramEnds[i] = paramEnds[i - 1] + paramTable.get(paramNames.get(i)).length();
        }
        int totalNFailures = 0;
        double maxError = 0.0;
        DataSet ds = new DataSet(input, labels);
        int currParamNameIdx = 0;
        INDArray params = mln.params();
        for (int i = 0; i < nParams; ++i) {
            if (i >= paramEnds[currParamNameIdx]) {
                ++currParamNameIdx;
            }
            String paramName = (String)paramNames.get(currParamNameIdx);
            double origValue = params.getDouble(i);
            params.putScalar(i, origValue + epsilon);
            double scorePlus = mln.score(ds, true);
            params.putScalar(i, origValue - epsilon);
            double scoreMinus = mln.score(ds, true);
            params.putScalar(i, origValue);
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble(i);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(backpropGradient - numericalGradient);
                if (absError < minAbsoluteError) {
                    log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
                    continue;
                }
                if (print) {
                    log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                }
                if (exitOnFirstError) {
                    return false;
                }
                ++totalNFailures;
                continue;
            }
            if (!print) continue;
            log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
        }
        if (print) {
            int nPass = nParams - totalNFailures;
            log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        }
        return totalNFailures == 0;
    }

    public static boolean checkGradients(ComputationGraph graph, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray[] inputs, INDArray[] labels) {
        int i;
        if (epsilon <= 0.0 || epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (maxRelError <= 0.0 || maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
        }
        if (graph.getNumInputArrays() != inputs.length) {
            throw new IllegalArgumentException("Invalid input arrays: expect " + graph.getNumInputArrays() + " inputs");
        }
        if (graph.getNumOutputArrays() != labels.length) {
            throw new IllegalArgumentException("Invalid labels arrays: expect " + graph.getNumOutputArrays() + " outputs");
        }
        int layerCount = 0;
        for (String vertexName : graph.getConfiguration().getVertices().keySet()) {
            GraphVertex gv = graph.getConfiguration().getVertices().get(vertexName);
            if (!(gv instanceof LayerVertex)) continue;
            LayerVertex lv = (LayerVertex)gv;
            org.deeplearning4j.nn.conf.Updater u = lv.getLayerConf().getLayer().getUpdater();
            if (u == org.deeplearning4j.nn.conf.Updater.SGD) {
                double lr = lv.getLayerConf().getLayer().getLearningRate();
                if (lr != 1.0) {
                    throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer \"" + vertexName + "\"; got " + (Object)((Object)u));
                }
            } else if (u != org.deeplearning4j.nn.conf.Updater.NONE) {
                throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer \"" + vertexName + "\"; got " + (Object)((Object)u));
            }
            double dropout = lv.getLayerConf().getLayer().getDropOut();
            if (lv.getLayerConf().isUseRegularization() && dropout != 0.0) {
                throw new IllegalStateException("Must have dropout == 0.0 for gradient checks - got dropout = " + dropout + " for layer " + layerCount);
            }
            IActivation activation = lv.getLayerConf().getLayer().getActivationFn();
            if (activation == null || VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) continue;
            log.warn("Layer \"" + vertexName + "\" is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
        }
        for (i = 0; i < inputs.length; ++i) {
            graph.setInput(i, inputs[i]);
        }
        for (i = 0; i < labels.length; ++i) {
            graph.setLabel(i, labels[i]);
        }
        graph.computeGradientAndScore();
        Pair<Gradient, Double> gradAndScore = graph.gradientAndScore();
        ComputationGraphUpdater updater = new ComputationGraphUpdater(graph);
        updater.update(graph, gradAndScore.getFirst(), 0, graph.batchSize());
        INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup();
        INDArray originalParams = graph.params().dup();
        int nParams = originalParams.length();
        Map<String, INDArray> paramTable = graph.paramTable();
        ArrayList<String> paramNames = new ArrayList<String>(paramTable.keySet());
        int[] paramEnds = new int[paramNames.size()];
        paramEnds[0] = paramTable.get(paramNames.get(0)).length();
        for (int i2 = 1; i2 < paramEnds.length; ++i2) {
            paramEnds[i2] = paramEnds[i2 - 1] + paramTable.get(paramNames.get(i2)).length();
        }
        int currParamNameIdx = 0;
        int totalNFailures = 0;
        double maxError = 0.0;
        MultiDataSet mds = new MultiDataSet(inputs, labels);
        INDArray params = graph.params();
        for (int i3 = 0; i3 < nParams; ++i3) {
            if (i3 >= paramEnds[currParamNameIdx]) {
                ++currParamNameIdx;
            }
            String paramName = (String)paramNames.get(currParamNameIdx);
            double origValue = params.getDouble(i3);
            params.putScalar(i3, origValue + epsilon);
            double scorePlus = graph.score((org.nd4j.linalg.dataset.api.MultiDataSet)mds, true);
            params.putScalar(i3, origValue - epsilon);
            double scoreMinus = graph.score((org.nd4j.linalg.dataset.api.MultiDataSet)mds, true);
            params.putScalar(i3, origValue);
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i3 + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble(i3);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(backpropGradient - numericalGradient);
                if (absError < minAbsoluteError) {
                    log.info("Param " + i3 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
                    continue;
                }
                if (print) {
                    log.info("Param " + i3 + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                }
                if (exitOnFirstError) {
                    return false;
                }
                ++totalNFailures;
                continue;
            }
            if (!print) continue;
            log.info("Param " + i3 + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
        }
        if (print) {
            int nPass = nParams - totalNFailures;
            log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        }
        return totalNFailures == 0;
    }

    public static boolean checkGradientsPretrainLayer(Layer layer, double epsilon, double maxRelError, double minAbsoluteError, boolean print, boolean exitOnFirstError, INDArray input, int rngSeed) {
        if (epsilon <= 0.0 || epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (maxRelError <= 0.0 || maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
        }
        boolean layerCount = false;
        layer.setInput(input);
        Nd4j.getRandom().setSeed(rngSeed);
        layer.computeGradientAndScore();
        Pair<Gradient, Double> gradAndScore = layer.gradientAndScore();
        Updater updater = UpdaterCreator.getUpdater(layer);
        updater.update(layer, gradAndScore.getFirst(), 0, layer.batchSize());
        INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup();
        INDArray originalParams = layer.params().dup();
        int nParams = originalParams.length();
        Map<String, INDArray> paramTable = layer.paramTable();
        ArrayList<String> paramNames = new ArrayList<String>(paramTable.keySet());
        int[] paramEnds = new int[paramNames.size()];
        paramEnds[0] = paramTable.get(paramNames.get(0)).length();
        for (int i = 1; i < paramEnds.length; ++i) {
            paramEnds[i] = paramEnds[i - 1] + paramTable.get(paramNames.get(i)).length();
        }
        int totalNFailures = 0;
        double maxError = 0.0;
        int currParamNameIdx = 0;
        INDArray params = layer.params();
        for (int i = 0; i < nParams; ++i) {
            if (i >= paramEnds[currParamNameIdx]) {
                ++currParamNameIdx;
            }
            String paramName = (String)paramNames.get(currParamNameIdx);
            double origValue = params.getDouble(i);
            params.putScalar(i, origValue + epsilon);
            Nd4j.getRandom().setSeed(rngSeed);
            layer.computeGradientAndScore();
            double scorePlus = layer.score();
            params.putScalar(i, origValue - epsilon);
            Nd4j.getRandom().setSeed(rngSeed);
            layer.computeGradientAndScore();
            double scoreMinus = layer.score();
            params.putScalar(i, origValue);
            double scoreDelta = scorePlus - scoreMinus;
            double numericalGradient = scoreDelta / (2.0 * epsilon);
            if (Double.isNaN(numericalGradient)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i + " of " + nParams);
            }
            double backpropGradient = gradientToCheck.getDouble(i);
            double relError = Math.abs(backpropGradient - numericalGradient) / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
            if (backpropGradient == 0.0 && numericalGradient == 0.0) {
                relError = 0.0;
            }
            if (relError > maxError) {
                maxError = relError;
            }
            if (relError > maxRelError || Double.isNaN(relError)) {
                double absError = Math.abs(backpropGradient - numericalGradient);
                if (absError < minAbsoluteError) {
                    log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsoluteError);
                    continue;
                }
                if (print) {
                    log.info("Param " + i + " (" + paramName + ") FAILED: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                }
                if (exitOnFirstError) {
                    return false;
                }
                ++totalNFailures;
                continue;
            }
            if (!print) continue;
            log.info("Param " + i + " (" + paramName + ") passed: grad= " + backpropGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError);
        }
        if (print) {
            int nPass = nParams - totalNFailures;
            log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        }
        return totalNFailures == 0;
    }
}

