/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.NoParamLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

public class GlobalPoolingLayer
extends NoParamLayer {
    private PoolingType poolingType;
    private int[] poolingDimensions;
    private int pnorm;
    private boolean collapseDimensions = true;

    private GlobalPoolingLayer(Builder builder) {
        super(builder);
        this.poolingType = builder.poolingType;
        this.poolingDimensions = builder.poolingDimensions;
        this.collapseDimensions = builder.collapseDimensions;
        this.pnorm = builder.pnorm;
        this.layerName = builder.layerName;
    }

    public GlobalPoolingLayer() {
        this(PoolingType.MAX);
    }

    public GlobalPoolingLayer(PoolingType poolingType) {
        this(new Builder().poolingType(poolingType));
    }

    @Override
    public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
        org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer ret = new org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer(conf, networkDataType);
        ret.setListeners(trainingListeners);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public ParamInitializer initializer() {
        return EmptyParamInitializer.getInstance();
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        switch (inputType.getType()) {
            case FF: {
                throw new UnsupportedOperationException("Global max pooling cannot be applied to feed-forward input type. Got input type = " + inputType);
            }
            case RNN: {
                InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent)inputType;
                if (this.collapseDimensions) {
                    return InputType.feedForward(recurrent.getSize());
                }
                return recurrent;
            }
            case CNN: {
                InputType.InputTypeConvolutional conv = (InputType.InputTypeConvolutional)inputType;
                if (this.collapseDimensions) {
                    return InputType.feedForward(conv.getChannels());
                }
                return InputType.convolutional(1L, 1L, conv.getChannels());
            }
            case CNN3D: {
                InputType.InputTypeConvolutional3D conv3d = (InputType.InputTypeConvolutional3D)inputType;
                if (this.collapseDimensions) {
                    return InputType.feedForward(conv3d.getChannels());
                }
                return InputType.convolutional3D(1L, 1L, 1L, conv3d.getChannels());
            }
            case CNNFlat: {
                InputType.InputTypeConvolutionalFlat convFlat = (InputType.InputTypeConvolutionalFlat)inputType;
                if (this.collapseDimensions) {
                    return InputType.feedForward(convFlat.getDepth());
                }
                return InputType.convolutional(1L, 1L, convFlat.getDepth());
            }
        }
        throw new UnsupportedOperationException("Unknown or not supported input type: " + inputType);
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        switch (inputType.getType()) {
            case FF: {
                throw new UnsupportedOperationException("Global max pooling cannot be applied to feed-forward input type. Got input type = " + inputType);
            }
            case RNN: 
            case CNN: 
            case CNN3D: {
                return null;
            }
            case CNNFlat: {
                InputType.InputTypeConvolutionalFlat cFlat = (InputType.InputTypeConvolutionalFlat)inputType;
                return new FeedForwardToCnnPreProcessor(cFlat.getHeight(), cFlat.getWidth(), cFlat.getDepth());
            }
        }
        return null;
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        throw new UnsupportedOperationException("Global pooling layer does not contain parameters");
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        InputType outputType = this.getOutputType(-1, inputType);
        long fwdTrainInferenceWorkingPerEx = 0L;
        if (this.poolingType == PoolingType.PNORM) {
            fwdTrainInferenceWorkingPerEx = inputType.arrayElementsPerExample();
        }
        return new LayerMemoryReport.Builder(this.layerName, GlobalPoolingLayer.class, inputType, outputType).standardMemory(0L, 0L).workingMemory(0L, fwdTrainInferenceWorkingPerEx, 0L, fwdTrainInferenceWorkingPerEx).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS).build();
    }

    public PoolingType getPoolingType() {
        return this.poolingType;
    }

    public int[] getPoolingDimensions() {
        return this.poolingDimensions;
    }

    public int getPnorm() {
        return this.pnorm;
    }

    public boolean isCollapseDimensions() {
        return this.collapseDimensions;
    }

    public void setPoolingType(PoolingType poolingType) {
        this.poolingType = poolingType;
    }

    public void setPoolingDimensions(int[] poolingDimensions) {
        this.poolingDimensions = poolingDimensions;
    }

    public void setPnorm(int pnorm) {
        this.pnorm = pnorm;
    }

    public void setCollapseDimensions(boolean collapseDimensions) {
        this.collapseDimensions = collapseDimensions;
    }

    @Override
    public String toString() {
        return "GlobalPoolingLayer(poolingType=" + (Object)((Object)this.getPoolingType()) + ", poolingDimensions=" + Arrays.toString(this.getPoolingDimensions()) + ", pnorm=" + this.getPnorm() + ", collapseDimensions=" + this.isCollapseDimensions() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof GlobalPoolingLayer)) {
            return false;
        }
        GlobalPoolingLayer other = (GlobalPoolingLayer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        PoolingType this$poolingType = this.getPoolingType();
        PoolingType other$poolingType = other.getPoolingType();
        if (this$poolingType == null ? other$poolingType != null : !((Object)((Object)this$poolingType)).equals((Object)other$poolingType)) {
            return false;
        }
        if (!Arrays.equals(this.getPoolingDimensions(), other.getPoolingDimensions())) {
            return false;
        }
        if (this.getPnorm() != other.getPnorm()) {
            return false;
        }
        return this.isCollapseDimensions() == other.isCollapseDimensions();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof GlobalPoolingLayer;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        PoolingType $poolingType = this.getPoolingType();
        result = result * 59 + ($poolingType == null ? 43 : ((Object)((Object)$poolingType)).hashCode());
        result = result * 59 + Arrays.hashCode(this.getPoolingDimensions());
        result = result * 59 + this.getPnorm();
        result = result * 59 + (this.isCollapseDimensions() ? 79 : 97);
        return result;
    }

    public static class Builder
    extends Layer.Builder<Builder> {
        private PoolingType poolingType = PoolingType.MAX;
        private int[] poolingDimensions;
        private int pnorm = 2;
        private boolean collapseDimensions = true;

        public Builder() {
        }

        public Builder(PoolingType poolingType) {
            this.setPoolingType(poolingType);
        }

        public Builder poolingDimensions(int ... poolingDimensions) {
            this.setPoolingDimensions(poolingDimensions);
            return this;
        }

        public Builder poolingType(PoolingType poolingType) {
            this.setPoolingType(poolingType);
            return this;
        }

        public Builder collapseDimensions(boolean collapseDimensions) {
            this.setCollapseDimensions(collapseDimensions);
            return this;
        }

        public Builder pnorm(int pnorm) {
            if (pnorm <= 0) {
                throw new IllegalArgumentException("Invalid input: p-norm value must be greater than 0. Got: " + pnorm);
            }
            this.setPnorm(pnorm);
            return this;
        }

        public void setPnorm(int pnorm) {
            ValidationUtils.validateNonNegative(pnorm, "pnorm");
            this.pnorm = pnorm;
        }

        @Override
        public GlobalPoolingLayer build() {
            return new GlobalPoolingLayer(this);
        }

        public PoolingType getPoolingType() {
            return this.poolingType;
        }

        public int[] getPoolingDimensions() {
            return this.poolingDimensions;
        }

        public int getPnorm() {
            return this.pnorm;
        }

        public boolean isCollapseDimensions() {
            return this.collapseDimensions;
        }

        public void setPoolingType(PoolingType poolingType) {
            this.poolingType = poolingType;
        }

        public void setPoolingDimensions(int[] poolingDimensions) {
            this.poolingDimensions = poolingDimensions;
        }

        public void setCollapseDimensions(boolean collapseDimensions) {
            this.collapseDimensions = collapseDimensions;
        }
    }
}

