package ai.djl.tensorflow.engine;

import ai.djl.Device;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.NDIndexBooleans;
import ai.djl.ndarray.index.NDIndexFullSlice;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.BroadcastTo;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Gather;
import org.tensorflow.op.core.Max;
import org.tensorflow.op.core.Min;
import org.tensorflow.op.core.Prod;
import org.tensorflow.op.core.ReduceAll;
import org.tensorflow.op.core.ReduceAny;
import org.tensorflow.op.core.Squeeze;
import org.tensorflow.op.core.Sum;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.linalg.Transpose;
import org.tensorflow.op.math.Cumsum;
import org.tensorflow.op.math.Equal;
import org.tensorflow.op.math.Mean;
import org.tensorflow.op.math.NotEqual;
import org.tensorflow.op.nn.TopK;
import org.tensorflow.op.train.BatchMatMul;
import org.tensorflow.tools.buffer.ByteDataBuffer;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:ai/djl/tensorflow/engine/TfNDArray.class */
public class TfNDArray implements NDArray {
    private static final int MAX_SIZE = 100;
    private static final int MAX_DEPTH = 10;
    private static final int MAX_ROWS = 10;
    private static final int MAX_COLUMNS = 20;
    private static final int MAX_OUTPUTS_PER_OP = 8;
    private String uid = UUID.randomUUID().toString();
    private Tensor<?> tensor;
    private Shape shape;
    private TfNDManager manager;
    private Ops tf;
    private Operand<?> operand;
    private String name;
    private TfNDArrayEx tfNDArrayEx;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.djl.tensorflow.engine.TfNDArray$1, reason: invalid class name */
    /* loaded from: input_file:ai/djl/tensorflow/engine/TfNDArray$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$djl$ndarray$types$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.INT8.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.INT32.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.INT64.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.FLOAT16.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.FLOAT32.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.FLOAT64.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    public TfNDArray(NDManager nDManager, Tensor<?> tensor) {
        this.manager = (TfNDManager) nDManager;
        this.manager.attach(getUid(), this);
        this.tensor = tensor;
        this.shape = new Shape(tensor.shape().asArray());
        this.tf = this.manager.getTf();
        this.tfNDArrayEx = new TfNDArrayEx(this);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    public TfNDArray(NDManager nDManager, Operand<?> operand) {
        this.manager = (TfNDManager) nDManager;
        this.manager.attach(getUid(), this);
        this.tensor = operand.asOutput().tensor();
        this.shape = new Shape(this.tensor.shape().asArray());
        this.tf = this.manager.getTf();
        this.tfNDArrayEx = new TfNDArrayEx(this);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public TfNDArray(NDManager nDManager, Shape shape, FloatBuffer floatBuffer) {
        this.manager = (TfNDManager) nDManager;
        this.manager.attach(getUid(), this);
        this.tensor = Tensor.of(TFloat32.DTYPE, toTfShape(shape), toDataBuffer(floatBuffer));
        this.shape = shape;
        this.tf = this.manager.getTf();
        this.tfNDArrayEx = new TfNDArrayEx(this);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    public TfNDArray(NDManager nDManager, Shape shape, ByteBuffer byteBuffer) {
        this.manager = (TfNDManager) nDManager;
        this.manager.attach(getUid(), this);
        this.shape = shape;
        this.tf = this.manager.getTf();
        this.tensor = Tensor.of(TUint8.DTYPE, toTfShape(shape), DataBuffers.of(byteBuffer));
        this.tfNDArrayEx = new TfNDArrayEx(this);
    }

    public NDManager getManager() {
        return this.manager;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String str) {
        this.name = str;
    }

    public final String getUid() {
        return this.uid;
    }

    public DataType getDataType() {
        return TfDataType.fromTf(getTfDataType());
    }

    public Device getDevice() {
        return this.manager.getDevice();
    }

    public Shape getShape() {
        if (this.shape == null) {
            this.shape = new Shape(this.tensor.shape().asArray());
        }
        return this.shape;
    }

    public org.tensorflow.DataType<? extends TType> getTfDataType() {
        return this.tensor.dataType();
    }

    public SparseFormat getSparseFormat() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public boolean isSparse() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray toDevice(Device device, boolean z) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray toType(DataType dataType, boolean z) {
        Operand cast = this.tf.dtypes.cast(asOperand(), TfDataType.toTf(dataType), new Cast.Options[0]);
        if (z) {
            cast = this.tf.deepCopy(cast);
        }
        return new TfNDArray((NDManager) this.manager, (Operand<?>) cast);
    }

    public void attachGradient() {
    }

    public void attachGradient(SparseFormat sparseFormat) {
    }

    public NDArray getGradient() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public double[] toDoubleArray() {
        double[] dArr = new double[(int) getShape().size()];
        this.tensor.rawData().asDoubles().read(dArr);
        return dArr;
    }

    public float[] toFloatArray() {
        float[] fArr = new float[(int) getShape().size()];
        this.tensor.rawData().asFloats().read(fArr);
        return fArr;
    }

    public int[] toIntArray() {
        int[] iArr = new int[(int) getShape().size()];
        this.tensor.rawData().asInts().read(iArr);
        return iArr;
    }

    public long[] toLongArray() {
        long[] jArr = new long[(int) getShape().size()];
        this.tensor.rawData().asLongs().read(jArr);
        return jArr;
    }

    public boolean[] toBooleanArray() {
        boolean[] zArr = new boolean[(int) getShape().size()];
        this.tensor.rawData().asBooleans().read(zArr);
        return zArr;
    }

    public ByteBuffer toByteBuffer() {
        Shape shape = getShape();
        DataType dataType = getDataType();
        byte[] bArr = new byte[Math.toIntExact(dataType.getNumOfBytes() * shape.size())];
        this.tensor.rawData().read(bArr);
        return ByteBuffer.wrap(bArr);
    }

    public void set(Buffer buffer) {
        throw new UnsupportedOperationException("Tensor cannot be modified after creation");
    }

    public void set(NDIndex nDIndex, NDArray nDArray) {
        throw new UnsupportedOperationException("Tensor cannot be modified after creation");
    }

    public void set(NDIndex nDIndex, Number number) {
        throw new UnsupportedOperationException("Tensor cannot be modified after creation");
    }

    public void setScalar(NDIndex nDIndex, Number number) {
        throw new UnsupportedOperationException("Tensor cannot be modified after creation");
    }

    public NDArray get(NDIndex nDIndex) {
        if (nDIndex.getRank() == 0 && getShape().isScalar()) {
            return this;
        }
        List indices = nDIndex.getIndices();
        if (!indices.isEmpty() && (indices.get(0) instanceof NDIndexBooleans)) {
            if (indices.size() != 1) {
                throw new IllegalArgumentException("get() currently didn't support more that one boolean NDArray");
            }
            return booleanMask(((NDIndexBooleans) indices.get(0)).getIndex());
        }
        NDIndexFullSlice nDIndexFullSlice = (NDIndexFullSlice) nDIndex.getAsFullSlice(getShape()).orElse(null);
        if (nDIndexFullSlice == null) {
            throw new UnsupportedOperationException("get() currently supports all, fixed, and slices indices");
        }
        long[] min = nDIndexFullSlice.getMin();
        long[] max = nDIndexFullSlice.getMax();
        long[] jArr = new long[min.length];
        Arrays.setAll(jArr, i -> {
            return max[i] - min[i];
        });
        Operand slice = this.tf.slice(asOperand(), this.tf.constant(min), this.tf.constant(jArr));
        if (!nDIndexFullSlice.getToSqueeze().isEmpty()) {
            slice = this.tf.squeeze(slice, new Squeeze.Options[]{Squeeze.axis((List) nDIndexFullSlice.getToSqueeze().stream().map((v0) -> {
                return v0.longValue();
            }).collect(Collectors.toList()))});
        }
        return new TfNDArray((NDManager) this.manager, (Operand<?>) slice);
    }

    public void copyTo(NDArray nDArray) {
        if (!(nDArray instanceof TfNDArray)) {
            throw new IllegalArgumentException("Only TfNDArray is supported.");
        }
        Shape shape = getShape();
        Shape shape2 = nDArray.getShape();
        if (!Arrays.equals(shape.getShape(), shape2.getShape())) {
            throw new IllegalArgumentException("shape are diff. Required: " + shape2 + ", Actual " + shape);
        }
        ((TfNDArray) nDArray).tensor = this.tf.deepCopy(asOperand()).asOutput().tensor();
        ((TfNDArray) nDArray).operand = null;
        ((TfNDArray) nDArray).shape = new Shape(this.tensor.shape().asArray());
    }

    public NDArray booleanMask(NDArray nDArray, int i) {
        if (!isScalar()) {
            return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.gather(asOperand(), this.tf.squeeze(this.tf.where(((TfNDArray) nDArray).asOperand()), new Squeeze.Options[]{Squeeze.axis(Collections.singletonList(1L))}), this.tf.constant(i), new Gather.Options[0]));
        }
        if (nDArray.isScalar()) {
            return nDArray.toBooleanArray()[0] ? expandDims(0) : this.manager.create(new Shape(new long[0]));
        }
        throw new IllegalArgumentException("Input is scalar, index must also be scalar.");
    }

    public NDArray sequenceMask(NDArray nDArray, float f) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    public NDArray sequenceMask(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    public NDArray zerosLike() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.zerosLike(asOperand()));
    }

    public NDArray onesLike() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.onesLike(asOperand()));
    }

    public boolean contentEquals(Number number) {
        if (number == null) {
            return false;
        }
        NDArray eq = eq(number);
        Throwable th = null;
        try {
            try {
                boolean z = eq.all().getBoolean(new long[0]);
                if (eq != null) {
                    if (0 != 0) {
                        try {
                            eq.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        eq.close();
                    }
                }
                return z;
            } finally {
            }
        } catch (Throwable th3) {
            if (eq != null) {
                if (th != null) {
                    try {
                        eq.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    eq.close();
                }
            }
            throw th3;
        }
    }

    public boolean contentEquals(NDArray nDArray) {
        if (nDArray != null && shapeEquals(nDArray) && getDataType() == nDArray.getDataType()) {
            return ((TfNDArray) eq(nDArray)).all().toBooleanArray()[0];
        }
        return false;
    }

    public NDArray eq(Number number) {
        return eq(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray eq(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.equal(asOperand(), ((TfNDArray) nDArray).asOperand(), new Equal.Options[0]).asOutput());
    }

    public NDArray neq(Number number) {
        return neq(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray neq(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.notEqual(asOperand(), ((TfNDArray) nDArray).asOperand(), new NotEqual.Options[0]).asOutput());
    }

    public NDArray gt(Number number) {
        return gt(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray gt(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.greater(asOperand(), ((TfNDArray) nDArray).asOperand()).asOutput());
    }

    public NDArray gte(Number number) {
        return gte(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray gte(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.greaterEqual(asOperand(), ((TfNDArray) nDArray).asOperand()).asOutput());
    }

    public NDArray lt(Number number) {
        return lt(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray lt(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.less(asOperand(), ((TfNDArray) nDArray).asOperand()).asOutput());
    }

    public NDArray lte(Number number) {
        return lte(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray lte(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.lessEqual(asOperand(), ((TfNDArray) nDArray).asOperand()).asOutput());
    }

    public NDArray all() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.reduceAll(this.tf.dtypes.cast(asOperand(), TBool.DTYPE, new Cast.Options[0]), this.tf.range(this.tf.constant(0L), this.tf.constant(getRank()), this.tf.constant(1L)), new ReduceAll.Options[0]));
    }

    public NDArray any() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.reduceAny(this.tf.dtypes.cast(asOperand(), TBool.DTYPE, new Cast.Options[0]), this.tf.range(this.tf.constant(0L), this.tf.constant(getRank()), this.tf.constant(1L)), new ReduceAny.Options[0]));
    }

    public NDArray add(Number number) {
        return add(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray add(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.add(asOperand(), ((TfNDArray) nDArray).asOperand()));
    }

    public NDArray sub(Number number) {
        return sub(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray sub(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.sub(asOperand(), ((TfNDArray) nDArray).asOperand()));
    }

    public NDArray mul(Number number) {
        return mul(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray mul(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.mul(asOperand(), ((TfNDArray) nDArray).asOperand()));
    }

    public NDArray div(Number number) {
        return div(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray div(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.div(asOperand(), ((TfNDArray) nDArray).asOperand()));
    }

    public NDArray mod(Number number) {
        return mod(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray mod(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.mod(asOperand(), ((TfNDArray) nDArray).asOperand()));
    }

    public NDArray pow(Number number) {
        return pow(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray pow(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.pow(asOperand(), ((TfNDArray) nDArray).asOperand()));
    }

    public NDArray maximum(Number number) {
        return maximum(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray maximum(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.maximum(asOperand(), ((TfNDArray) nDArray).asOperand()));
    }

    public NDArray minimum(Number number) {
        return minimum(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray minimum(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.minimum(asOperand(), ((TfNDArray) nDArray).asOperand()));
    }

    public NDArray addi(Number number) {
        return addi(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray addi(NDArray nDArray) {
        return inPlaceHelper(add(nDArray), this);
    }

    public NDArray subi(Number number) {
        return subi(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray subi(NDArray nDArray) {
        return inPlaceHelper(sub(nDArray), this);
    }

    public NDArray muli(Number number) {
        return muli(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray muli(NDArray nDArray) {
        return inPlaceHelper(mul(nDArray), this);
    }

    public NDArray divi(Number number) {
        return divi(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray divi(NDArray nDArray) {
        return inPlaceHelper(div(nDArray), this);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public NDArray inPlaceHelper(NDArray nDArray, NDArray nDArray2) {
        if (getShape().isScalar()) {
            throw new UnsupportedOperationException("TensorFlow engine does not support inplace operations on scalars yet");
        }
        ((TfNDArray) nDArray2).setTensor(this.tf.inplaceUpdate(((TfNDArray) nDArray2).asOperand(), this.tf.range(this.tf.constant(0), this.tf.constant((int) getShape().getShape()[0]), this.tf.constant(1)), ((TfNDArray) nDArray).asOperand()).asOutput().tensor());
        ((TfNDArray) nDArray2).clearOperand();
        return nDArray2;
    }

    public NDArray toSparse(SparseFormat sparseFormat) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray modi(Number number) {
        return modi(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray modi(NDArray nDArray) {
        return inPlaceHelper(mod(nDArray), this);
    }

    public NDArray powi(Number number) {
        return powi(this.manager.create(number).toType(getDataType(), false));
    }

    public NDArray powi(NDArray nDArray) {
        return inPlaceHelper(pow(nDArray), this);
    }

    NDArray rpowi(NDArray nDArray) {
        return inPlaceHelper(nDArray.pow(this), this);
    }

    public NDArray neg() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.neg(asOperand()));
    }

    public NDArray negi() {
        return inPlaceHelper(neg(), this);
    }

    public NDArray abs() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.abs(asOperand()));
    }

    public NDArray square() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.square(asOperand()));
    }

    public NDArray sqrt() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.sqrt(asOperand()));
    }

    public NDArray cbrt() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.pow(asOperand(), toConstant(Float.valueOf(0.33333334f), getDataType())));
    }

    public NDArray floor() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.floor(asOperand()));
    }

    public NDArray ceil() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.ceil(asOperand()));
    }

    public NDArray round() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.round(asOperand()));
    }

    public NDArray trunc() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray exp() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.exp(asOperand()));
    }

    public NDArray log() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.log(asOperand()));
    }

    public NDArray log10() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.div(this.tf.math.log(asOperand()), this.tf.math.log(toConstant(10, getDataType()))));
    }

    public NDArray log2() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.div(this.tf.math.log(asOperand()), this.tf.math.log(toConstant(2, getDataType()))));
    }

    public NDArray sin() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.sin(asOperand()));
    }

    public NDArray cos() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.cos(asOperand()));
    }

    public NDArray tan() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.tan(asOperand()));
    }

    public NDArray asin() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.asin(asOperand()));
    }

    public NDArray acos() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.acos(asOperand()));
    }

    public NDArray atan() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.atan(asOperand()));
    }

    public NDArray sinh() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.sinh(asOperand()));
    }

    public NDArray cosh() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.cosh(asOperand()));
    }

    public NDArray tanh() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.tanh(asOperand()));
    }

    public NDArray asinh() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.asinh(asOperand()));
    }

    public NDArray acosh() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.acosh(asOperand()));
    }

    public NDArray atanh() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.atanh(asOperand()));
    }

    public NDArray toDegrees() {
        return mul((Number) 180).div(Double.valueOf(3.141592653589793d));
    }

    public NDArray toRadians() {
        return mul(Double.valueOf(3.141592653589793d)).div(180);
    }

    public NDArray max() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.max(asOperand(), ((TfNDArray) this.manager.arange(0, getRank(), 1)).asOperand(), new Max.Options[0]));
    }

    public NDArray max(int[] iArr, boolean z) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.max(asOperand(), this.tf.constant(iArr), new Max.Options[]{Max.keepDims(Boolean.valueOf(z))}));
    }

    public NDArray min() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.min(asOperand(), ((TfNDArray) this.manager.arange(0, getRank(), 1)).asOperand(), new Min.Options[0]));
    }

    public NDArray min(int[] iArr, boolean z) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.min(asOperand(), this.tf.constant(iArr), new Min.Options[]{Min.keepDims(Boolean.valueOf(z))}));
    }

    public NDArray sum() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.sum(getDataType() == DataType.BOOLEAN ? this.tf.dtypes.cast(asOperand(), TInt64.DTYPE, new Cast.Options[0]) : asOperand(), this.tf.range(this.tf.constant(0L), this.tf.constant(getRank()), this.tf.constant(1L)), new Sum.Options[0]));
    }

    public NDArray sum(int[] iArr, boolean z) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.sum(asOperand(), ((TfNDArray) this.manager.create(iArr)).asOperand(), new Sum.Options[]{Sum.keepDims(Boolean.valueOf(z))}));
    }

    public NDArray prod() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.prod(asOperand(), this.tf.range(this.tf.constant(0L), this.tf.constant(getRank()), this.tf.constant(1L)), new Prod.Options[0]));
    }

    public NDArray prod(int[] iArr, boolean z) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.prod(asOperand(), this.tf.constant(iArr), new Prod.Options[]{Prod.keepDims(Boolean.valueOf(z))}).asOutput());
    }

    public NDArray mean() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.mean(asOperand(), ((TfNDArray) this.manager.arange(0, getRank(), 1)).asOperand(), new Mean.Options[0]));
    }

    public NDArray mean(int[] iArr, boolean z) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.mean(asOperand(), this.tf.constant(iArr), new Mean.Options[]{Mean.keepDims(Boolean.valueOf(z))}).asOutput());
    }

    public NDArray trace(int i, int i2, int i3) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDList split(long[] jArr, int i) {
        if (jArr.length <= MAX_OUTPUTS_PER_OP) {
            return splitHelper(jArr, i);
        }
        NDList nDList = new NDList();
        long j = getShape().get(i);
        int i2 = 0;
        while (i2 < (jArr.length - MAX_OUTPUTS_PER_OP) + 2) {
            long[] jArr2 = new long[MAX_OUTPUTS_PER_OP];
            for (int i3 = 0; i3 < 7; i3++) {
                jArr2[i3] = jArr[i2 + i3];
            }
            jArr2[7] = j;
            NDList splitHelper = splitHelper(jArr2, i);
            splitHelper.remove(splitHelper.get(splitHelper.size() - 1));
            if (i2 > 0) {
                splitHelper.remove(splitHelper.get(0));
            }
            nDList.addAll(splitHelper);
            i2 += 6;
        }
        long[] jArr3 = new long[jArr.length - i2];
        for (int i4 = 0; i4 < jArr3.length; i4++) {
            jArr3[i4] = jArr[i2 + i4];
        }
        NDList splitHelper2 = splitHelper(jArr3, i);
        splitHelper2.remove(splitHelper2.get(0));
        nDList.addAll(splitHelper2);
        return nDList;
    }

    private NDList splitHelper(long[] jArr, int i) {
        NDList nDList = new NDList();
        ArrayList arrayList = new ArrayList();
        int length = jArr.length - 1;
        long j = getShape().get(i);
        if (jArr[0] > 0) {
            arrayList.add(Long.valueOf(jArr[0]));
        }
        for (int i2 = 1; i2 < jArr.length; i2++) {
            arrayList.add(Long.valueOf(jArr[i2] - jArr[i2 - 1]));
        }
        if (jArr[length] < j) {
            arrayList.add(Long.valueOf(j - jArr[length]));
        }
        long sum = arrayList.stream().mapToLong((v0) -> {
            return v0.longValue();
        }).sum();
        if (sum != getShape().get(i)) {
            throw new IllegalArgumentException("split sizes :" + sum + " must sum to dimension on axis " + i + ": " + getShape().get(i));
        }
        this.tf.splitV(asOperand(), this.tf.constant(arrayList.stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray()), this.tf.constant(i), Long.valueOf(arrayList.size())).forEach(operand -> {
            nDList.add(new TfNDArray((NDManager) this.manager, (Operand<?>) operand));
        });
        return nDList;
    }

    public NDArray flatten() {
        return reshape(new Shape(new long[]{-1}));
    }

    public NDArray reshape(Shape shape) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.reshape(asOperand(), this.tf.constant(shape.getShape())));
    }

    public NDArray reshapeLike(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray expandDims(int i) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.expandDims(asOperand(), this.tf.constant(i)));
    }

    public NDArray squeeze(int[] iArr) {
        if (isScalar()) {
            iArr = new int[0];
        }
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.squeeze(asOperand(), new Squeeze.Options[]{Squeeze.axis((List) Arrays.stream(iArr).mapToLong(i -> {
            return i;
        }).boxed().collect(Collectors.toList()))}));
    }

    public NDArray logicalAnd(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.logicalAnd(this.tf.dtypes.cast(asOperand(), TBool.DTYPE, new Cast.Options[0]), this.tf.dtypes.cast(((TfNDArray) nDArray).asOperand(), TBool.DTYPE, new Cast.Options[0])));
    }

    public NDArray logicalOr(NDArray nDArray) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.logicalOr(this.tf.dtypes.cast(asOperand(), TBool.DTYPE, new Cast.Options[0]), this.tf.dtypes.cast(((TfNDArray) nDArray).asOperand(), TBool.DTYPE, new Cast.Options[0])));
    }

    public NDArray logicalXor(NDArray nDArray) {
        return logicalOr(nDArray).logicalAnd(logicalAnd(nDArray).logicalNot());
    }

    public NDArray logicalNot() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.logicalNot(this.tf.dtypes.cast(asOperand(), TBool.DTYPE, new Cast.Options[0])));
    }

    public NDArray argSort(int i, boolean z) {
        return sortHelper(i, z, true);
    }

    public NDArray sort(int i) {
        return sortHelper(i, true, false);
    }

    public NDArray sort() {
        return sortHelper(-1, true, false);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v51, types: [ai.djl.ndarray.NDArray] */
    private NDArray sortHelper(int i, boolean z, boolean z2) {
        TfNDArray tfNDArray;
        Transpose asOperand;
        int i2;
        if (isScalar()) {
            return this;
        }
        int rank = getRank();
        if (i == -1 || i + 1 == getShape().dimension()) {
            tfNDArray = null;
            asOperand = asOperand();
            long[] shape = getShape().getShape();
            i2 = (int) shape[shape.length - 1];
        } else {
            i2 = (int) getShape().getShape()[i];
            tfNDArray = NDArrays.concat(new NDList(new NDArray[]{this.manager.arange(0, i, 1, DataType.INT32, getDevice()), this.manager.create(new int[]{rank - 1}), this.manager.arange(i + 1, rank - 1, 1, DataType.INT32, getDevice()), this.manager.create(new int[]{i})}));
            asOperand = this.tf.linalg.transpose(asOperand(), tfNDArray.asOperand());
        }
        TopK pKVar = z ? this.tf.nn.topK(this.tf.math.neg(asOperand), this.tf.constant(i2), new TopK.Options[0]) : this.tf.nn.topK(asOperand, this.tf.constant(i2), new TopK.Options[0]);
        Cast cast = z2 ? this.tf.dtypes.cast(pKVar.indices(), TInt64.DTYPE, new Cast.Options[0]) : pKVar.values();
        if (tfNDArray != null) {
            cast = this.tf.linalg.transpose(cast, tfNDArray.asOperand());
            tfNDArray.close();
        }
        if (z && !z2) {
            cast = this.tf.math.neg(cast);
        }
        return new TfNDArray((NDManager) this.manager, (Operand<?>) cast);
    }

