/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.util;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

public class NDArrayMath {
    private NDArrayMath() {
    }

    public static long offsetForSlice(INDArray arr, int slice) {
        return (long)slice * NDArrayMath.lengthPerSlice(arr);
    }

    public static long lengthPerSlice(INDArray arr, int ... dimension) {
        long[] remove = ArrayUtil.removeIndex((long[])arr.shape(), (int[])dimension);
        return ArrayUtil.prodLong((long[])remove);
    }

    public static long lengthPerSlice(INDArray arr) {
        return NDArrayMath.lengthPerSlice(arr, 0);
    }

    public static long numVectors(INDArray arr) {
        if (arr.rank() == 1) {
            return 1L;
        }
        if (arr.rank() == 2) {
            return arr.size(0);
        }
        int prod = 1;
        for (int i = 0; i < arr.rank() - 1; ++i) {
            prod = (int)((long)prod * arr.size(i));
        }
        return prod;
    }

    public static long vectorsPerSlice(INDArray arr) {
        if (arr.rank() > 2) {
            return ArrayUtil.prodLong((long[])new long[]{arr.size(-1), arr.size(-2)});
        }
        return arr.slices();
    }

    public static long tensorsPerSlice(INDArray arr, int[] tensorShape) {
        return NDArrayMath.lengthPerSlice(arr) / (long)ArrayUtil.prod((int[])tensorShape);
    }

    public static long matricesPerSlice(INDArray arr) {
        if (arr.rank() == 3) {
            return 1L;
        }
        if (arr.rank() > 3) {
            int ret = 1;
            for (int i = 1; i < arr.rank() - 2; ++i) {
                ret = (int)((long)ret * arr.size(i));
            }
            return ret;
        }
        return arr.size(-2);
    }

    public static long vectorsPerSlice(INDArray arr, int ... rank) {
        if (arr.rank() > 2) {
            return arr.size(-2) * arr.size(-1);
        }
        return arr.size(-1);
    }

    public static long sliceOffsetForTensor(int index, INDArray arr, int[] tensorShape) {
        long tensorLength = ArrayUtil.prodLong((int[])tensorShape);
        long lengthPerSlice = NDArrayMath.lengthPerSlice(arr);
        long offset = (long)index * tensorLength / lengthPerSlice;
        return offset;
    }

    public static long sliceOffsetForTensor(int index, INDArray arr, long[] tensorShape) {
        long tensorLength = ArrayUtil.prodLong((long[])tensorShape);
        long lengthPerSlice = NDArrayMath.lengthPerSlice(arr);
        long offset = (long)index * tensorLength / lengthPerSlice;
        return offset;
    }

    public static int mapIndexOntoTensor(int index, INDArray arr, int ... rank) {
        int ret = index * ArrayUtil.prod((long[])ArrayUtil.removeIndex((long[])arr.shape(), (int[])rank));
        return ret;
    }

    public static long mapIndexOntoVector(int index, INDArray arr) {
        long ret = (long)index * arr.size(-1);
        return ret;
    }
}

