/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.mkldnn;

import java.util.Collections;
import java.util.Map;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper;
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

public class MKLDNNSubsamplingHelper
implements SubsamplingHelper {
    protected OpContext context;

    public MKLDNNSubsamplingHelper(DataType dataType) {
    }

    @Override
    public boolean checkSupported() {
        return BaseMKLDNNHelper.mklDnnEnabled();
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
        if (poolingType == PoolingType.SUM || poolingType == PoolingType.PNORM) {
            return null;
        }
        INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
        if (convolutionMode == ConvolutionMode.Same) {
            pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)epsilon.size(2), (int)epsilon.size(3)}, new int[]{(int)input.size(2), (int)input.size(3)}, kernel, strides, dilation);
        }
        input = input.dup();
        epsilon = epsilon.dup();
        Pooling2DConfig conf = Pooling2DConfig.builder().isSameMode(convolutionMode == ConvolutionMode.Same).kH((long)kernel[0]).kW((long)kernel[1]).sH((long)strides[0]).sW((long)strides[1]).dH((long)dilation[0]).dW((long)dilation[1]).pH((long)pad[0]).pW((long)pad[1]).isNHWC(false).build();
        switch (poolingType) {
            case MAX: {
                conf.setType(Pooling2D.Pooling2DType.MAX);
                break;
            }
            case AVG: {
                conf.setType(Pooling2D.Pooling2DType.AVG);
            }
        }
        Pooling2DDerivative d = Pooling2DDerivative.derivativeBuilder().config(conf).arrayInputs(new INDArray[]{input, epsilon}).arrayOutputs(new INDArray[]{gradAtInput}).build();
        Nd4j.exec((CustomOp)d);
        return new Pair((Object)new DefaultGradient(), (Object)gradAtInput);
    }

    @Override
    public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) {
        MaxPooling2D op;
        int[] outSize;
        if (convolutionMode == ConvolutionMode.Same) {
            outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation);
            pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[]{(int)input.size(2), (int)input.size(3)}, kernel, strides, dilation);
        } else {
            outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation);
        }
        long[] outShape = new long[]{input.size(0), input.size(1), outSize[0], outSize[1]};
        INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape);
        if (this.context == null) {
            this.context = Nd4j.getExecutioner().buildContext();
            this.context.setIArguments(new long[]{kernel[0], kernel[1], strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], ArrayUtil.fromBoolean((convolutionMode == ConvolutionMode.Same ? 1 : 0) != 0), 0L, 0L});
        }
        switch (poolingType) {
            case MAX: {
                op = new MaxPooling2D();
                break;
            }
            case AVG: {
                op = new AvgPooling2D();
                break;
            }
            default: {
                return null;
            }
        }
        this.context.getInputArrays().clear();
        this.context.getOutputArrays().clear();
        this.context.setInputArray(0, input);
        this.context.setOutputArray(0, output);
        Nd4j.exec((CustomOp)op, (OpContext)this.context);
        return output;
    }

    @Override
    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }
}

