/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class Gather
extends DynamicCustomOp {
    protected int[] indices;
    protected int jaxis = 0;

    public Gather() {
    }

    public Gather(SameDiff sameDiff, SDVariable df, SDVariable indices, int axis) {
        this(sameDiff, df, indices, axis, false);
    }

    public Gather(SameDiff sameDiff, SDVariable df, int[] indices, int axis) {
        this(sameDiff, df, indices, axis, false);
    }

    public Gather(SameDiff sameDiff, SDVariable input, int[] indices, int axis, boolean inPlace) {
        super(null, sameDiff, new SDVariable[]{input, sameDiff.constant(Nd4j.createFromArray(indices))}, inPlace);
        this.addIArgument(axis);
        this.addIArgument(indices);
        this.jaxis = axis;
        this.indices = indices;
    }

    public Gather(SameDiff sameDiff, SDVariable input, SDVariable indices, int axis, boolean inPlace) {
        super(null, sameDiff, new SDVariable[]{input, indices}, inPlace);
        this.addIArgument(axis);
        this.jaxis = axis;
    }

    public Gather(INDArray df, int[] indexes, int axis) {
        this.addInputArgument(df);
        this.addIArgument(axis);
        this.addIArgument(indexes);
        this.jaxis = axis;
        this.indices = this.indices;
    }

    public Gather(INDArray df, INDArray indexes, int axis) {
        this.addInputArgument(df, indexes);
        this.addIArgument(axis);
        this.jaxis = axis;
        this.indices = this.indices;
    }

    @Override
    public String onnxName() {
        return "Gather";
    }

    @Override
    public String[] tensorflowNames() {
        return new String[]{"Gather", "GatherV2"};
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
    }

    @Override
    public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
    }

    @Override
    public void configureFromArguments() {
        if (!this.iArguments.isEmpty()) {
            this.jaxis = ((Long)this.iArguments.get(0)).intValue();
        }
    }

    @Override
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap<String, Map<String, PropertyMapping>> ret = new HashMap<String, Map<String, PropertyMapping>>();
        HashMap<String, PropertyMapping> map = new HashMap<String, PropertyMapping>();
        PropertyMapping broadcast = PropertyMapping.builder().onnxAttrName("indices").tfInputPosition(1).propertyNames(new String[]{"indices"}).build();
        map.put("indices", broadcast);
        ret.put(this.tensorflowNames()[0], map);
        ret.put(this.onnxName(), map);
        HashMap<String, PropertyMapping> map2 = new HashMap<String, PropertyMapping>();
        PropertyMapping broadcast2 = PropertyMapping.builder().tfInputPosition(1).propertyNames(new String[]{"indices"}).build();
        map2.put("indices", broadcast2);
        PropertyMapping axis2 = PropertyMapping.builder().tfInputPosition(2).propertyNames(new String[]{"axis"}).build();
        map2.put("axis", axis2);
        ret.put("GatherV2", map2);
        return ret;
    }

    @Override
    public void setPropertiesForFunction(Map<String, Object> properties) {
        if (properties.containsKey("dimensions")) {
            Long dimensions = (Long)properties.get("dimensions");
            this.jaxis = dimensions.intValue();
        }
    }

    @Override
    public String opName() {
        return "gather";
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> i_v) {
        SDVariable indicesSize = this.sameDiff.expandDims(this.args()[1].length(), 0);
        SDVariable paramsShape = this.sameDiff.shape(this.args()[0]);
        paramsShape = paramsShape.reshape(paramsShape.length());
        SDVariable indicesGrad = this.sameDiff.zerosLike(this.arg(1));
        if (this.jaxis == 0) {
            SDVariable paramsTailShape = paramsShape.getView(SDIndex.interval(this.sameDiff.constant(1), this.sameDiff.constant(1), paramsShape.length()));
            SDVariable valueShape = this.sameDiff.concat(0, indicesSize, paramsTailShape);
            SDVariable values = this.sameDiff.reshape(i_v.get(0), valueShape);
            SDVariable indices = this.sameDiff.flatten(this.args()[1]);
            SDVariable retGrad = this.sameDiff.zerosLike(this.arg());
            SDVariable put = retGrad.put(indices, values, indices).reshape(this.arg().shape());
            return Arrays.asList(put, indicesGrad);
        }
        SDVariable batchDims = this.sameDiff.constant(0);
        SDVariable outerShape = paramsShape.getView(SDIndex.interval(0, this.jaxis));
        SDVariable innerShape = paramsShape.getView(SDIndex.interval(this.sameDiff.constant(this.jaxis), paramsShape.length()), SDIndex.interval(this.sameDiff.constant(1), this.sameDiff.constant(-1)));
        SDVariable valueShape = this.sameDiff.concat(0, outerShape, this.sameDiff.constant(-1).castTo(outerShape.dataType()), innerShape.castTo(outerShape.dataType()));
        SDVariable valuesDims = valueShape.length();
        SDVariable axisDims = outerShape.length();
        SDVariable outerBatchIndices = this.sameDiff.range(0.0, 0.0, 0.0, DataType.INT64);
        SDVariable batchAxisIndices = this.sameDiff.range(batchDims, axisDims, this.sameDiff.constant(1), DataType.INT64);
        SDVariable innerAxisIndices = this.sameDiff.range(axisDims.add(1.0), valuesDims, this.sameDiff.constant(1), DataType.INT64);
        SDVariable indices = this.sameDiff.reshape(this.args()[1], indicesSize);
        SDVariable put = this.sameDiff.unsortedSegmentSum(i_v.get(0), this.sameDiff.range(this.sameDiff.constant(0), this.sameDiff.sizeAt(i_v.get(0), 0), this.sameDiff.constant(1), DataType.INT64), this.sameDiff.sizeAt(i_v.get(0), 0));
        SDVariable values = this.sameDiff.reshape(put, valueShape);
        SDVariable transposeDims = this.sameDiff.concat("transposeConcat", 0, outerBatchIndices, axisDims, batchAxisIndices, innerAxisIndices);
        SDVariable valuesTranspose = this.sameDiff.permute(values, transposeDims);
        SDVariable paramsGrad = this.sameDiff.unsortedSegmentSum(valuesTranspose, indices, paramsShape.get(SDIndex.point(this.jaxis)));
        SDVariable invertTransposeDims = this.sameDiff.concat(0, outerBatchIndices.castTo(DataType.INT64), batchAxisIndices.add(1.0).castTo(DataType.INT64), batchDims.castTo(DataType.INT64), innerAxisIndices.castTo(DataType.INT64));
        paramsGrad = this.sameDiff.permute(paramsGrad, invertTransposeDims);
        return Arrays.asList(paramsGrad, indicesGrad);
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        return Collections.singletonList(dataTypes.get(0));
    }
}

