/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.training;

import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;

public class CenterLossOutputLayer
extends BaseOutputLayer<org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer> {
    private double fullNetRegTerm;

    public CenterLossOutputLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        this.fullNetRegTerm = fullNetRegTerm;
        INDArray preOut = this.preOutput2d(training, workspaceMgr);
        ILossFunction interClassLoss = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLossFn();
        INDArray centers = (INDArray)this.params.get("cL");
        INDArray l = this.labels.castTo(centers.dataType());
        INDArray centersForExamples = l.mmul(centers);
        INDArray norm2DifferenceSquared = this.input.sub(centersForExamples).norm2(new int[]{1});
        norm2DifferenceSquared.muli(norm2DifferenceSquared);
        double sum = norm2DifferenceSquared.sumNumber().doubleValue();
        double lambda = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLambda();
        double intraClassScore = lambda / 2.0 * sum;
        double interClassScore = interClassLoss.computeScore(this.getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut, ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getActivationFn(), this.maskArray, false);
        double score = interClassScore + intraClassScore;
        score /= (double)this.getInputMiniBatchSize();
        this.score = score += fullNetRegTerm;
        return score;
    }

    @Override
    public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        INDArray preOut = this.preOutput2d(false, workspaceMgr);
        INDArray centers = (INDArray)this.params.get("cL");
        INDArray centersForExamples = this.labels.mmul(centers);
        INDArray intraClassScoreArray = this.input.sub(centersForExamples);
        ILossFunction interClassLoss = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLossFn();
        INDArray scoreArray = interClassLoss.computeScoreArray(this.getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut, ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        scoreArray.addi(intraClassScoreArray.muli((Number)(((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLambda() / 2.0)));
        if (fullNetRegTerm != 0.0) {
            scoreArray.addi((Number)fullNetRegTerm);
        }
        return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, scoreArray);
    }

    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null || this.labels == null) {
            return;
        }
        INDArray preOut = this.preOutput2d(true, workspaceMgr);
        Pair<Gradient, INDArray> pair = this.getGradientsAndDelta(preOut, workspaceMgr);
        this.gradient = (Gradient)pair.getFirst();
        this.score = this.computeScore(this.fullNetRegTerm, true, workspaceMgr);
    }

    @Override
    protected void setScoreWithZ(INDArray z) {
        throw new RuntimeException("Not supported " + this.layerId());
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        Pair<Gradient, INDArray> pair = this.getGradientsAndDelta(this.preOutput2d(true, workspaceMgr), workspaceMgr);
        INDArray delta = (INDArray)pair.getSecond();
        INDArray centers = (INDArray)this.params.get("cL");
        INDArray l = this.labels.castTo(centers.dataType());
        INDArray centersForExamples = l.mmul(centers);
        INDArray dLcdai = this.input.sub(centersForExamples);
        INDArray w = this.getParamWithNoise("W", true, workspaceMgr);
        INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, w.dataType(), new long[]{w.size(0), delta.size(0)}, 'f');
        epsilonNext = w.mmuli(delta.transpose(), epsilonNext).transpose();
        double lambda = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLambda();
        epsilonNext.addi(dLcdai.muli((Number)lambda));
        this.weightNoiseParams.clear();
        return new Pair(pair.getFirst(), (Object)epsilonNext);
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    private Pair<Gradient, INDArray> getGradientsAndDelta(INDArray preOut, LayerWorkspaceMgr workspaceMgr) {
        INDArray deltaC;
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLossFn();
        INDArray labels2d = this.getLabels2d(workspaceMgr, ArrayType.BP_WORKING_MEM);
        if (labels2d.size(1) != preOut.size(1)) {
            throw new DL4JInvalidInputException("Labels array numColumns (size(1) = " + labels2d.size(1) + ") does not match output layer number of outputs (nOut = " + preOut.size(1) + ") " + this.layerId());
        }
        INDArray delta = lossFunction.computeGradient(labels2d, preOut, ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        DefaultGradient gradient = new DefaultGradient();
        INDArray weightGradView = (INDArray)this.gradientViews.get("W");
        INDArray biasGradView = (INDArray)this.gradientViews.get("b");
        INDArray centersGradView = (INDArray)this.gradientViews.get("cL");
        double alpha = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getAlpha();
        INDArray centers = (INDArray)this.params.get("cL");
        INDArray l = this.labels.castTo(centers.dataType());
        INDArray centersForExamples = l.mmul(centers);
        INDArray diff = centersForExamples.sub(this.input).muli((Number)alpha);
        INDArray numerator = l.transpose().mmul(diff);
        INDArray denominator = l.sum(new int[]{0}).reshape(l.size(1), 1L).addi((Number)1.0);
        if (((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getGradientCheck()) {
            double lambda = ((org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer)this.layerConf()).getLambda();
            deltaC = numerator.muli((Number)lambda);
        } else {
            deltaC = numerator.diviColumnVector(denominator);
        }
        centersGradView.assign(deltaC);
        Nd4j.gemm((INDArray)this.input, (INDArray)delta, (INDArray)weightGradView, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
        delta.sum(biasGradView, new int[]{0});
        gradient.gradientForVariable().put("W", weightGradView);
        gradient.gradientForVariable().put("b", biasGradView);
        gradient.gradientForVariable().put("cL", centersGradView);
        return new Pair((Object)gradient, (Object)delta);
    }

    @Override
    protected INDArray getLabels2d(LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        return this.labels;
    }
}

