/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.nativeblas;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.BaseNDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseNativeNDArrayFactory
extends BaseNDArrayFactory {
    private static final Logger log = LoggerFactory.getLogger(BaseNativeNDArrayFactory.class);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

    public BaseNativeNDArrayFactory(DataType dtype, Character order) {
        super(dtype, order);
    }

    public BaseNativeNDArrayFactory(DataType dtype, char order) {
        super(dtype, order);
    }

    public BaseNativeNDArrayFactory() {
    }

    public Pointer convertToNumpy(INDArray array) {
        LongPointer size = new LongPointer(1L);
        Pointer header = NativeOpsHolder.getInstance().getDeviceNativeOps().numpyHeaderForNd4j(array.data().pointer(), array.shapeInfoDataBuffer().pointer(), array.data().getElementSize(), size);
        long headerSize = size.get() - 1L;
        header.capacity(headerSize);
        header.position(0L);
        BytePointer bytePointer = new BytePointer((long)((int)(headerSize + (long)array.data().getElementSize() * array.data().length())));
        BytePointer headerCast = new BytePointer(header);
        ByteIndexer indexer = ByteIndexer.create((BytePointer)headerCast);
        int pos = 0;
        bytePointer.position((long)pos);
        Pointer.memcpy((Pointer)bytePointer, (Pointer)headerCast, (long)headerCast.capacity());
        pos = (int)((long)pos + headerCast.capacity());
        bytePointer.position((long)pos);
        Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.HOST);
        Pointer.memcpy((Pointer)bytePointer, (Pointer)array.data().pointer(), (long)((long)array.data().getElementSize() * array.data().length()));
        bytePointer.position(0L);
        return bytePointer;
    }

    public INDArray createFromNpyPointer(Pointer pointer) {
        FloatPointer dPointer;
        Pointer dataPointer = this.nativeOps.dataPointForNumpy(pointer);
        int dataBufferElementSize = this.nativeOps.elementSizeForNpyArray(pointer);
        DataBuffer data = null;
        Pointer shapeBufferPointer = this.nativeOps.shapeBufferForNumpy(pointer);
        int length = this.nativeOps.lengthForShapeBufferPointer(shapeBufferPointer);
        shapeBufferPointer.capacity((long)(8 * length));
        shapeBufferPointer.limit((long)(8 * length));
        shapeBufferPointer.position(0L);
        LongPointer intPointer = new LongPointer(shapeBufferPointer);
        LongPointer newPointer = new LongPointer((long)length);
        long perfD = PerformanceTracker.getInstance().helperStartTransaction();
        Pointer.memcpy((Pointer)newPointer, (Pointer)intPointer, (long)shapeBufferPointer.limit());
        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, shapeBufferPointer.limit(), MemcpyDirection.HOST_TO_HOST);
        DataBuffer shapeBuffer = Nd4j.createBuffer((Pointer)newPointer, (DataType)DataType.LONG, (long)length, (Indexer)LongIndexer.create((LongPointer)newPointer));
        dataPointer.position(0L);
        dataPointer.limit((long)dataBufferElementSize * Shape.length((DataBuffer)shapeBuffer));
        dataPointer.capacity((long)dataBufferElementSize * Shape.length((DataBuffer)shapeBuffer));
        if (dataBufferElementSize == 4) {
            dPointer = new FloatPointer(dataPointer.limit() / (long)dataBufferElementSize);
            long perfX = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy((Pointer)dPointer, (Pointer)dataPointer, (long)dataPointer.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
            data = Nd4j.createBuffer((Pointer)dPointer, (DataType)DataType.FLOAT, (long)Shape.length((DataBuffer)shapeBuffer), (Indexer)FloatIndexer.create((FloatPointer)dPointer));
        } else if (dataBufferElementSize == 8) {
            dPointer = new DoublePointer(dataPointer.limit() / (long)dataBufferElementSize);
            long perfX = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy((Pointer)dPointer, (Pointer)dataPointer, (long)dataPointer.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
            data = Nd4j.createBuffer((Pointer)dPointer, (DataType)DataType.DOUBLE, (long)Shape.length((DataBuffer)shapeBuffer), (Indexer)DoubleIndexer.create((DoublePointer)dPointer));
        }
        INDArray ret = Nd4j.create(data, (long[])Shape.shape((DataBuffer)shapeBuffer), (long[])Shape.strideArr((DataBuffer)shapeBuffer), (long)0L, (char)Shape.order((DataBuffer)shapeBuffer));
        Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.DEVICE);
        return ret;
    }

    public INDArray createFromNpyHeaderPointer(Pointer pointer) {
        FloatPointer dPointer;
        Pointer dataPointer = this.nativeOps.dataPointForNumpyHeader(pointer);
        int dataBufferElementSize = this.nativeOps.elementSizeForNpyArrayHeader(pointer);
        DataBuffer data = null;
        Pointer shapeBufferPointer = this.nativeOps.shapeBufferForNumpyHeader(pointer);
        int length = this.nativeOps.lengthForShapeBufferPointer(shapeBufferPointer);
        shapeBufferPointer.capacity((long)(8 * length));
        shapeBufferPointer.limit((long)(8 * length));
        shapeBufferPointer.position(0L);
        LongPointer intPointer = new LongPointer(shapeBufferPointer);
        LongPointer newPointer = new LongPointer((long)length);
        long perfD = PerformanceTracker.getInstance().helperStartTransaction();
        Pointer.memcpy((Pointer)newPointer, (Pointer)intPointer, (long)shapeBufferPointer.limit());
        PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, shapeBufferPointer.limit(), MemcpyDirection.HOST_TO_HOST);
        DataBuffer shapeBuffer = Nd4j.createBuffer((Pointer)newPointer, (DataType)DataType.LONG, (long)length, (Indexer)LongIndexer.create((LongPointer)newPointer));
        dataPointer.position(0L);
        dataPointer.limit((long)dataBufferElementSize * Shape.length((DataBuffer)shapeBuffer));
        dataPointer.capacity((long)dataBufferElementSize * Shape.length((DataBuffer)shapeBuffer));
        if (dataBufferElementSize == 4) {
            dPointer = new FloatPointer(dataPointer.limit() / (long)dataBufferElementSize);
            long perfX = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy((Pointer)dPointer, (Pointer)dataPointer, (long)dataPointer.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
            data = Nd4j.createBuffer((Pointer)dPointer, (DataType)DataType.FLOAT, (long)Shape.length((DataBuffer)shapeBuffer), (Indexer)FloatIndexer.create((FloatPointer)dPointer));
        } else if (dataBufferElementSize == 8) {
            dPointer = new DoublePointer(dataPointer.limit() / (long)dataBufferElementSize);
            long perfX = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy((Pointer)dPointer, (Pointer)dataPointer, (long)dataPointer.limit());
            PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
            data = Nd4j.createBuffer((Pointer)dPointer, (DataType)DataType.DOUBLE, (long)Shape.length((DataBuffer)shapeBuffer), (Indexer)DoubleIndexer.create((DoublePointer)dPointer));
        }
        INDArray ret = Nd4j.create(data, (long[])Shape.shape((DataBuffer)shapeBuffer), (long[])Shape.strideArr((DataBuffer)shapeBuffer), (long)0L, (char)Shape.order((DataBuffer)shapeBuffer));
        return ret;
    }

    public INDArray createFromNpyFile(File file) {
        byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
        ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
        directBuffer.put(pathBytes);
        directBuffer.rewind();
        directBuffer.position(0);
        Pointer pointer = this.nativeOps.numpyFromFile(new BytePointer(directBuffer));
        INDArray result = this.createFromNpyPointer(pointer);
        this.nativeOps.releaseNumpy(pointer);
        return result;
    }

    public Map<String, INDArray> createFromNpzFile(File file) throws Exception {
        HashMap<String, INDArray> map = new HashMap<String, INDArray>();
        FileInputStream is = new FileInputStream(file);
        while (true) {
            int i;
            int elemSize;
            int b;
            byte[] localHeader = new byte[30];
            ((InputStream)is).read(localHeader);
            if (localHeader[2] != 3 || localHeader[3] != 4) break;
            byte fNameLength = localHeader[26];
            byte[] fNameBytes = new byte[fNameLength];
            ((InputStream)is).read(fNameBytes);
            String fName = "";
            for (int i2 = 0; i2 < fNameLength - 4; ++i2) {
                fName = fName + (char)fNameBytes[i2];
            }
            byte extraFieldLength = localHeader[28];
            if (extraFieldLength > 0) {
                ((InputStream)is).read(new byte[extraFieldLength]);
            }
            ((InputStream)is).read(new byte[11]);
            String headerStr = "";
            while ((b = ((InputStream)is).read()) != 10) {
                headerStr = headerStr + (char)b;
            }
            int idx = headerStr.indexOf("'<") + 2;
            String typeStr = headerStr.substring(idx, idx + 2);
            if (typeStr.equals("f8")) {
                elemSize = 8;
            } else if (typeStr.equals("f4")) {
                elemSize = 4;
            } else {
                throw new Exception("Unsupported data type: " + typeStr);
            }
            idx = headerStr.indexOf("'fortran_order': ");
            char order = headerStr.charAt(idx + "'fortran_order': ".length()) == 'F' ? (char)'c' : 'f';
            String shapeStr = headerStr.substring(headerStr.indexOf("(") + 1, headerStr.indexOf(")"));
            shapeStr = shapeStr.replace(" ", "");
            String[] dims = shapeStr.split(",");
            long[] shape = new long[dims.length];
            long size = 1L;
            for (int i3 = 0; i3 < dims.length; ++i3) {
                long d;
                shape[i3] = d = Long.parseLong(dims[i3]);
                size *= d;
            }
            int numBytes = (int)(size * (long)elemSize);
            byte[] data = new byte[numBytes];
            ((InputStream)is).read(data);
            ByteBuffer bb = ByteBuffer.wrap(data);
            if (elemSize == 8) {
                double[] doubleData = new double[(int)size];
                i = 0;
                while ((long)i < size) {
                    doubleData[i] = bb.getDouble(i);
                    ++i;
                }
                map.put(fName, Nd4j.create((double[])doubleData, (long[])shape, (char)order));
                continue;
            }
            double[] floatData = new double[(int)size];
            i = 0;
            while ((long)i < size) {
                floatData[i] = bb.getFloat(i);
                ++i;
            }
            map.put(fName, Nd4j.create((double[])floatData, (long[])shape, (char)order));
        }
        return map;
    }

    public Map<String, INDArray> _createFromNpzFile(File file) throws Exception {
        byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
        ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
        directBuffer.put(pathBytes);
        directBuffer.rewind();
        directBuffer.position(0);
        Pointer pointer = this.nativeOps.mapFromNpzFile(new BytePointer(directBuffer));
        int n = this.nativeOps.getNumNpyArraysInMap(pointer);
        HashMap<String, INDArray> map = new HashMap<String, INDArray>();
        for (int i = 0; i < n; ++i) {
            INDArray arr;
            DataBuffer data;
            FloatPointer dPointer;
            String arrName = this.nativeOps.getNpyArrayNameFromMap(pointer, i);
            Pointer arrPtr = this.nativeOps.getNpyArrayFromMap(pointer, i);
            int ndim = this.nativeOps.getNpyArrayRank(arrPtr);
            long[] shape = new long[ndim];
            LongPointer shapePtr = this.nativeOps.getNpyArrayShape(arrPtr);
            long length = 1L;
            for (int j = 0; j < ndim; ++j) {
                shape[j] = shapePtr.get((long)j);
                length *= shape[j];
            }
            int numBytes = this.nativeOps.getNpyArrayElemSize(arrPtr);
            int elemSize = numBytes * 8;
            char order = this.nativeOps.getNpyArrayOrder(arrPtr);
            Pointer dataPointer = this.nativeOps.dataPointForNumpyStruct(arrPtr);
            dataPointer.position(0L);
            long size = (long)elemSize * length;
            dataPointer.limit(size);
            dataPointer.capacity(size);
            if (elemSize == 32) {
                dPointer = new FloatPointer(dataPointer.limit() / (long)elemSize);
                data = Nd4j.createBuffer((Pointer)dPointer, (DataType)DataType.FLOAT, (long)length, (Indexer)FloatIndexer.create((FloatPointer)dPointer));
                arr = Nd4j.create((DataBuffer)data, (long[])shape, (long[])Nd4j.getStrides((long[])shape, (char)order), (long)0L, (char)order, (DataType)DataType.FLOAT);
            } else if (elemSize == 64) {
                dPointer = new DoublePointer(dataPointer.limit() / (long)elemSize);
                data = Nd4j.createBuffer((Pointer)dPointer, (DataType)DataType.DOUBLE, (long)length, (Indexer)DoubleIndexer.create((DoublePointer)dPointer));
                arr = Nd4j.create((DataBuffer)data, (long[])shape, (long[])Nd4j.getStrides((long[])shape, (char)order), (long)0L, (char)order, (DataType)DataType.DOUBLE);
            } else {
                throw new Exception("Unsupported data type: " + String.valueOf(elemSize));
            }
            map.put(arrName, arr);
        }
        return map;
    }
}

