/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.convolution;

import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;

public class Convolution1DLayer
extends ConvolutionLayer {
    public Convolution1DLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        if (epsilon.rank() != 3) {
            throw new DL4JInvalidInputException("Got rank " + epsilon.rank() + " array as epsilon for Convolution1DLayer backprop with shape " + Arrays.toString(epsilon.shape()) + ". Expected rank 3 array with shape [minibatchSize, features, length]. " + this.layerId());
        }
        if (this.getRnnDataFormat() == RNNFormat.NWC) {
            epsilon = epsilon.permute(new int[]{0, 2, 1});
            this.input = this.input.permute(new int[]{0, 2, 1});
        }
        if (this.maskArray != null) {
            INDArray maskOut = (INDArray)this.feedForwardMaskArray(this.maskArray, MaskState.Active, (int)epsilon.size(0)).getFirst();
            Preconditions.checkState((epsilon.size(0) == maskOut.size(0) && epsilon.size(2) == maskOut.size(1) ? 1 : 0) != 0, (String)"Activation gradients dimensions (0,2) and mask dimensions (0,1) don't match: Activation gradients %s, Mask %s", (Object)epsilon.shape(), (Object)maskOut.shape());
            Broadcast.mul((INDArray)epsilon, (INDArray)maskOut, (INDArray)epsilon, (int[])new int[]{0, 2});
        }
        if (((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getConvolutionMode() == ConvolutionMode.Causal) {
            INDArray[] outputArrs;
            INDArray[] inputArrs;
            Pair<INDArray, INDArray> fwd = this.causalConv1dForward();
            IActivation afn = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getActivationFn();
            INDArray delta = (INDArray)afn.backprop((INDArray)fwd.getFirst(), epsilon).getFirst();
            org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer)this.layerConf();
            Conv1DConfig conf = Conv1DConfig.builder().k((long)c.getKernelSize()[0]).s((long)c.getStride()[0]).d((long)c.getDilation()[0]).p((long)c.getPadding()[0]).dataFormat("NCW").paddingMode(PaddingMode.CAUSAL).build();
            INDArray w = this.getParam("W");
            w = w.reshape(w.ordering(), new long[]{w.size(0), w.size(1), w.size(2)}).permute(new int[]{2, 1, 0});
            INDArray wg = (INDArray)this.gradientViews.get("W");
            wg = wg.reshape(wg.ordering(), new long[]{wg.size(0), wg.size(1), wg.size(2)}).permute(new int[]{2, 1, 0});
            INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, this.input.dataType(), this.input.shape());
            if (((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).hasBias()) {
                INDArray b = this.getParam("b");
                b = b.reshape(new long[]{b.length()});
                inputArrs = new INDArray[]{this.input.castTo(w.dataType()), w, b, delta};
                INDArray bg = (INDArray)this.gradientViews.get("b");
                bg = bg.reshape(new long[]{bg.length()});
                outputArrs = new INDArray[]{epsOut, wg, bg};
            } else {
                inputArrs = new INDArray[]{this.input.castTo(w.dataType()), w, delta};
                outputArrs = new INDArray[]{epsOut, wg};
            }
            Conv1DDerivative op = new Conv1DDerivative(inputArrs, outputArrs, conf);
            Nd4j.exec((CustomOp)op);
            DefaultGradient retGradient = new DefaultGradient();
            if (((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).hasBias()) {
                retGradient.setGradientFor("b", (INDArray)this.gradientViews.get("b"));
            }
            retGradient.setGradientFor("W", (INDArray)this.gradientViews.get("W"), Character.valueOf('c'));
            if (this.getRnnDataFormat() == RNNFormat.NWC) {
                epsOut = epsOut.permute(new int[]{0, 2, 1});
                this.input = this.input.permute(new int[]{0, 2, 1});
            }
            return new Pair((Object)retGradient, (Object)epsOut);
        }
        epsilon = epsilon.reshape(new long[]{epsilon.size(0), epsilon.size(1), epsilon.size(2), 1L});
        INDArray origInput = this.input;
        this.input = this.input.reshape(new long[]{this.input.size(0), this.input.size(1), this.input.size(2), 1L});
        Pair<Gradient, INDArray> gradientEpsNext = super.backpropGradient(epsilon, workspaceMgr);
        INDArray epsNext = (INDArray)gradientEpsNext.getSecond();
        epsNext = epsNext.reshape(new long[]{epsNext.size(0), epsNext.size(1), epsNext.size(2)});
        this.input = origInput;
        if (this.getRnnDataFormat() == RNNFormat.NWC) {
            epsNext = epsNext.permute(new int[]{0, 2, 1});
            this.input = this.input.permute(new int[]{0, 2, 1});
        }
        return new Pair(gradientEpsNext.getFirst(), (Object)epsNext);
    }

    @Override
    protected Pair<INDArray, INDArray> preOutput4d(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
        Pair<INDArray, INDArray> preOutput = super.preOutput(true, forBackprop, workspaceMgr);
        INDArray p3d = (INDArray)preOutput.getFirst();
        INDArray p = ((INDArray)preOutput.getFirst()).reshape(new long[]{p3d.size(0), p3d.size(1), p3d.size(2), 1L});
        preOutput.setFirst((Object)p);
        return preOutput;
    }

    @Override
    protected Pair<INDArray, INDArray> preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        if (((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getConvolutionMode() == ConvolutionMode.Causal) {
            return this.causalConv1dForward();
        }
        INDArray origInput = this.input;
        this.input = this.input.reshape(new long[]{this.input.size(0), this.input.size(1), this.input.size(2), 1L});
        Pair<INDArray, INDArray> preOutput = super.preOutput(training, forBackprop, workspaceMgr);
        this.input = origInput;
        INDArray p4d = (INDArray)preOutput.getFirst();
        INDArray p = ((INDArray)preOutput.getFirst()).reshape(new long[]{p4d.size(0), p4d.size(1), p4d.size(2)});
        preOutput.setFirst((Object)p);
        return preOutput;
    }

    protected Pair<INDArray, INDArray> causalConv1dForward() {
        INDArray[] inputs;
        org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = (org.deeplearning4j.nn.conf.layers.Convolution1DLayer)this.layerConf();
        Conv1DConfig conf = Conv1DConfig.builder().k((long)c.getKernelSize()[0]).s((long)c.getStride()[0]).d((long)c.getDilation()[0]).p((long)c.getPadding()[0]).dataFormat(((org.deeplearning4j.nn.conf.layers.Convolution1DLayer)this.layerConf()).getRnnDataFormat() == RNNFormat.NCW ? "NCW" : "NCW").paddingMode(PaddingMode.CAUSAL).build();
        INDArray w = this.getParam("W");
        w = w.reshape(w.ordering(), new long[]{w.size(0), w.size(1), w.size(2)}).permute(new int[]{2, 1, 0});
        if (((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).hasBias()) {
            INDArray b = this.getParam("b");
            b = b.reshape(new long[]{b.length()});
            inputs = new INDArray[]{this.input.castTo(w.dataType()), w, b};
        } else {
            inputs = new INDArray[]{this.input.castTo(w.dataType()), w};
        }
        Conv1D op = new Conv1D(inputs, null, conf);
        List outShape = op.calculateOutputShape();
        op.setOutputArgument(0, Nd4j.create((LongShapeDescriptor)((LongShapeDescriptor)outShape.get(0)), (boolean)false));
        Nd4j.exec((CustomOp)op);
        return new Pair((Object)op.getOutputArgument(0), null);
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        if (this.getRnnDataFormat() == RNNFormat.NWC) {
            this.input = this.input.permute(new int[]{0, 2, 1});
        }
        INDArray act4d = super.activate(training, workspaceMgr);
        INDArray act3d = act4d.reshape(new long[]{act4d.size(0), act4d.size(1), act4d.size(2)});
        if (this.maskArray != null) {
            INDArray maskOut = (INDArray)this.feedForwardMaskArray(this.maskArray, MaskState.Active, (int)act3d.size(0)).getFirst();
            Preconditions.checkState((act3d.size(0) == maskOut.size(0) && act3d.size(2) == maskOut.size(1) ? 1 : 0) != 0, (String)"Activations dimensions (0,2) and mask dimensions (0,1) don't match: Activations %s, Mask %s", (Object)act3d.shape(), (Object)maskOut.shape());
            Broadcast.mul((INDArray)act3d, (INDArray)maskOut, (INDArray)act3d, (int[])new int[]{0, 2});
        }
        if (this.getRnnDataFormat() == RNNFormat.NWC) {
            this.input = this.input.permute(new int[]{0, 2, 1});
            act3d = act3d.permute(new int[]{0, 2, 1});
        }
        return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, act3d);
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        INDArray reduced = ConvolutionUtils.cnn1dMaskReduction(maskArray, ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getKernelSize()[0], ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getStride()[0], ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getPadding()[0], ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getDilation()[0], ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getConvolutionMode());
        return new Pair((Object)reduced, (Object)currentMaskState);
    }

    private RNNFormat getRnnDataFormat() {
        return ((org.deeplearning4j.nn.conf.layers.Convolution1DLayer)this.layerConf()).getRnnDataFormat();
    }
}