    public NDArray softmax(int[] iArr, float f) {
        if (f != 1.0d) {
            throw new UnsupportedOperationException("TensorFlow softmax didn't suuport temperature");
        }
        return new TfNDArray((NDManager) this.manager, (Operand<?>) softmaxHelper(iArr, false));
    }

    public NDArray logSoftmax(int[] iArr, float f) {
        if (f != 1.0d) {
            throw new UnsupportedOperationException("TensorFlow softmax didn't suuport temperature");
        }
        return new TfNDArray((NDManager) this.manager, (Operand<?>) softmaxHelper(iArr, true));
    }

    private Operand softmaxHelper(int[] iArr, boolean z) {
        long dimension = getShape().dimension();
        if (iArr.length > 1) {
            throw new UnsupportedOperationException("TensorFlow softmax does not support multiple axes");
        }
        if (dimension == 0) {
            return asOperand();
        }
        if (iArr[0] == -1 || iArr[0] == dimension - 1) {
            return z ? this.tf.nn.logSoftmax(asOperand()) : this.tf.nn.softmax(asOperand());
        }
        if (iArr[0] < (-dimension) || iArr[0] >= dimension) {
            throw new IllegalArgumentException("Invalid axes value: " + iArr[0] + ", must be in range [" + (-dimension) + ", " + dimension + ") where " + dimension + " is the number of dimensions in the input.");
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.tf.range(this.tf.constant(0L), this.tf.constant(iArr[0] % dimension), this.tf.constant(1L)));
        arrayList.add(this.tf.expandDims(this.tf.constant(dimension - 1), this.tf.constant(0)));
        arrayList.add(this.tf.range(this.tf.constant(iArr[0] + 1), this.tf.constant(dimension - 1), this.tf.constant(1L)));
        arrayList.add(this.tf.expandDims(this.tf.constant(iArr[0]), this.tf.constant(0)));
        Transpose transpose = this.tf.linalg.transpose(asOperand(), this.tf.concat(arrayList, this.tf.constant(0)));
        return this.tf.linalg.transpose(z ? this.tf.nn.logSoftmax(transpose) : this.tf.nn.softmax(transpose), this.tf.concat(arrayList, this.tf.constant(0)));
    }

