/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.ocnn;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.ops.transforms.Transforms;

public class OCNNOutputLayer
extends BaseOutputLayer<org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer> {
    private IActivation activation = new ActivationReLU();
    private static IActivation relu = new ActivationReLU();
    private ILossFunction lossFunction = new OCNNLossFunction();
    private int batchWindowSizeIndex;
    private INDArray window;

    public OCNNOutputLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)conf.getLayer();
        ocnnOutputLayer.setLossFn(this.lossFunction);
    }

    @Override
    public void setLabels(INDArray labels) {
    }

    @Override
    public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        INDArray preOut = this.preOutput2d(training, workspaceMgr);
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getLossFn();
        double score = lossFunction.computeScore(this.getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn(), this.maskArray, false);
        if (this.conf().isMiniBatch()) {
            score /= (double)this.getInputMiniBatchSize();
        }
        this.score = score += fullNetRegTerm;
        return score;
    }

    @Override
    public boolean needsLabels() {
        return false;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        Pair<Gradient, INDArray> pair = this.getGradientsAndDelta(this.preOutput2d(true, workspaceMgr), workspaceMgr);
        long inputShape = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.getConf().getLayer()).getNIn();
        INDArray delta = (INDArray)pair.getSecond();
        INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, this.input.dataType(), new long[]{inputShape, delta.length()}, 'f');
        epsilonNext = epsilonNext.assign(delta.broadcast(epsilonNext.shape())).transpose();
        return new Pair((Object)((Gradient)pair.getFirst()), (Object)epsilonNext);
    }

    private Pair<Gradient, INDArray> getGradientsAndDelta(INDArray preOut, LayerWorkspaceMgr workspaceMgr) {
        INDArray currentR;
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getLossFn();
        INDArray labels2d = this.getLabels2d(workspaceMgr, ArrayType.BP_WORKING_MEM);
        INDArray delta = lossFunction.computeGradient(labels2d, preOut, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer conf = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.conf().getLayer();
        if (conf.getLastEpochSinceRUpdated() == 0 && this.epochCount == 0) {
            currentR = this.doOutput(false, workspaceMgr);
            if (this.window == null) {
                this.window = Nd4j.createUninitializedDetached((DataType)preOut.dataType(), (long[])new long[]{conf.getWindowSize()}).assign((Number)0.0);
            }
            if ((long)this.batchWindowSizeIndex < this.window.length() - currentR.length()) {
                this.window.put(new INDArrayIndex[]{NDArrayIndex.interval((long)this.batchWindowSizeIndex, (long)((long)this.batchWindowSizeIndex + currentR.length()))}, currentR);
            } else if ((long)this.batchWindowSizeIndex < this.window.length()) {
                int windowIdx = (int)this.window.length() - this.batchWindowSizeIndex;
                this.window.put(new INDArrayIndex[]{NDArrayIndex.interval((long)(this.window.length() - (long)windowIdx), (long)this.window.length())}, currentR.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)windowIdx)}));
            }
            this.batchWindowSizeIndex = (int)((long)this.batchWindowSizeIndex + currentR.length());
            conf.setLastEpochSinceRUpdated(this.epochCount);
        } else if (conf.getLastEpochSinceRUpdated() != this.epochCount) {
            double percentile = this.window.percentileNumber((Number)(100.0 * conf.getNu())).doubleValue();
            this.getParam("r").putScalar(0L, percentile);
            conf.setLastEpochSinceRUpdated(this.epochCount);
            this.batchWindowSizeIndex = 0;
        } else {
            currentR = this.doOutput(false, workspaceMgr);
            this.window.put(new INDArrayIndex[]{NDArrayIndex.interval((long)this.batchWindowSizeIndex, (long)((long)this.batchWindowSizeIndex + currentR.length()))}, currentR);
        }
        DefaultGradient gradient = new DefaultGradient();
        INDArray vGradView = (INDArray)this.gradientViews.get("v");
        double oneDivNu = 1.0 / ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getNu();
        INDArray xTimesV = this.input.mmul(this.getParam("v"));
        INDArray derivW = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn().getActivation(xTimesV.dup(), true).negi();
        INDArray w = this.getParam("w");
        derivW = derivW.muliColumnVector(delta).mean(new int[]{0}).muli((Number)oneDivNu).addi(w.reshape(new long[]{w.length()}));
        gradient.setGradientFor("w", ((INDArray)this.gradientViews.get("w")).assign(derivW));
        INDArray firstVertDerivV = ((INDArray)((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn().backprop(xTimesV.dup(), Nd4j.ones((DataType)this.input.dataType(), (long[])xTimesV.shape())).getFirst()).muliRowVector(this.getParam("w").neg());
        firstVertDerivV = firstVertDerivV.muliColumnVector(delta).reshape('f', new long[]{this.input.size(0), 1L, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getHiddenSize()});
        INDArray secondTermDerivV = this.input.reshape('f', new long[]{this.input.size(0), this.getParam("v").size(0), 1L});
        long[] shape = new long[firstVertDerivV.shape().length];
        for (int i = 0; i < firstVertDerivV.rank(); ++i) {
            shape[i] = Math.max(firstVertDerivV.size(i), secondTermDerivV.size(i));
        }
        INDArray firstDerivVBroadcast = Nd4j.createUninitialized((DataType)this.input.dataType(), (long[])shape);
        INDArray mulResult = firstVertDerivV.broadcast(firstDerivVBroadcast);
        int[] bcDims = new int[]{0, 1};
        Broadcast.mul((INDArray)mulResult, (INDArray)secondTermDerivV, (INDArray)mulResult, (int[])bcDims);
        INDArray derivV = mulResult.mean(new int[]{0}).muli((Number)oneDivNu).addi(this.getParam("v"));
        gradient.setGradientFor("v", vGradView.assign(derivV));
        INDArray derivR = Nd4j.scalar((Number)delta.meanNumber()).muli((Number)oneDivNu).addi((Number)-1);
        gradient.setGradientFor("r", ((INDArray)this.gradientViews.get("r")).assign(derivR));
        this.clearNoiseWeightParams();
        delta = this.backpropDropOutIfPresent(delta);
        return new Pair((Object)gradient, (Object)delta);
    }

    @Override
    public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.input = input;
        return this.doOutput(training, workspaceMgr);
    }

    @Override
    public double f1Score(INDArray examples, INDArray labels) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    @Override
    protected INDArray preOutput2d(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return this.doOutput(training, workspaceMgr);
    }

    @Override
    protected INDArray getLabels2d(LayerWorkspaceMgr workspaceMgr, ArrayType arrayType) {
        return this.labels;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return this.doOutput(training, workspaceMgr);
    }

    private INDArray doOutput(boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        INDArray w = this.getParamWithNoise("w", training, workspaceMgr);
        INDArray v = this.getParamWithNoise("v", training, workspaceMgr);
        this.applyDropOutIfNecessary(training, workspaceMgr);
        INDArray first = Nd4j.createUninitialized((DataType)this.input.dataType(), (long[])new long[]{this.input.size(0), v.size(1)});
        this.input.mmuli(v, first);
        INDArray act2d = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn().getActivation(first, training);
        INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, this.input.dataType(), new long[]{this.input.size(0)});
        act2d.mmuli(w.reshape(new long[]{w.length()}), output);
        this.labels = output;
        return output;
    }

    @Override
    public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null || this.labels == null) {
            throw new IllegalStateException("Cannot calculate score without input and labels " + this.layerId());
        }
        INDArray preOut = this.preOutput2d(false, workspaceMgr);
        ILossFunction lossFunction = ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getLossFn();
        INDArray scoreArray = lossFunction.computeScoreArray(this.getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut, ((org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)this.layerConf()).getActivationFn(), this.maskArray);
        INDArray summedScores = scoreArray.sum(new int[]{1});
        if (fullNetRegTerm != 0.0) {
            summedScores.addi((Number)fullNetRegTerm);
        }
        return summedScores;
    }

    public void setActivation(IActivation activation) {
        this.activation = activation;
    }

    public IActivation getActivation() {
        return this.activation;
    }

    public class OCNNLossFunction
    implements ILossFunction {
        public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
            double wSum = Transforms.pow((INDArray)OCNNOutputLayer.this.getParam("w"), (Number)2).sumNumber().doubleValue() * 0.5;
            double vSum = Transforms.pow((INDArray)OCNNOutputLayer.this.getParam("v"), (Number)2).sumNumber().doubleValue() * 0.5;
            org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = (org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer)OCNNOutputLayer.this.conf().getLayer();
            INDArray rSubPre = preOutput.rsub((Number)OCNNOutputLayer.this.getParam("r").getDouble(0L));
            INDArray rMeanSub = relu.getActivation(rSubPre, true);
            double rMean = rMeanSub.meanNumber().doubleValue();
            double rSum = OCNNOutputLayer.this.getParam("r").getDouble(0L);
            double nuDiv = 1.0 / ocnnOutputLayer.getNu() * rMean;
            double lastTerm = -rSum;
            return wSum + vSum + nuDiv + lastTerm;
        }

        public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
            INDArray r = OCNNOutputLayer.this.getParam("r").sub(preOutput);
            return r;
        }

        public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
            INDArray preAct = preOutput.rsub((Number)OCNNOutputLayer.this.getParam("r").getDouble(0L));
            INDArray target = (INDArray)relu.backprop(preAct, Nd4j.ones((DataType)preOutput.dataType(), (long[])preAct.shape())).getFirst();
            return target;
        }

        public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
            return new Pair((Object)this.computeScore(labels, preOutput, activationFn, mask, average), (Object)this.computeGradient(labels, preOutput, activationFn, mask));
        }

        public String name() {
            return "OCNNLossFunction";
        }
    }
}

