/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape.tensorops;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRead;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRemove;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayScatter;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySize;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayWrite;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class TensorArray
extends BaseTensorOp {
    protected DataType tensorArrayDataType;
    protected SDVariable flow;
    protected boolean clearOnRead = true;

    @Override
    public String tensorflowName() {
        return "TensorArrayV3";
    }

    public TensorArray(String name, SameDiff sameDiff, DataType dataType) {
        super(name, sameDiff, new SDVariable[0]);
        this.tensorArrayDataType = dataType;
    }

    public TensorArray(SameDiff sameDiff, DataType dataType) {
        super(sameDiff, new SDVariable[0]);
        this.tensorArrayDataType = dataType;
    }

    public TensorArray(TensorArray ta) {
        super(ta.sameDiff, new SDVariable[0]);
        this.tensorArrayDataType = ta.tensorArrayDataType;
    }

    public TensorArray(TensorArray ta, SDVariable[] inputs) {
        super(ta.sameDiff, inputs);
        this.tensorArrayDataType = ta.tensorArrayDataType;
    }

    @Override
    public void configureFromArguments() {
        super.configureFromArguments();
        if (!this.bArguments.isEmpty()) {
            this.clearOnRead = (Boolean)this.bArguments.get(0);
        }
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        String idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
        NodeDef iddNode = null;
        for (int i = 0; i < graph.getNodeCount(); ++i) {
            if (!graph.getNode(i).getName().equals(idd)) continue;
            iddNode = graph.getNode(i);
        }
        INDArray arr = TFGraphMapper.getNDArrayFromTensor(iddNode);
        if (arr != null) {
            int idx = arr.getInt(0);
            this.addIArgument(idx);
        }
        this.tensorArrayDataType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
    }

    public TensorArray() {
        this(DataType.FLOAT);
    }

    public TensorArray(DataType dataType) {
        this.tensorArrayDataType = dataType;
    }

    @Override
    public String toString() {
        return this.opName();
    }

    @Override
    public String opName() {
        return "create_list";
    }

    @Override
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }

    public SDVariable getVar() {
        if (this.flow != null) {
            return this.flow;
        }
        return this.outputVariables()[0];
    }

    @Override
    public SameDiff getSameDiff() {
        SameDiff sd = this.sameDiff;
        if (sd.getChild() != null) {
            return sd.getChild();
        }
        return sd;
    }

    private SDVariable intToVar(int ... index) {
        return this.sameDiff.constant(Nd4j.createFromArray(index));
    }

    public SDVariable read(int index) {
        return new TensorArrayRead(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(index)}).outputVariable();
    }

    public SDVariable read(SDVariable from, SDVariable index) {
        return new TensorArrayRead(this.getSameDiff(), new SDVariable[]{from, index}).outputVariable();
    }

    public SDVariable read(SDVariable index) {
        return new TensorArrayRead(this.getSameDiff(), new SDVariable[]{this.getVar(), index}).outputVariable();
    }

    public SDVariable gather(SDVariable flow, int ... indices) {
        return new TensorArrayGather(this.getSameDiff(), new SDVariable[]{this.getVar(), this.sameDiff.constant(Nd4j.createFromArray(indices)), flow}).outputVariable();
    }

    public SDVariable gather(SDVariable flow, SDVariable indices) {
        return new TensorArrayGather(this.getSameDiff(), new SDVariable[]{this.getVar(), indices, flow}).outputVariable();
    }

    public SDVariable stack(SDVariable flow) {
        return new TensorArrayGather(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(-1), flow}).outputVariable();
    }

    public SDVariable concat(SDVariable flow) {
        return new TensorArrayConcat(this.getSameDiff(), new SDVariable[]{this.getVar()}).outputVariable();
    }

    public SDVariable write(SDVariable flow, int index, SDVariable value) {
        return this.write(flow, this.intToVar(index), value);
    }

    public SDVariable write(SDVariable flow, SDVariable index, SDVariable value) {
        return new TensorArrayWrite(this.getSameDiff(), new SDVariable[]{this.getVar(), index, value, flow}).outputVariable();
    }

    public SDVariable scatter(SDVariable flow, SDVariable value, int ... indices) {
        return new TensorArrayScatter(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(indices), value, flow}).outputVariable();
    }

    public SDVariable scatter(SDVariable flow, SDVariable value, SDVariable indices) {
        return new TensorArrayScatter(this.getSameDiff(), new SDVariable[]{this.getVar(), indices, value, flow}).outputVariable();
    }

    public SDVariable unstack(SDVariable flow, SDVariable value) {
        return new TensorArrayScatter(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(-1), value, flow}).outputVariable();
    }

    public SDVariable size(SDVariable value) {
        return new TensorArraySize(this.getSameDiff(), value).outputVariable();
    }

    public SDVariable remove(SDVariable value, SDVariable idx) {
        return new TensorArrayRemove(this.getSameDiff(), value, idx).outputVariable();
    }

    public SDVariable remove(SDVariable value, int idx) {
        return new TensorArrayRemove(this.getSameDiff(), value, idx).outputVariable();
    }

    public SDVariable remove(SDVariable value) {
        return this.remove(value, -1);
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataType) {
        return Arrays.asList(DataType.BOOL, DataType.FLOAT);
    }

    @Override
    public int getNumOutputs() {
        return 2;
    }

    public static SDVariable itemAtIndex(SameDiff sd, SDVariable[] inputs) {
        return TensorArray.itemAtIndex(sd, inputs, null);
    }

    public static SDVariable itemAtIndex(SameDiff sd, SDVariable[] inputs, String outputVarName) {
        int i;
        SDVariable sequenceVar = inputs[0];
        SDVariable position = inputs.length < 2 ? sd.constant(-1) : inputs[1];
        TensorArray ta = TensorArray.getTensorArray(sd, sequenceVar);
        SDVariable read = ta.read(sequenceVar, position);
        for (i = 0; i < inputs.length; ++i) {
            read.addControlDependency(inputs[i]);
        }
        if (outputVarName != null) {
            read = read.rename(outputVarName);
        }
        for (i = 0; i < inputs.length; ++i) {
            read.addControlDependency(inputs[i]);
        }
        return read;
    }

    public long[] requiredShape() {
        Preconditions.checkState((this.args().length > 1 ? 1 : 0) != 0, (String)"Missing input shape.");
        INDArray inputShape = this.arg(1).getArr();
        long[] inputShapeArr = inputShape.toLongVector();
        return inputShapeArr;
    }

    public static TensorArray getTensorArray(SameDiff sd, SDVariable sequenceVar) {
        DifferentialFunction baseTensorOp = sd.getVariableOutputOp(sequenceVar.name());
        TensorArray ta = null;
        if (baseTensorOp instanceof TensorArray) {
            ta = (TensorArray)baseTensorOp;
        } else {
            while (!(baseTensorOp instanceof TensorArray)) {
                int n = 0;
                SDVariable[] sDVariableArray = baseTensorOp.args();
                int n2 = sDVariableArray.length;
                if (n >= n2) continue;
                SDVariable input = sDVariableArray[n];
                if (sd.getVariableOutputOp(input.name()) instanceof TensorArray) {
                    baseTensorOp = sd.getVariableOutputOp(input.name());
                    ta = (TensorArray)baseTensorOp;
                    return ta;
                }
                return TensorArray.getTensorArray(sd, input);
            }
        }
        return ta;
    }

    public static SDVariable removeFromTensorArray(SameDiff sameDiff, SDVariable inputSequence) {
        return TensorArray.removeFromTensorArray(sameDiff, inputSequence, sameDiff.constant(-1), null);
    }

    public static SDVariable removeFromTensorArray(SameDiff sameDiff, SDVariable inputSequence, SDVariable position) {
        return TensorArray.removeFromTensorArray(sameDiff, inputSequence, position, null);
    }

    public static SDVariable removeFromTensorArray(SameDiff sameDiff, SDVariable inputSequence, SDVariable position, String outputVarName) {
        TensorArray ta = TensorArray.getTensorArray(sameDiff, inputSequence);
        SDVariable outputVar = ta.remove(inputSequence, position);
        outputVar.addControlDependency(inputSequence);
        outputVar.addControlDependency(position);
        if (outputVarName != null) {
            return outputVar.rename(outputVarName);
        }
        return outputVar;
    }

    public static SDVariable sizeOfTensorArray(SameDiff sd, SDVariable sequence) {
        return TensorArray.sizeOfTensorArray(sd, sequence, null);
    }

    public static SDVariable sizeOfTensorArray(SameDiff sd, SDVariable sequence, String outputVarName) {
        TensorArray tensorArray = TensorArray.getTensorArray(sd, sequence);
        SDVariable outputVar = tensorArray.size(sequence);
        outputVar.addControlDependency(sequence);
        if (outputVarName != null) {
            outputVar = outputVar.rename(outputVarName);
        }
        return outputVar;
    }

    public static SDVariable createEmpty(SameDiff sd, DataType dataType) {
        return TensorArray.createEmpty(sd, dataType, null);
    }

    public static SDVariable createEmpty(SameDiff sd, DataType dataType, String outputVarName) {
        TensorArray ta = sd.tensorArray(dataType);
        SDVariable outputVar = ta.outputVariable();
        if (outputVar.name() != null) {
            return outputVar.rename(outputVarName);
        }
        return outputVar;
    }

    public static SDVariable createTensorArrayFrom(SameDiff sd, SDVariable[] inputs) {
        return TensorArray.createTensorArrayFrom(sd, inputs, null);
    }

    public static SDVariable createTensorArrayFrom(SameDiff sd, SDVariable[] inputs, String outputVarName) {
        TensorArray outputVar = sd.tensorArray(inputs[0].dataType());
        SDVariable outTmp = outputVar.getVar();
        for (int i = 0; i < inputs.length; ++i) {
            SDVariable write = outputVar.write(outTmp, i, inputs[i]);
            if (outTmp != null) {
                write.addControlDependency(outTmp);
            }
            outTmp = write;
        }
        if (outputVarName != null) {
            outTmp = outTmp.rename(outputVarName);
        }
        return outTmp;
    }

    public DataType getTensorArrayDataType() {
        return this.tensorArrayDataType;
    }

    public void setTensorArrayDataType(DataType tensorArrayDataType) {
        this.tensorArrayDataType = tensorArrayDataType;
    }

    public SDVariable getFlow() {
        return this.flow;
    }

    public void setFlow(SDVariable flow) {
        this.flow = flow;
    }

    public boolean isClearOnRead() {
        return this.clearOnRead;
    }

    public void setClearOnRead(boolean clearOnRead) {
        this.clearOnRead = clearOnRead;
    }
}