    public NDArray cumSum(int i) {
        return isScalar() ? expandDims(0) : Arrays.stream(getShape().getShape()).anyMatch(j -> {
            return j == 0;
        }) ? this.manager.create(new Shape(new long[]{0})) : new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.cumsum(asOperand(), this.tf.constant(i), new Cumsum.Options[0]));
    }

    public NDArray cumSum() {
        return cumSum(0);
    }

    public NDArray isInfinite() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.dtypes.cast(this.tf.math.isInf(asOperand()), TBool.DTYPE, new Cast.Options[0]));
    }

    public NDArray isNaN() {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.dtypes.cast(this.tf.math.isNan(asOperand()), TBool.DTYPE, new Cast.Options[0]));
    }

    public NDArray createMask(NDIndex nDIndex) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray createMask(Predicate<Number> predicate) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray tile(long j) {
        long[] jArr = new long[getShape().dimension()];
        Arrays.fill(jArr, j);
        return tile(jArr);
    }

    public NDArray tile(int i, long j) {
        long[] jArr = new long[getShape().dimension()];
        Arrays.fill(jArr, 1L);
        jArr[i] = j;
        return tile(jArr);
    }

    public NDArray tile(long[] jArr) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.tile(asOperand(), this.tf.constant(jArr)));
    }

    public NDArray tile(Shape shape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray repeat(long j) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray repeat(int i, long j) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray repeat(long[] jArr) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray repeat(Shape shape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray dot(NDArray nDArray) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray matMul(NDArray nDArray) {
        if (isScalar() || nDArray.isScalar()) {
            throw new IllegalArgumentException("scalar is not allowed for matMul()");
        }
        if (getShape().dimension() > 2 || nDArray.getShape().dimension() > 2) {
            return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.train.batchMatMul(asOperand(), ((TfNDArray) nDArray).asOperand(), new BatchMatMul.Options[0]));
        }
        BroadcastTo asOperand = asOperand();
        BroadcastTo asOperand2 = ((TfNDArray) nDArray).asOperand();
        boolean z = false;
        if (getShape().dimension() == 1) {
            asOperand = this.tf.broadcastTo(asOperand(), this.tf.constant(new long[]{1, getShape().get(0)}));
            z = true;
        }
        if (nDArray.getShape().dimension() == 1) {
            asOperand2 = this.tf.broadcastTo(((TfNDArray) nDArray).asOperand(), this.tf.constant(new long[]{1, getShape().get(0)}));
            z = true;
        }
        return z ? new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.linalg.matMul(asOperand, asOperand2, new MatMul.Options[0])).squeeze() : new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.linalg.matMul(asOperand, asOperand2, new MatMul.Options[0]));
    }

    public NDArray clip(Number number, Number number2) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.clipByValue(asOperand(), toConstant(number, getDataType()), toConstant(number2, getDataType())));
    }

    public NDArray transpose() {
        int dimension = getShape().dimension();
        return transpose(IntStream.range(0, dimension).map(i -> {
            return (dimension - i) - 1;
        }).toArray());
    }

    public NDArray transpose(int... iArr) {
        if (Arrays.stream(iArr).anyMatch(i -> {
            return i < 0;
        })) {
            throw new UnsupportedOperationException("Passing -1 for broadcasting the dimension is not currently supported");
        }
        if (Arrays.equals(Arrays.stream(iArr).sorted().toArray(), IntStream.range(0, getShape().dimension()).toArray())) {
            return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.linalg.transpose(asOperand(), this.tf.constant(iArr)));
        }
        throw new IllegalArgumentException("You must include each of the dimensions from 0 until " + getShape().dimension());
    }

    public NDArray broadcast(Shape shape) {
        return new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.broadcastTo(asOperand(), this.tf.constant(shape.getShape())));
    }

    public NDArray argMax() {
        if (isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        return flatten().argMax(0);
    }

    public NDArray argMax(int i) {
        return isScalar() ? this.manager.create(0L) : new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.argMax(asOperand(), this.tf.constant(i)));
    }

    public NDArray argMin() {
        if (isEmpty()) {
            throw new IllegalArgumentException("attempt to get argMin of an empty NDArray");
        }
        return flatten().argMin(0);
    }

    public NDArray argMin(int i) {
        return isScalar() ? this.manager.create(0L) : new TfNDArray((NDManager) this.manager, (Operand<?>) this.tf.math.argMin(asOperand(), this.tf.constant(i)));
    }

    public NDArray percentile(Number number) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray percentile(Number number, int[] iArr) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray median() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray median(int[] iArr) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray toDense() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArray nonzero() {
        throw new UnsupportedOperationException("Not implemented");
    }

    public NDArrayEx getNDArrayInternal() {
        return this.tfNDArrayEx;
    }

    public boolean equals(Object obj) {
        if (obj instanceof TfNDArray) {
            return contentEquals((TfNDArray) obj);
        }
        return false;
    }

    public int hashCode() {
        return 0;
    }

    public String toString() {
        return this.tensor == null ? "This array is already closed" : toDebugString(MAX_SIZE, 10, 10, MAX_COLUMNS);
    }

    public void close() {
        if (this.tensor != null) {
            this.tensor.close();
        }
        this.tensor = null;
        this.tf = null;
        this.operand = null;
        this.tfNDArrayEx = null;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <T extends TType> Operand<T> asOperand() {
        if (this.operand == null) {
            this.operand = this.tf.constant(this.tensor);
        }
        return (Operand<T>) this.operand;
    }

    public Tensor<?> getTensor() {
        return this.tensor;
    }

    void setTensor(Tensor<?> tensor) {
        this.tensor = tensor;
    }

    void clearOperand() {
        this.operand = null;
    }

    int getRank() {
        return this.tf.rank(asOperand()).asOutput().tensor().rawData().asInts().getInt(0L);
    }

    private <T extends TType> Constant<T> toConstant(Number number, DataType dataType) {
        return getConstant(number, dataType, this.tf);
    }

    public static org.tensorflow.tools.Shape toTfShape(Shape shape) {
        return org.tensorflow.tools.Shape.of(shape.getShape());
    }

    public static ByteDataBuffer toDataBuffer(FloatBuffer floatBuffer) {
        ByteBuffer allocate = ByteBuffer.allocate(floatBuffer.remaining() * 4);
        allocate.asFloatBuffer().put(floatBuffer);
        return DataBuffers.of(allocate);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static <T extends TType> Constant<T> getConstant(Number number, DataType dataType, Ops ops) {
        switch (AnonymousClass1.$SwitchMap$ai$djl$ndarray$types$DataType[dataType.ordinal()]) {
            case 1:
                return ops.constant(number.byteValue());
            case 2:
                return ops.constant(number.intValue());
            case 3:
                return ops.constant(number.longValue());
            case 4:
                return ops.constant(number.shortValue());
            case 5:
                return ops.constant(number.floatValue());
            case 6:
                return ops.constant(number.doubleValue());
            default:
                throw new EngineException("unsupported type");
        }
    }
}
