package ai.djl.tensorflow.engine;

import ai.djl.ndarray.types.DataType;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:ai/djl/tensorflow/engine/TfDataType.class */
public final class TfDataType {
    private static Map<DataType, org.tensorflow.DataType<? extends TType>> toTf = createMapToTf();
    private static Map<org.tensorflow.DataType<? extends TType>, DataType> fromTf = createMapFromTf();

    private TfDataType() {
    }

    private static Map<DataType, org.tensorflow.DataType<? extends TType>> createMapToTf() {
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        concurrentHashMap.put(DataType.FLOAT32, TFloat32.DTYPE);
        concurrentHashMap.put(DataType.FLOAT64, TFloat64.DTYPE);
        concurrentHashMap.put(DataType.INT32, TInt32.DTYPE);
        concurrentHashMap.put(DataType.INT64, TInt64.DTYPE);
        concurrentHashMap.put(DataType.UINT8, TUint8.DTYPE);
        concurrentHashMap.put(DataType.INT8, TUint8.DTYPE);
        concurrentHashMap.put(DataType.BOOLEAN, TBool.DTYPE);
        return concurrentHashMap;
    }

    private static Map<org.tensorflow.DataType<? extends TType>, DataType> createMapFromTf() {
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        concurrentHashMap.put(TFloat32.DTYPE, DataType.FLOAT32);
        concurrentHashMap.put(TFloat64.DTYPE, DataType.FLOAT64);
        concurrentHashMap.put(TInt32.DTYPE, DataType.INT32);
        concurrentHashMap.put(TInt64.DTYPE, DataType.INT64);
        concurrentHashMap.put(TUint8.DTYPE, DataType.UINT8);
        concurrentHashMap.put(TBool.DTYPE, DataType.BOOLEAN);
        return concurrentHashMap;
    }

    public static DataType fromTf(org.tensorflow.DataType<? extends TType> dataType) {
        return fromTf.get(dataType);
    }

    public static org.tensorflow.DataType<? extends TType> toTf(DataType dataType) {
        return toTf.get(dataType);
    }
}
