package ai.djl.tensorflow.engine;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.BlockList;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;

/* loaded from: input_file:ai/djl/tensorflow/engine/TfSymbolBlock.class */
public class TfSymbolBlock implements SymbolBlock {
    private NDManager manager;
    private SavedModelBundle bundle;
    private MetaGraphDef metaGraphDef;
    private Session session;

    public TfSymbolBlock(NDManager nDManager, SavedModelBundle savedModelBundle) {
        this.manager = nDManager;
        this.bundle = savedModelBundle;
        this.session = savedModelBundle.session();
        this.metaGraphDef = savedModelBundle.metaGraphDef();
    }

    public void removeLastBlock() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        Session.Runner runner = this.session.runner();
        PairList<String, Shape> describeInput = describeInput();
        PairList<String, Shape> describeOutput = describeOutput();
        for (int i = 0; i < describeInput.size(); i++) {
            runner.feed((String) describeInput.get(i).getKey(), ((TfNDArray) nDList.get(i)).getTensor());
        }
        for (int i2 = 0; i2 < describeOutput.size(); i2++) {
            runner.fetch((String) describeOutput.get(i2).getKey());
        }
        List run = runner.run();
        NDList nDList2 = new NDList();
        Iterator it = run.iterator();
        while (it.hasNext()) {
            nDList2.add(this.manager.create((Tensor<?>) it.next()));
        }
        return nDList2;
    }

    public void setInitializer(Initializer initializer) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public void setInitializer(Initializer initializer, String str) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Shape[] initialize(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        return new Shape[0];
    }

    public boolean isInitialized() {
        return this.bundle != null;
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public void clear() {
        if (this.session != null) {
            this.session.close();
        }
        if (this.bundle != null) {
            this.bundle.close();
        }
    }

    public PairList<String, Shape> describeInput() {
        PairList<String, Shape> pairList = new PairList<>();
        Map signatureDefMap = this.metaGraphDef.getSignatureDefMap();
        for (Map.Entry entry : this.metaGraphDef.getSignatureDefOrDefault("serving_default", (SignatureDef) signatureDefMap.get(signatureDefMap.keySet().iterator().next())).getInputsMap().entrySet()) {
            pairList.add(((TensorInfo) entry.getValue()).getName(), new Shape(((TensorInfo) entry.getValue()).getTensorShape().getDimList().stream().mapToLong((v0) -> {
                return v0.getSize();
            }).toArray()));
        }
        return pairList;
    }

    PairList<String, Shape> describeOutput() {
        PairList<String, Shape> pairList = new PairList<>();
        Map signatureDefMap = this.metaGraphDef.getSignatureDefMap();
        for (Map.Entry entry : this.metaGraphDef.getSignatureDefOrDefault("serving_default", (SignatureDef) signatureDefMap.get(signatureDefMap.keySet().iterator().next())).getOutputsMap().entrySet()) {
            pairList.add(((TensorInfo) entry.getValue()).getName(), new Shape(((TensorInfo) entry.getValue()).getTensorShape().getDimList().stream().mapToLong((v0) -> {
                return v0.getSize();
            }).toArray()));
        }
        return pairList;
    }

    public BlockList getChildren() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public List<Parameter> getDirectParameters() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public ParameterList getParameters() {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Shape getParameterShape(String str, Shape[] shapeArr) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[0];
    }

    public void saveParameters(DataOutputStream dataOutputStream) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }

    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) {
        throw new UnsupportedOperationException("Not supported for TensorFlow Engine");
    }
}
