/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.feedforward;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.PReLULayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.impl.ActivationPReLU;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

public class PReLU
extends BaseLayer<PReLULayer> {
    long[] axes = ((PReLULayer)this.layerConf()).getSharedAxes();

    public PReLU(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr mgr) {
        this.assertInputSet(false);
        this.applyDropOutIfNecessary(training, mgr);
        INDArray in = training ? mgr.dup(ArrayType.ACTIVATIONS, this.input, this.input.ordering()) : mgr.leverageTo(ArrayType.ACTIVATIONS, this.input);
        INDArray alpha = this.getParam("W");
        return new ActivationPReLU(alpha, this.axes).getActivation(in, training);
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        INDArray layerInput = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, this.input, this.input.ordering());
        INDArray alpha = this.getParam("W");
        ActivationPReLU prelu = new ActivationPReLU(alpha, this.axes);
        Pair deltas = prelu.backprop(layerInput, epsilon);
        INDArray delta = (INDArray)deltas.getFirst();
        INDArray weightGrad = (INDArray)deltas.getSecond();
        INDArray weightGradView = (INDArray)this.gradientViews.get("W");
        weightGradView.assign(weightGrad);
        delta = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, delta);
        delta = this.backpropDropOutIfPresent(delta);
        DefaultGradient ret = new DefaultGradient();
        ret.setGradientFor("W", weightGradView, Character.valueOf('c'));
        return new Pair((Object)ret, (Object)delta);
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }
}

