/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff;

import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.GridOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.MetaOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.ShapeOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;

public class SameDiffOpExecutioner
implements OpExecutioner,
OpProfiler.OpProfilerListener {
    private Map<INDArray, SDVariable> variables;
    private SameDiff sameDiff;
    private AtomicReference<Op> opAtomicReference;
    private OpExecutioner backendExecutioner = Nd4j.getExecutioner();

    public SameDiffOpExecutioner() {
        this.variables = new IdentityHashMap<INDArray, SDVariable>();
        this.sameDiff = SameDiff.create();
        OpProfiler.getInstance().addListener(this);
    }

    private Op processOp(Op op) {
        SDVariable result;
        if (this.opAtomicReference == null) {
            this.opAtomicReference = new AtomicReference<Op>(op);
        }
        for (INDArray arr : new INDArray[]{op.x(), op.y(), op.z()}) {
            if (arr == null || this.variables.containsKey(arr)) continue;
            SDVariable sdVariable = this.sameDiff.var(UUID.randomUUID().toString(), arr);
            this.variables.put(arr, sdVariable);
        }
        if (op.x() != null && op.y() != null) {
            result = this.sameDiff.invoke(op, this.variables.get(op.x()), this.variables.get(op.y()));
            this.variables.put(op.z(), result);
        } else {
            result = this.sameDiff.invoke(op, this.variables.get(op.x()));
            this.variables.put(op.z(), result);
        }
        return op;
    }

    @Override
    public String getLastOp() {
        return this.opAtomicReference.get().opName();
    }

    @Override
    public Op exec(Op op) {
        return this.processOp(op);
    }

    @Override
    public void iterateOverAllRows(Op op) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void iterateOverAllColumns(Op op) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray execAndReturn(TransformOp op) {
        return this.processOp(op).z();
    }

    @Override
    public Accumulation execAndReturn(Accumulation op) {
        return (Accumulation)((Object)this.processOp(op).z());
    }

    @Override
    public Accumulation execAndReturn(Variance op, boolean biasCorrected) {
        return (Accumulation)this.processOp(op);
    }

    @Override
    public IndexAccumulation execAndReturn(IndexAccumulation op) {
        return (IndexAccumulation)this.processOp(op);
    }

    @Override
    public INDArray execAndReturn(ScalarOp op) {
        return this.processOp(op).z();
    }

    @Override
    public INDArray execAndReturn(BroadcastOp op) {
        return this.processOp(op).z();
    }

    @Override
    public INDArray execAndReturn(ShapeOp op) {
        return this.backendExecutioner.execAndReturn(op);
    }

    @Override
    public Op exec(Op op, int ... dimension) {
        return this.processOp(op);
    }

    @Override
    public INDArray exec(Accumulation accumulation, int ... dimension) {
        return this.processOp(accumulation).z();
    }

    @Override
    public INDArray exec(BroadcastOp broadcast, int ... dimension) {
        return this.processOp(broadcast).z();
    }

    @Override
    public INDArray exec(Variance accumulation, boolean biasCorrected, int ... dimension) {
        return this.processOp(accumulation).z();
    }

    @Override
    public INDArray exec(IndexAccumulation indexAccum, int ... dimension) {
        return this.processOp(indexAccum).z();
    }

    @Override
    public INDArray execAndReturn(Op op) {
        return this.processOp(op).z();
    }

    @Override
    public OpExecutioner.ExecutionMode executionMode() {
        return this.backendExecutioner.executionMode();
    }

    @Override
    public void setExecutionMode(OpExecutioner.ExecutionMode executionMode) {
        this.backendExecutioner.setExecutionMode(executionMode);
    }

    @Override
    public void exec(MetaOp op) {
    }

    @Override
    public void exec(GridOp op) {
    }

    @Override
    public void exec(Aggregate op) {
    }

    @Override
    public void exec(ShapeOp op) {
        this.backendExecutioner.exec(op);
    }

    @Override
    public <T extends Aggregate> void exec(Batch<T> batch) {
    }

    @Override
    public void exec(List<Aggregate> batch) {
    }

    @Override
    public INDArray exec(RandomOp op) {
        return this.processOp(op).z();
    }

    @Override
    public INDArray exec(RandomOp op, Random rng) {
        return this.processOp(op).z();
    }

    @Override
    public Properties getEnvironmentInformation() {
        return this.backendExecutioner.getEnvironmentInformation();
    }

    @Override
    public void setProfilingMode(OpExecutioner.ProfilingMode mode) {
        this.backendExecutioner.setProfilingMode(mode);
    }

    @Override
    public OpExecutioner.ProfilingMode getProfilingMode() {
        return this.backendExecutioner.getProfilingMode();
    }

    @Override
    public TADManager getTADManager() {
        return this.backendExecutioner.getTADManager();
    }

    @Override
    public void printEnvironmentInformation() {
        this.backendExecutioner.printEnvironmentInformation();
    }

    @Override
    public void push() {
        this.backendExecutioner.push();
    }

    @Override
    public void commit() {
        this.backendExecutioner.commit();
    }

    @Override
    public INDArray thresholdEncode(INDArray input, double threshold) {
        return this.backendExecutioner.thresholdEncode(input, threshold);
    }

    @Override
    public INDArray thresholdEncode(INDArray input, double threshold, Integer boundary) {
        return this.backendExecutioner.thresholdEncode(input, threshold, boundary);
    }

    @Override
    public INDArray thresholdDecode(INDArray encoded, INDArray target) {
        return this.backendExecutioner.thresholdDecode(encoded, target);
    }

    @Override
    public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
        return this.backendExecutioner.bitmapEncode(indArray, target, threshold);
    }

    @Override
    public INDArray bitmapEncode(INDArray indArray, double threshold) {
        return this.backendExecutioner.bitmapEncode(indArray, threshold);
    }

    @Override
    public INDArray bitmapDecode(INDArray encoded, INDArray target) {
        return this.backendExecutioner.bitmapDecode(encoded, target);
    }

    @Override
    public void invoke(Op op) {
        this.processOp(op);
    }

    @Override
    public Map<String, CustomOpDescriptor> getCustomOperations() {
        return this.backendExecutioner.getCustomOperations();
    }

    @Override
    public void exec(CustomOp op) {
        this.backendExecutioner.exec(op);
    }

    @Override
    public List<long[]> calculateOutputShape(CustomOp op) {
        return this.backendExecutioner.calculateOutputShape(op);
    }

    @Override
    public INDArray[] allocateOutputArrays(CustomOp op) {
        return this.backendExecutioner.allocateOutputArrays(op);
    }

    @Override
    public void registerGraph(long id, Pointer graph) {
        this.backendExecutioner.registerGraph(id, graph);
    }

    @Override
    public Map<String, INDArray> executeGraph(long id, @NonNull Map<String, INDArray> map, @NonNull Map<String, Integer> reverseMap) {
        if (map == null) {
            throw new NullPointerException("map is marked @NonNull but is null");
        }
        if (reverseMap == null) {
            throw new NullPointerException("reverseMap is marked @NonNull but is null");
        }
        return this.backendExecutioner.executeGraph(id, map, reverseMap);
    }

    @Override
    public void forgetGraph(long id) {
        this.backendExecutioner.forgetGraph(id);
    }

    @Override
    public void enableDebugMode(boolean reallyEnable) {
        this.backendExecutioner.enableDebugMode(reallyEnable);
    }

    @Override
    public void enableVerboseMode(boolean reallyEnable) {
        this.backendExecutioner.enableVerboseMode(reallyEnable);
    }

    @Override
    public void setElementsThreshold(int threshold) {
        this.backendExecutioner.setElementsThreshold(threshold);
    }

    @Override
    public void setTadThreshold(int threshold) {
        this.backendExecutioner.setTadThreshold(threshold);
    }

    @Override
    public OpExecutioner.ExecutionerType type() {
        return this.backendExecutioner.type();
    }

    @Override
    public boolean isVerbose() {
        return this.backendExecutioner.isVerbose();
    }

    @Override
    public boolean isDebug() {
        return this.backendExecutioner.isDebug();
    }

    public Map<INDArray, SDVariable> getVariables() {
        return this.variables;
    }

    public SameDiff getSameDiff() {
        return this.sameDiff;
    }

    public AtomicReference<Op> getOpAtomicReference() {
        return this.opAtomicReference;
    }

    public OpExecutioner getBackendExecutioner() {
        return this.backendExecutioner;
    }
}

