/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.samediff;

import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

public abstract class SameDiffVertex
extends org.deeplearning4j.nn.conf.graph.GraphVertex
implements TrainingConfig {
    private SDVertexParams vertexParams;
    private String name;
    protected List<Regularization> regularization;
    protected List<Regularization> regularizationBias;
    protected IUpdater updater;
    protected IUpdater biasUpdater;
    protected GradientNormalization gradientNormalization;
    protected double gradientNormalizationThreshold = Double.NaN;

    public abstract SDVariable defineVertex(SameDiff var1, Map<String, SDVariable> var2, Map<String, SDVariable> var3, Map<String, SDVariable> var4);

    public abstract void defineParametersAndInputs(SDVertexParams var1);

    public abstract void initializeParameters(Map<String, INDArray> var1);

    public SDVertexParams getVertexParams() {
        if (this.vertexParams == null) {
            this.vertexParams = new SDVertexParams();
            this.defineParametersAndInputs(this.vertexParams);
        }
        return this.vertexParams;
    }

    @Override
    public org.deeplearning4j.nn.conf.graph.GraphVertex clone() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public long numParams(boolean backprop) {
        SDVertexParams params = this.getVertexParams();
        long count = 0L;
        for (long[] l : params.getParamShapes().values()) {
            count += ArrayUtil.prodLong((long[])l);
        }
        return (int)count;
    }

    @Override
    public int minVertexInputs() {
        return 1;
    }

    @Override
    public int maxVertexInputs() {
        return -1;
    }

    @Override
    public GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
        this.name = name;
        return new SameDiffGraphVertex(this, graph, name, idx, paramsView, initializeParams, networkDatatype);
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType ... vertexInputs) throws InvalidInputTypeException {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) {
        throw new UnsupportedOperationException("Not yet supported");
    }

    public void validateInput(INDArray[] input) {
    }

    @Override
    public MemoryReport getMemoryReport(InputType ... inputTypes) {
        return null;
    }

    public char paramReshapeOrder(String paramName) {
        return 'c';
    }

    public void applyGlobalConfig(NeuralNetConfiguration.Builder b) {
        if (this.regularization == null || this.regularization.isEmpty()) {
            this.regularization = b.getRegularization();
        }
        if (this.regularizationBias == null || this.regularizationBias.isEmpty()) {
            this.regularizationBias = b.getRegularizationBias();
        }
        if (this.updater == null) {
            this.updater = b.getIUpdater();
        }
        if (this.biasUpdater == null) {
            this.biasUpdater = b.getBiasUpdater();
        }
        if (this.gradientNormalization == null) {
            this.gradientNormalization = b.getGradientNormalization();
        }
        if (Double.isNaN(this.gradientNormalizationThreshold)) {
            this.gradientNormalizationThreshold = b.getGradientNormalizationThreshold();
        }
        this.applyGlobalConfigToLayer(b);
    }

    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) {
    }

    @Override
    public String getLayerName() {
        return this.name;
    }

    @Override
    public List<Regularization> getRegularizationByParam(String paramName) {
        if ((this.regularization == null || this.regularization.isEmpty()) && (this.regularizationBias == null || this.regularizationBias.isEmpty())) {
            return null;
        }
        if (this.getVertexParams().isWeightParam(paramName)) {
            return this.regularization;
        }
        if (this.getVertexParams().isBiasParam(paramName)) {
            return this.regularizationBias;
        }
        throw new IllegalStateException("Unknown parameter name: " + paramName + " - not in weights (" + this.getVertexParams().getWeightParameterKeys() + ") or biases (" + this.getVertexParams().getBiasParameterKeys() + ")");
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        return false;
    }

    @Override
    public IUpdater getUpdaterByParam(String paramName) {
        if (this.getVertexParams().isWeightParam(paramName)) {
            return this.updater;
        }
        if (this.getVertexParams().isBiasParam(paramName)) {
            if (this.biasUpdater == null) {
                return this.updater;
            }
            return this.biasUpdater;
        }
        throw new IllegalStateException("Unknown parameter name: " + paramName + " - not in weights (" + this.getVertexParams().getWeightParameterKeys() + ") or biases (" + this.getVertexParams().getBiasParameterKeys() + ")");
    }

    @Override
    public GradientNormalization getGradientNormalization() {
        return this.gradientNormalization;
    }

    @Override
    public double getGradientNormalizationThreshold() {
        return this.gradientNormalizationThreshold;
    }

    public String getName() {
        return this.name;
    }

    public List<Regularization> getRegularization() {
        return this.regularization;
    }

    public List<Regularization> getRegularizationBias() {
        return this.regularizationBias;
    }

    public IUpdater getUpdater() {
        return this.updater;
    }

    public IUpdater getBiasUpdater() {
        return this.biasUpdater;
    }

    public void setVertexParams(SDVertexParams vertexParams) {
        this.vertexParams = vertexParams;
    }

    public void setName(String name) {
        this.name = name;
    }

    public void setRegularization(List<Regularization> regularization) {
        this.regularization = regularization;
    }

    public void setRegularizationBias(List<Regularization> regularizationBias) {
        this.regularizationBias = regularizationBias;
    }

    public void setUpdater(IUpdater updater) {
        this.updater = updater;
    }

    public void setBiasUpdater(IUpdater biasUpdater) {
        this.biasUpdater = biasUpdater;
    }

    public void setGradientNormalization(GradientNormalization gradientNormalization) {
        this.gradientNormalization = gradientNormalization;
    }

    public void setGradientNormalizationThreshold(double gradientNormalizationThreshold) {
        this.gradientNormalizationThreshold = gradientNormalizationThreshold;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof SameDiffVertex)) {
            return false;
        }
        SameDiffVertex other = (SameDiffVertex)o;
        if (!other.canEqual(this)) {
            return false;
        }
        SDVertexParams this$vertexParams = this.getVertexParams();
        SDVertexParams other$vertexParams = other.getVertexParams();
        if (this$vertexParams == null ? other$vertexParams != null : !((Object)this$vertexParams).equals(other$vertexParams)) {
            return false;
        }
        String this$name = this.getName();
        String other$name = other.getName();
        if (this$name == null ? other$name != null : !this$name.equals(other$name)) {
            return false;
        }
        List<Regularization> this$regularization = this.getRegularization();
        List<Regularization> other$regularization = other.getRegularization();
        if (this$regularization == null ? other$regularization != null : !((Object)this$regularization).equals(other$regularization)) {
            return false;
        }
        List<Regularization> this$regularizationBias = this.getRegularizationBias();
        List<Regularization> other$regularizationBias = other.getRegularizationBias();
        if (this$regularizationBias == null ? other$regularizationBias != null : !((Object)this$regularizationBias).equals(other$regularizationBias)) {
            return false;
        }
        IUpdater this$updater = this.getUpdater();
        IUpdater other$updater = other.getUpdater();
        if (this$updater == null ? other$updater != null : !this$updater.equals(other$updater)) {
            return false;
        }
        IUpdater this$biasUpdater = this.getBiasUpdater();
        IUpdater other$biasUpdater = other.getBiasUpdater();
        if (this$biasUpdater == null ? other$biasUpdater != null : !this$biasUpdater.equals(other$biasUpdater)) {
            return false;
        }
        GradientNormalization this$gradientNormalization = this.getGradientNormalization();
        GradientNormalization other$gradientNormalization = other.getGradientNormalization();
        if (this$gradientNormalization == null ? other$gradientNormalization != null : !((Object)((Object)this$gradientNormalization)).equals((Object)other$gradientNormalization)) {
            return false;
        }
        return Double.compare(this.getGradientNormalizationThreshold(), other.getGradientNormalizationThreshold()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof SameDiffVertex;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        SDVertexParams $vertexParams = this.getVertexParams();
        result = result * 59 + ($vertexParams == null ? 43 : ((Object)$vertexParams).hashCode());
        String $name = this.getName();
        result = result * 59 + ($name == null ? 43 : $name.hashCode());
        List<Regularization> $regularization = this.getRegularization();
        result = result * 59 + ($regularization == null ? 43 : ((Object)$regularization).hashCode());
        List<Regularization> $regularizationBias = this.getRegularizationBias();
        result = result * 59 + ($regularizationBias == null ? 43 : ((Object)$regularizationBias).hashCode());
        IUpdater $updater = this.getUpdater();
        result = result * 59 + ($updater == null ? 43 : $updater.hashCode());
        IUpdater $biasUpdater = this.getBiasUpdater();
        result = result * 59 + ($biasUpdater == null ? 43 : $biasUpdater.hashCode());
        GradientNormalization $gradientNormalization = this.getGradientNormalization();
        result = result * 59 + ($gradientNormalization == null ? 43 : ((Object)((Object)$gradientNormalization)).hashCode());
        long $gradientNormalizationThreshold = Double.doubleToLongBits(this.getGradientNormalizationThreshold());
        result = result * 59 + (int)($gradientNormalizationThreshold >>> 32 ^ $gradientNormalizationThreshold);
        return result;
    }

    public String toString() {
        return "SameDiffVertex(vertexParams=" + this.getVertexParams() + ", name=" + this.getName() + ", regularization=" + this.getRegularization() + ", regularizationBias=" + this.getRegularizationBias() + ", updater=" + this.getUpdater() + ", biasUpdater=" + this.getBiasUpdater() + ", gradientNormalization=" + (Object)((Object)this.getGradientNormalization()) + ", gradientNormalizationThreshold=" + this.getGradientNormalizationThreshold() + ")";
    }
}

