/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.reduce;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class Mmul
extends DynamicCustomOp {
    protected MMulTranspose mt;
    protected double alpha = 1.0;
    protected double beta = 0.0;

    public Mmul(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, MMulTranspose mt) {
        super(null, sameDiff, new SDVariable[]{i_v1, i_v2});
        this.mt = mt;
        this.addIArgument(ArrayUtil.fromBoolean((boolean)mt.isTransposeA()), ArrayUtil.fromBoolean((boolean)mt.isTransposeB()), ArrayUtil.fromBoolean((boolean)mt.isTransposeResult()));
        this.addTArgument(this.alpha, this.beta);
    }

    public Mmul(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
        this(sameDiff, i_v1, i_v2, MMulTranspose.allFalse());
    }

    public Mmul(INDArray x, INDArray y, INDArray z, double alpha, double beta, MMulTranspose mt) {
        this.addInputArgument(x, y);
        if (z != null) {
            this.addOutputArgument(z);
        }
        if (mt != null) {
            this.mt = mt;
            this.addIArgument(ArrayUtil.fromBoolean((boolean)mt.isTransposeA()), ArrayUtil.fromBoolean((boolean)mt.isTransposeB()), ArrayUtil.fromBoolean((boolean)mt.isTransposeResult()));
        }
        this.alpha = alpha;
        this.beta = beta;
        this.addTArgument(alpha, beta);
    }

    public Mmul(INDArray x, INDArray y, INDArray z, MMulTranspose mt) {
        this(x, y, z, 1.0, 0.0, mt);
    }

    public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) {
        this(x, y, 1.0, 0.0, transposeX, transposeY, transposeZ);
    }

    public Mmul(INDArray x, INDArray y, double alpha, double beta, boolean transposeX, boolean transposeY, boolean transposeZ) {
        this.addInputArgument(x, y);
        this.addIArgument(ArrayUtil.fromBoolean((boolean)transposeX), ArrayUtil.fromBoolean((boolean)transposeY), ArrayUtil.fromBoolean((boolean)transposeZ));
        this.mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
        this.addTArgument(alpha, beta);
        this.alpha = alpha;
        this.beta = beta;
    }

    public Mmul(INDArray x, INDArray y, double alpha, double beta) {
        this(x, y, null, alpha, beta, null);
    }

    public Mmul(INDArray x, INDArray y) {
        this(x, y, 1.0, 0.0);
    }

    public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, boolean transposeZ) {
        super(null, sameDiff, new SDVariable[]{x, y});
        this.addIArgument(ArrayUtil.fromBoolean((boolean)transposeX), ArrayUtil.fromBoolean((boolean)transposeY), ArrayUtil.fromBoolean((boolean)transposeZ));
        this.addTArgument(this.alpha, this.beta);
        this.mt = MMulTranspose.builder().transposeA(transposeX).transposeB(transposeY).transposeResult(transposeZ).build();
    }

    public Mmul() {
    }

    @Override
    public Object getValue(Field property) {
        if (this.mt == null) {
            this.mt = MMulTranspose.builder().build();
        }
        return this.mt.getValue(property);
    }

    @Override
    public Map<String, Object> propertiesForFunction() {
        if (this.mt == null) {
            return Collections.emptyMap();
        }
        return this.mt.toProperties();
    }

    @Override
    public void configureFromArguments() {
        this.mt = MMulTranspose.builder().transposeA(this.numIArguments() > 0 && this.getIArgument(0) > 0L).transposeB(this.numIArguments() > 1 && this.getIArgument(1) > 0L).transposeResult(this.numIArguments() > 2 && this.getIArgument(2) > 0L).build();
    }

    @Override
    public boolean isConfigProperties() {
        return true;
    }

    @Override
    public String configFieldName() {
        return "mt";
    }

    @Override
    public void setPropertiesForFunction(Map<String, Object> properties) {
        if (this.mt == null) {
            this.mt = MMulTranspose.builder().build();
        }
        this.mt.setProperties(properties);
    }

    public long[] transposeShapeArray(long[] shape) {
        if (shape.length == 2) {
            return ArrayUtil.reverseCopy((long[])shape);
        }
        if (shape.length == 3) {
            return new long[]{shape[0], shape[2], shape[1]};
        }
        throw new IllegalArgumentException("Matrix input has to be of length 2 or 3, got: " + shape.length);
    }

    @Override
    public String onnxName() {
        return "MatMul";
    }

    @Override
    public String[] tensorflowNames() {
        return new String[]{"MatMul", "BatchMatMul", "BatchMatMulV2"};
    }

    @Override
    public String opName() {
        return "matmul";
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        MMulTranspose mMulTranspose;
        boolean isTransposeB;
        boolean isTransposeA;
        super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
        if (nodeDef.getOp().equalsIgnoreCase("MatMul")) {
            isTransposeA = attributesForNode.get("transpose_a").getB();
            isTransposeB = attributesForNode.get("transpose_b").getB();
        } else {
            isTransposeA = attributesForNode.containsKey("transpose_a") ? attributesForNode.get("transpose_a").getB() : attributesForNode.get("adj_x").getB();
            isTransposeB = attributesForNode.containsKey("transpose_b") ? attributesForNode.get("transpose_b").getB() : attributesForNode.get("adj_y").getB();
        }
        this.mt = mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
        this.iArguments.clear();
        this.addIArgument(ArrayUtil.fromBoolean((boolean)this.mt.isTransposeA()), ArrayUtil.fromBoolean((boolean)this.mt.isTransposeB()));
    }

    @Override
    public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
        MMulTranspose mMulTranspose;
        boolean isTransposeA;
        boolean bl = !attributesForNode.containsKey("transA") ? false : (isTransposeA = attributesForNode.get("transA").getI() > 0L);
        boolean isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0L;
        this.mt = mMulTranspose = MMulTranspose.builder().transposeA(isTransposeA).transposeB(isTransposeB).build();
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> gradients) {
        return Arrays.asList(new MmulBp(this.sameDiff, this.larg(), this.rarg(), gradients.get(0), this.mt).outputVariables());
    }

    @Override
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap<String, Map<String, PropertyMapping>> ret = new HashMap<String, Map<String, PropertyMapping>>();
        HashMap<String, PropertyMapping> map = new HashMap<String, PropertyMapping>();
        PropertyMapping transposeA = PropertyMapping.builder().onnxAttrName("transA").tfAttrName("transpose_a").propertyNames(new String[]{"transposeA"}).build();
        PropertyMapping transposeB = PropertyMapping.builder().onnxAttrName("transB").tfAttrName("transpose_b").propertyNames(new String[]{"transposeB"}).build();
        map.put("transposeA", transposeA);
        map.put("transposeB", transposeB);
        for (String s : this.tensorflowNames()) {
            ret.put(s, map);
        }
        ret.put(this.onnxName(), map);
        return ret;
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        if (!this.dArguments.isEmpty()) {
            return Collections.singletonList((DataType)((Object)this.dArguments.get(0)));
        }
        Preconditions.checkState((dataTypes != null && dataTypes.size() >= 2 ? 1 : 0) != 0, (String)"Expected at least 2 inputs to mmul op, got %s", dataTypes);
        Preconditions.checkState((dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType() ? 1 : 0) != 0, (String)"Inputs to mmul op must both be a floatingpoint type: got %s", dataTypes);
        return Collections.singletonList(dataTypes.get(0));
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof Mmul)) {
            return false;
        }
        Mmul other = (Mmul)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.alpha, other.alpha) != 0) {
            return false;
        }
        if (Double.compare(this.beta, other.beta) != 0) {
            return false;
        }
        MMulTranspose this$mt = this.mt;
        MMulTranspose other$mt = other.mt;
        return !(this$mt == null ? other$mt != null : !((Object)this$mt).equals(other$mt));
    }

    protected boolean canEqual(Object other) {
        return other instanceof Mmul;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $alpha = Double.doubleToLongBits(this.alpha);
        result = result * 59 + (int)($alpha >>> 32 ^ $alpha);
        long $beta = Double.doubleToLongBits(this.beta);
        result = result * 59 + (int)($beta >>> 32 ^ $beta);
        MMulTranspose $mt = this.mt;
        result = result * 59 + ($mt == null ? 43 : ((Object)$mt).hashCode());
        return result;
    }
}

