/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.validation;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradCheckUtil {
    private static final Logger log = LoggerFactory.getLogger(GradCheckUtil.class);
    private static final boolean DEFAULT_PRINT = true;
    private static final boolean DEFAULT_EXIT_FIRST_FAILURE = false;
    private static final boolean DEFAULT_DEBUG_MODE = false;
    private static final double DEFAULT_EPS = 1.0E-5;
    private static final double DEFAULT_MAX_REL_ERROR = 1.0E-5;
    private static final double DEFAULT_MIN_ABS_ERROR = 1.0E-6;

    public static boolean checkGradients(SDVariable function, SDVariable wrt, double epsilon, double maxRelError, boolean print, Map<String, INDArray> inputParameters) {
        if (epsilon <= 0.0 || epsilon > 0.1) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (maxRelError <= 0.0 || maxRelError > 0.25) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
        }
        DataBuffer.Type dataType = DataTypeUtil.getDtypeFromContext();
        if (dataType != DataBuffer.Type.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); before using GradientCheckUtil");
        }
        SameDiff sameDiff = function.getSameDiff();
        SameDiff opExec = SameDiff.create(sameDiff);
        INDArray[] eval = opExec.eval(inputParameters);
        int totalNFailures = 0;
        double maxError = 0.0;
        for (Map.Entry<String, INDArray> entry : inputParameters.entrySet()) {
            long nParams = entry.getValue().length();
            INDArray params = entry.getValue().dup();
            int i = 0;
            while ((long)i < nParams) {
                INDArray zeros = Nd4j.create(nParams);
                zeros.putScalar((long)i, epsilon / 2.0);
                double origValue = params.getDouble((long)i);
                params.putScalar((long)i, origValue + epsilon);
                HashMap<String, INDArray> evalParams = new HashMap<String, INDArray>();
                for (Map.Entry<String, INDArray> entry2 : inputParameters.entrySet()) {
                    if (!entry2.getKey().equals(entry.getKey())) {
                        evalParams.put(entry2.getKey(), entry2.getValue());
                        continue;
                    }
                    evalParams.put(entry.getKey(), params);
                }
                INDArray[] plusParams = sameDiff.eval(evalParams);
                INDArray[] minusParams = sameDiff.eval(evalParams);
                INDArray[] newDifferences = new INDArray[minusParams.length];
                for (int j = 0; j < newDifferences.length; ++j) {
                    newDifferences[j] = plusParams[j].subi(minusParams[j]).divi(epsilon);
                }
                double diff = plusParams[plusParams.length - 1].sumNumber().doubleValue() - minusParams[minusParams.length - 1].sumNumber().doubleValue();
                double eps = diff / epsilon;
                double correctVal = eval[eval.length - 1].sumNumber().doubleValue();
                double gradDiff = Math.abs(correctVal - eps);
                if (gradDiff > maxRelError) {
                    ++totalNFailures;
                }
                if (print) {
                    long nPass = nParams - (long)totalNFailures;
                    log.info("GradientCheckUtil.checkGradients(): " + nParams + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
                }
                ++i;
            }
        }
        return totalNFailures == 0;
    }

    public static boolean checkGradients(TestCase t) {
        return GradCheckUtil.checkGradients(t.sameDiff(), t.gradCheckEpsilon(), t.gradCheckMaxRelativeError(), t.gradCheckMinAbsError(), t.gradCheckPrint(), t.gradCheckDefaultExitFirstFailure(), false, t.gradCheckDebugMode(), t.gradCheckSkipVariables());
    }

    public static boolean checkGradients(SameDiff sd) {
        return GradCheckUtil.checkGradients(sd, true, false);
    }

    public static boolean checkGradients(SameDiff sd, String ... skipVariables) {
        HashSet<String> skip = null;
        if (skipVariables != null) {
            skip = new HashSet<String>();
            Collections.addAll(skip, skipVariables);
        }
        return GradCheckUtil.checkGradients(sd, 1.0E-5, 1.0E-5, 1.0E-6, true, false, false, false, skip);
    }

    public static boolean checkGradients(SameDiff sd, boolean print, boolean exitOnFirstFailure) {
        return GradCheckUtil.checkGradients(sd, 1.0E-5, 1.0E-5, 1.0E-6, print, exitOnFirstFailure);
    }

    public static boolean checkGradients(SameDiff sd, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure) {
        return GradCheckUtil.checkGradients(sd, eps, maxRelError, minAbsError, print, exitOnFirstFailure, false, false, null);
    }

    public static boolean checkGradients(SameDiff sd, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure, boolean skipValidation, boolean debugMode, Set<String> skipVariables) {
        boolean debugBefore = sd.isDebugMode();
        if (debugMode) {
            sd.enableDebugMode();
        }
        if (!skipValidation) {
            GradCheckUtil.validateInternalState(sd, true);
        }
        if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
            throw new IllegalStateException("Data type must be set to double");
        }
        HashSet<String> fnOutputs = new HashSet<String>();
        for (DifferentialFunction f : sd.functions()) {
            for (SDVariable s : f.outputVariables()) {
                fnOutputs.add(s.getVarName());
            }
        }
        for (SDVariable s : sd.variables()) {
            if (fnOutputs.contains(s.getVarName()) || s.getArr() != null) continue;
            throw new IllegalStateException("Variable \"" + s.getVarName() + "\" does not have array associated with it");
        }
        INDArray out = sd.execAndEndResult();
        if (out.length() != 1L) {
            throw new IllegalStateException("Output variable is not a scalar - has shape " + Arrays.toString(out.shape()));
        }
        sd.execBackwards();
        HashMap<String, INDArray> grad = new HashMap<String, INDArray>();
        for (SDVariable v : sd.variables()) {
            if (fnOutputs.contains(v.getVarName())) continue;
            SDVariable g = sd.grad(v.getVarName());
            if (g == null) {
                throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
            }
            INDArray ga = g.getArr();
            if (ga == null) {
                throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
            }
            if (!Arrays.equals(v.getArr().shape(), g.getArr().shape())) {
                throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape()));
            }
            grad.put(v.getVarName(), ga.dup());
        }
        int totalNFailures = 0;
        int totalCount = 0;
        double maxError = 0.0;
        for (SDVariable s : sd.variables()) {
            if (fnOutputs.contains(s.getVarName())) continue;
            if (skipVariables != null && skipVariables.contains(s.getVarName())) {
                log.info("Grad check: skipping variable \"{}\"", (Object)s.getVarName());
                continue;
            }
            String name = s.getVarName();
            INDArray a = s.getArr();
            long n = a.length();
            if (print) {
                log.info("Starting test for variable \"{}\" with {} values", (Object)s.getVarName(), (Object)n);
            }
            NdIndexIterator iter = new NdIndexIterator('c', a.shape());
            int i = 0;
            while (iter.hasNext()) {
                long[] idx = iter.next();
                String strIdx = null;
                if (print) {
                    strIdx = Arrays.toString(idx).replaceAll(" ", "");
                }
                ++totalCount;
                double orig = a.getDouble(idx);
                a.putScalar(idx, orig + eps);
                double scorePlus = sd.execAndEndResult().getDouble(0L);
                a.putScalar(idx, orig - eps);
                double scoreMinus = sd.execAndEndResult().getDouble(0L);
                a.putScalar(idx, orig);
                double numericalGrad = (scorePlus - scoreMinus) / (2.0 * eps);
                INDArray aGrad = (INDArray)grad.get(s.getVarName());
                double analyticGrad = aGrad.getDouble(idx);
                if (Double.isInfinite(numericalGrad) || Double.isNaN(numericalGrad)) {
                    throw new IllegalStateException("Numerical gradient was " + numericalGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
                }
                if (Double.isInfinite(analyticGrad) || Double.isNaN(analyticGrad)) {
                    throw new IllegalStateException("Analytic (SameDiff) gradient was " + analyticGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
                }
                double relError = numericalGrad == 0.0 && analyticGrad == 0.0 ? 0.0 : Math.abs(analyticGrad - numericalGrad) / Math.abs(Math.abs(analyticGrad) + Math.abs(numericalGrad));
                if (relError > maxError) {
                    maxError = relError;
                }
                if (relError > maxRelError || Double.isNaN(relError)) {
                    double absError = Math.abs(analyticGrad - numericalGrad);
                    if (absError < minAbsError) {
                        if (print) {
                            log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError);
                        }
                    } else {
                        if (print) {
                            log.info("Param " + i + " (" + name + strIdx + ") FAILED: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + ", absError=" + absError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                        }
                        if (exitOnFirstFailure) {
                            return false;
                        }
                        ++totalNFailures;
                    }
                } else if (print) {
                    log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError);
                }
                ++i;
            }
        }
        if (print) {
            int nPass = totalCount - totalNFailures;
            log.info("GradCheckUtil.checkGradients(): " + totalCount + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        }
        if (debugMode && !debugBefore) {
            sd.disableDebugging();
        }
        return totalNFailures == 0;
    }

    public static void validateInternalState(SameDiff sd, boolean generateAndCheckGradFn) {
        DifferentialFunction[] dfs = sd.functions();
        List<SDVariable> vars = sd.variables();
        HashSet<SDVariable> varsSet = new HashSet<SDVariable>(vars);
        Preconditions.checkState((vars.size() == varsSet.size() ? 1 : 0) != 0, (String)"Duplicate variables in variables() list");
        HashSet<String> varSetStr = new HashSet<String>();
        for (SDVariable v : vars) {
            if (varSetStr.contains(v.getVarName())) {
                throw new IllegalStateException("Variable with name " + v.getVarName() + " already encountered");
            }
            varSetStr.add(v.getVarName());
        }
        Map incomingArgsReverse = (Map)GradCheckUtil.getObject("incomingArgsReverse", sd, SameDiff.class);
        Map outgoingArgsReverse = (Map)GradCheckUtil.getObject("outgoingArgsReverse", sd, SameDiff.class);
        Preconditions.checkState((dfs.length == incomingArgsReverse.size() ? 1 : 0) != 0, (String)"All functions not present in incomingArgsReverse");
        Preconditions.checkState((dfs.length == outgoingArgsReverse.size() ? 1 : 0) != 0, (String)"All functions not present in outgoingArgsReverse");
        for (DifferentialFunction df : dfs) {
            String[] str;
            Preconditions.checkState((boolean)incomingArgsReverse.containsKey(df.getOwnName()), (String)(df.getOwnName() + " not present in incomingArgsReverse"));
            Preconditions.checkState((boolean)outgoingArgsReverse.containsKey(df.getOwnName()), (String)(df.getOwnName() + " not present in outgoingArgsReverse"));
            for (String s : str = (String[])incomingArgsReverse.get(df.getOwnName())) {
                Preconditions.checkState((boolean)varSetStr.contains(s), (String)("Variable " + s + " in incomingArgsReverse value not a known variable name"));
            }
            String[] stringArray = str = (String[])outgoingArgsReverse.get(df.getOwnName());
            int n = stringArray.length;
            for (int i = 0; i < n; ++i) {
                String s;
                s = stringArray[i];
                Preconditions.checkState((boolean)varSetStr.contains(s), (String)("Variable " + s + " in outgoingArgsReverse value not a known variable name"));
            }
        }
        HashMap seen = new HashMap();
        for (Map.Entry e : outgoingArgsReverse.entrySet()) {
            String[] varNames;
            for (String s : varNames = (String[])e.getValue()) {
                if (seen.containsKey(s)) {
                    throw new IllegalStateException("Already saw variable \"" + s + "\" as output for op \"" + (String)seen.get(s) + "\": expected variables to be present as an output only once; also seen as output for op \"" + (String)e.getKey() + "\"");
                }
                seen.put(s, e.getKey());
            }
        }
        Map variableMap = (Map)GradCheckUtil.getObject("variableMap", sd, SameDiff.class);
        Preconditions.checkState((vars.size() == variableMap.size() ? 1 : 0) != 0, (String)"Variable map size check failed");
        for (Map.Entry e : variableMap.entrySet()) {
            Preconditions.checkState((boolean)((String)e.getKey()).equals(((SDVariable)e.getValue()).getVarName()), (String)"Name not equal");
        }
        Map functionsArgsFor = (Map)GradCheckUtil.getObject("functionsArgsFor", sd, SameDiff.class);
        Map functionOutputFor = (Map)GradCheckUtil.getObject("functionOutputFor", sd, SameDiff.class);
        if (generateAndCheckGradFn) {
            if (sd.getFunction("grad") == null) {
                sd.createGradFunction();
            }
            SameDiff gradFn = sd.getFunction("grad");
            GradCheckUtil.validateInternalState(gradFn, false);
            for (DifferentialFunction dfOrig : dfs) {
                Preconditions.checkNotNull((Object)gradFn.getFunctionById(dfOrig.getOwnName()), (String)("DifferentialFunction " + dfOrig.getOwnName() + " from original SameDiff instance not present in grad fn"));
            }
        }
    }

    private static <T> T getObject(String fieldName, Object from, Class<?> fromClass) {
        try {
            Field f = fromClass.getDeclaredField(fieldName);
            f.setAccessible(true);
            return (T)f.get(from);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}

