/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.config.ExecutionResult;
import org.nd4j.autodiff.samediff.config.SDValue;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.FrameIter;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.AtomicDouble;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrainingSession
extends InferenceSession {
    private static final Logger log = LoggerFactory.getLogger(TrainingSession.class);
    protected TrainingConfig config;
    protected Map<String, String> gradVarToVarMap;
    protected Map<String, GradientUpdater> updaters;
    protected Map<String, Integer> lossVarsToLossIdx;
    protected double[] currIterLoss;
    protected Map<Class<?>, AtomicDouble> currIterRegLoss;
    protected List<Listener> listeners;

    public TrainingSession(SameDiff sameDiff) {
        super(sameDiff);
    }

    public Loss trainingIteration(TrainingConfig config, Map<String, INDArray> placeholders, Set<String> paramsToTrain, Map<String, GradientUpdater> updaters, MultiDataSet batch, List<String> lossVariables, List<Listener> listeners, At at) {
        List<String> lossVars;
        this.config = config;
        this.updaters = updaters;
        if (listeners == null) {
            this.listeners = null;
        } else {
            ArrayList<Listener> filtered = new ArrayList<Listener>();
            for (Listener l : listeners) {
                if (!l.isActive(at.operation())) continue;
                filtered.add(l);
            }
            this.listeners = filtered.isEmpty() ? null : filtered;
        }
        HashSet<String> requiredActivations = new HashSet<String>();
        this.gradVarToVarMap = new HashMap<String, String>();
        for (String s : paramsToTrain) {
            Preconditions.checkState((boolean)this.sameDiff.hasVariable(s), (String)"SameDiff instance does not have a variable with name \"%s\"", (Object)s);
            SDVariable v = this.sameDiff.getVariable(s);
            Preconditions.checkState((v.getVariableType() == VariableType.VARIABLE ? 1 : 0) != 0, (String)"Can only train VARIABLE type variable - \"%s\" has type %s", (Object)s, (Object)((Object)v.getVariableType()));
            SDVariable grad = this.sameDiff.getVariable(s).getGradient();
            if (grad == null) continue;
            requiredActivations.add(grad.name());
            this.gradVarToVarMap.put(grad.name(), s);
        }
        if (config.getTrainEvaluations() != null) {
            requiredActivations.addAll(config.getTrainEvaluations().keySet());
        }
        this.lossVarsToLossIdx = new LinkedHashMap<String, Integer>();
        this.currIterLoss = new double[lossVariables.size()];
        this.currIterRegLoss = new HashMap();
        for (int i = 0; i < lossVariables.size(); ++i) {
            this.lossVarsToLossIdx.put(lossVariables.get(i), i);
        }
        ArrayList<String> outputVars = new ArrayList<String>(this.gradVarToVarMap.keySet());
        Map<String, INDArray> m = this.output(outputVars, placeholders, batch, requiredActivations, listeners, at);
        double[] finalLoss = new double[this.currIterLoss.length + this.currIterRegLoss.size()];
        System.arraycopy(this.currIterLoss, 0, finalLoss, 0, this.currIterLoss.length);
        if (this.currIterRegLoss.size() > 0) {
            lossVars = new ArrayList<String>(lossVariables.size() + this.currIterRegLoss.size());
            lossVars.addAll(lossVariables);
            int s = this.currIterRegLoss.size();
            for (Map.Entry entry : this.currIterRegLoss.entrySet()) {
                lossVars.add(((Class)entry.getKey()).getSimpleName());
                finalLoss[s] = ((AtomicDouble)entry.getValue()).get();
            }
        } else {
            lossVars = lossVariables;
        }
        Loss loss = new Loss(lossVars, finalLoss);
        if (listeners != null) {
            for (Listener listener : listeners) {
                if (!listener.isActive(Operation.TRAINING)) continue;
                listener.iterationDone(this.sameDiff, at, batch, loss);
            }
        }
        return loss;
    }

    @Override
    public ExecutionResult getOutputs(Pair<SameDiffOp, OpContext> opPair, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables, Map<String, SDValue> otherPlaceHolders) {
        ExecutionResult out = super.getOutputs(opPair, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables, otherPlaceHolders);
        SameDiffOp op = (SameDiffOp)opPair.getFirst();
        List<String> outputs = op.getOutputsOfOp();
        int outIdx = 0;
        for (String s : outputs) {
            if (this.lossVarsToLossIdx.containsKey(s)) {
                int lossIdx = this.lossVarsToLossIdx.get(s);
                INDArray arr = out.resultAt(outIdx);
                double l = arr.isScalar() ? arr.getDouble(0L) : arr.sumNumber().doubleValue();
                int n = lossIdx;
                this.currIterLoss[n] = this.currIterLoss[n] + l;
            }
            if (this.gradVarToVarMap.containsKey(s)) {
                double score;
                double lr;
                String varName = this.gradVarToVarMap.get(s);
                Variable gradVar = (Variable)this.sameDiff.getVariables().get((Object)s);
                if (gradVar.getInputsForOp() != null && gradVar.getInputsForOp().isEmpty()) {
                    throw new IllegalStateException("Op depends on gradient variable: " + s + " for variable " + varName);
                }
                GradientUpdater u = this.updaters.get(varName);
                Preconditions.checkState((u != null ? 1 : 0) != 0, (String)"No updater found for variable \"%s\"", (Object)varName);
                Variable var = (Variable)this.sameDiff.getVariables().get((Object)varName);
                INDArray gradArr = out.resultAt(outIdx);
                INDArray paramArr = var.getVariable().getArr();
                List<Regularization> r = this.config.getRegularization();
                if (r != null && r.size() > 0) {
                    lr = this.config.getUpdater().hasLearningRate() ? this.config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0;
                    for (Regularization reg : r) {
                        if (reg.applyStep() != Regularization.ApplyStep.BEFORE_UPDATER) continue;
                        if (this.listeners != null) {
                            score = reg.score(paramArr, at.iteration(), at.epoch());
                            if (!this.currIterRegLoss.containsKey(reg.getClass())) {
                                this.currIterRegLoss.put(reg.getClass(), new AtomicDouble());
                            }
                            this.currIterRegLoss.get(reg.getClass()).addAndGet(score);
                        }
                        reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch());
                    }
                }
                u.applyUpdater(gradArr, at.iteration(), at.epoch());
                if (r != null && r.size() > 0) {
                    lr = this.config.getUpdater().hasLearningRate() ? this.config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0;
                    for (Regularization reg : r) {
                        if (reg.applyStep() != Regularization.ApplyStep.POST_UPDATER) continue;
                        if (this.listeners != null) {
                            score = reg.score(paramArr, at.iteration(), at.epoch());
                            if (!this.currIterRegLoss.containsKey(reg.getClass())) {
                                this.currIterRegLoss.put(reg.getClass(), new AtomicDouble());
                            }
                            this.currIterRegLoss.get(reg.getClass()).addAndGet(score);
                        }
                        reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch());
                    }
                }
                if (listeners != null) {
                    for (Listener l : listeners) {
                        if (!l.isActive(at.operation())) continue;
                        l.preUpdate(this.sameDiff, at, var, gradArr);
                    }
                }
                if (this.config.isMinimize()) {
                    paramArr.subi(gradArr);
                } else {
                    paramArr.addi(gradArr);
                }
                log.trace("Applied updater to gradient and updated variable: {}", (Object)varName);
            }
            ++outIdx;
        }
        return out;
    }
}

