/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.memory.impl;

import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.jita.memory.impl.CudaDirectProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaCachingZeroProvider
extends CudaDirectProvider
implements MemoryProvider {
    private static Logger log = LoggerFactory.getLogger(CudaCachingZeroProvider.class);
    protected volatile ConcurrentHashMap<AllocationShape, CacheHolder> zeroCache = new ConcurrentHashMap();
    protected final AtomicLong cacheZeroHit = new AtomicLong(0L);
    protected final AtomicLong cacheZeroMiss = new AtomicLong(0L);
    protected final AtomicLong cacheDeviceHit = new AtomicLong(0L);
    protected final AtomicLong cacheDeviceMiss = new AtomicLong(0L);
    private final AtomicLong allocRequests = new AtomicLong(0L);
    protected final AtomicLong zeroCachedAmount = new AtomicLong(0L);
    protected List<AtomicLong> deviceCachedAmount = new ArrayList<AtomicLong>();
    protected final Semaphore singleLock = new Semaphore(1);
    protected final long FORCED_CACHE_THRESHOLD = 96L;

    @Override
    public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
        long reqMemory = AllocationUtils.getRequiredMemory(shape);
        if (location == AllocationStatus.HOST && reqMemory < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength()) {
            Pointer pointer;
            CacheHolder cache = this.zeroCache.get(shape);
            if (cache != null && (pointer = cache.poll()) != null) {
                this.cacheZeroHit.incrementAndGet();
                this.zeroCachedAmount.addAndGet(-1L * reqMemory);
                PointersPair pair = new PointersPair();
                pair.setDevicePointer(new CudaPointer(pointer.address()));
                pair.setHostPointer(new CudaPointer(pointer.address()));
                point.setAllocationStatus(AllocationStatus.HOST);
                return pair;
            }
            this.cacheZeroMiss.incrementAndGet();
            if (CudaEnvironment.getInstance().getConfiguration().isUsePreallocation() && this.zeroCachedAmount.get() < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache() / 10L && reqMemory < 0x1000000L) {
                CachePreallocator preallocator = new CachePreallocator(shape, location, CudaEnvironment.getInstance().getConfiguration().getPreallocationCalls());
                preallocator.start();
            }
            this.cacheZeroMiss.incrementAndGet();
            return super.malloc(shape, point, location);
        }
        return super.malloc(shape, point, location);
    }

    protected void ensureCacheHolder(AllocationShape shape) {
        if (!this.zeroCache.containsKey(shape)) {
            try {
                this.singleLock.acquire();
                if (!this.zeroCache.containsKey(shape)) {
                    this.zeroCache.put(shape, new CacheHolder(shape, this.zeroCachedAmount));
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            finally {
                this.singleLock.release();
            }
        }
    }

    @Override
    public void free(AllocationPoint point) {
        if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
            super.free(point);
        } else {
            AllocationShape shape = point.getShape();
            long reqMemory = AllocationUtils.getRequiredMemory(shape);
            if (reqMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength() || this.zeroCachedAmount.get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache()) {
                super.free(point);
                return;
            }
            this.ensureCacheHolder(shape);
            CacheHolder cache = this.zeroCache.get(shape);
            if (reqMemory <= 96L) {
                Pointer.memset((Pointer)point.getHostPointer(), (int)0, (long)reqMemory);
                cache.put(new CudaPointer(point.getHostPointer().address()));
            } else {
                long cacheEntries = cache.size();
                long cacheHeight = this.zeroCache.size();
                long cacheDepth = cacheEntries * reqMemory;
                Pointer.memset((Pointer)point.getHostPointer(), (int)0, (long)reqMemory);
                cache.put(new CudaPointer(point.getHostPointer().address()));
            }
        }
    }

    private float getZeroCacheHitRatio() {
        long totalHits = this.cacheZeroHit.get() + this.cacheZeroMiss.get();
        float cacheRatio = (float)(this.cacheZeroHit.get() * 100L) / (float)totalHits;
        return cacheRatio;
    }

    private float getDeviceCacheHitRatio() {
        long totalHits = this.cacheDeviceHit.get() + this.cacheDeviceMiss.get();
        float cacheRatio = (float)(this.cacheDeviceHit.get() * 100L) / (float)totalHits;
        return cacheRatio;
    }

    @Deprecated
    public void printCacheStats() {
        log.debug("Cached host amount: " + this.zeroCachedAmount.get());
        log.debug("Cached device amount: " + this.deviceCachedAmount.get(0).get());
        log.debug("Total shapes in cache: " + this.zeroCache.size());
        log.debug("Current host hit ratio: " + this.getZeroCacheHitRatio());
        log.debug("Current device hit ratio: " + this.getDeviceCacheHitRatio());
    }

    @Override
    public void purgeCache() {
        for (AllocationShape shape : this.zeroCache.keySet()) {
            Pointer ptr = null;
            while ((ptr = this.zeroCache.get(shape).poll()) != null) {
                this.freeHost(ptr);
            }
        }
        this.zeroCachedAmount.set(0L);
    }

    protected class CachePreallocator
    extends Thread
    implements Runnable {
        private AllocationShape shape;
        private AllocationStatus location;
        private int target;

        public CachePreallocator(AllocationShape shape, AllocationStatus location, int numberOfEntries) {
            this.shape = shape;
            this.target = numberOfEntries;
            this.location = location;
        }

        @Override
        public void run() {
            CudaCachingZeroProvider.this.ensureCacheHolder(this.shape);
            for (int i = 0; i < this.target; ++i) {
                AllocationPoint point = new AllocationPoint();
                PointersPair pair = CudaCachingZeroProvider.super.malloc(this.shape, point, this.location);
                if (this.location != AllocationStatus.HOST) continue;
                CudaPointer pointer = new CudaPointer(pair.getHostPointer().address());
                CudaCachingZeroProvider.this.zeroCache.get(this.shape).put(pointer);
            }
        }
    }

    protected class CacheHolder {
        private Queue<Pointer> queue = new ConcurrentLinkedQueue<Pointer>();
        private AtomicInteger counter = new AtomicInteger(0);
        private long reqMem = 0L;
        private final AtomicLong allocCounter;

        public CacheHolder(AllocationShape shape, AtomicLong counter) {
            this.reqMem = AllocationUtils.getRequiredMemory(shape);
            this.allocCounter = counter;
        }

        public int size() {
            return this.counter.get();
        }

        public Pointer poll() {
            Pointer pointer = this.queue.poll();
            if (pointer != null) {
                this.counter.decrementAndGet();
            }
            return pointer;
        }

        public void put(Pointer pointer) {
            this.allocCounter.addAndGet(this.reqMem);
            this.counter.incrementAndGet();
            this.queue.add(pointer);
        }
    }
}

