/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.recurrent;

import java.util.Collection;
import lombok.NonNull;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.layers.recurrent.TimeDistributedLayer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class TimeDistributed
extends BaseWrapperLayer {
    private RNNFormat rnnDataFormat = RNNFormat.NCW;

    public TimeDistributed(@JsonProperty(value="underlying") @NonNull Layer underlying, @JsonProperty(value="rnnDataFormat") RNNFormat rnnDataFormat) {
        super(underlying);
        if (underlying == null) {
            throw new NullPointerException("underlying is marked non-null but is null");
        }
        this.rnnDataFormat = rnnDataFormat;
    }

    public TimeDistributed(Layer underlying) {
        super(underlying);
    }

    @Override
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
        NeuralNetConfiguration conf2 = conf.clone();
        conf2.setLayer(((TimeDistributed)conf2.getLayer()).getUnderlying());
        return new TimeDistributedLayer(this.underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType), this.rnnDataFormat);
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer (layer #" + layerIndex + ")");
        }
        InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent)inputType;
        InputType ff = InputType.feedForward(rnn.getSize());
        InputType ffOut = this.underlying.getOutputType(layerIndex, ff);
        return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength(), this.rnnDataFormat);
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (inputType.getType() != InputType.Type.RNN) {
            throw new IllegalStateException("Only RNN input type is supported as input to TimeDistributed layer");
        }
        InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent)inputType;
        InputType ff = InputType.feedForward(rnn.getSize());
        this.rnnDataFormat = rnn.getFormat();
        this.underlying.setNIn(ff, override);
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return null;
    }

    public RNNFormat getRnnDataFormat() {
        return this.rnnDataFormat;
    }

    public void setRnnDataFormat(RNNFormat rnnDataFormat) {
        this.rnnDataFormat = rnnDataFormat;
    }

    @Override
    public String toString() {
        return "TimeDistributed(rnnDataFormat=" + this.getRnnDataFormat() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TimeDistributed)) {
            return false;
        }
        TimeDistributed other = (TimeDistributed)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        RNNFormat this$rnnDataFormat = this.getRnnDataFormat();
        RNNFormat other$rnnDataFormat = other.getRnnDataFormat();
        return !(this$rnnDataFormat == null ? other$rnnDataFormat != null : !this$rnnDataFormat.equals(other$rnnDataFormat));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof TimeDistributed;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        RNNFormat $rnnDataFormat = this.getRnnDataFormat();
        result = result * 59 + ($rnnDataFormat == null ? 43 : $rnnDataFormat.hashCode());
        return result;
    }
}

