/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.layers;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseModule;
import org.nd4j.linalg.api.ops.Module;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.weightinit.WeightInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class Linear
extends BaseModule {
    private DifferentialFunction forward;
    private int nIn;
    private int nOut;
    private WeightInitScheme weightInitScheme;
    private WeightInitScheme biasWeightInitScheme;

    public Linear(int nIn, int nOut, WeightInitScheme weightInitScheme, WeightInitScheme biasWeightInitScheme) {
        super(null, Linear.getParams(nIn, nOut, weightInitScheme, biasWeightInitScheme), new INDArray[0], new ArrayList<Double>(), new ArrayList<Integer>(), new ArrayList<Module>());
        this.weightInitScheme = weightInitScheme;
        this.biasWeightInitScheme = biasWeightInitScheme;
        this.nIn = nIn;
        this.nOut = nOut;
    }

    public Linear(SameDiff sameDiff, int nIn, int nOut, WeightInitScheme weightInitScheme, WeightInitScheme biasWeightInitScheme) {
        super(null, sameDiff, null, false, new ArrayList<Module>());
        this.weightInitScheme = weightInitScheme;
        this.biasWeightInitScheme = biasWeightInitScheme;
        this.nIn = nIn;
        this.nOut = nOut;
    }

    @Override
    public String opName() {
        return "linear";
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> f1) {
        this.execSameDiff(new SDVariable[0]);
        return this.forward.doDiff(f1);
    }

    @Override
    public List<LongShapeDescriptor> calculateOutputShape() {
        ArrayList<LongShapeDescriptor> ret = new ArrayList<LongShapeDescriptor>();
        ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(this.inputArguments()[0].shape(), new long[]{this.nOut, this.nIn}), this.inputArguments()[1].dataType()));
        ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(this.inputArguments()[0].shape(), this.inputArguments()[1].transpose().shape()), this.inputArguments()[1].dataType()));
        if (this.biasWeightInitScheme != null) {
            ret.add(LongShapeDescriptor.fromShape(new long[]{this.nOut, 1L}, this.inputArguments()[1].dataType()));
        }
        return ret;
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + this.opName());
    }

    @Override
    public void exec(INDArray ... inputs) {
        INDArray[] inputArguments = this.inputArguments();
        if (inputArguments == null || inputArguments.length < 1) {
            throw new IllegalStateException("No arguments found.");
        }
        INDArray weights = inputArguments[0];
        INDArray right = inputArguments[1];
        INDArray[] outputArguments = this.outputArguments();
        if (outputArguments == null || outputArguments.length < 1) {
            if (inputArguments.length == 1) {
                this.addOutputArgument(inputs[0].mmul(weights.transpose()));
            } else {
                this.addOutputArgument(inputs[0].mmul(weights.transpose()).addiColumnVector(right));
            }
        } else {
            inputs[0].mmul(weights.transpose(), outputArguments[0]);
        }
    }

    @Override
    public void execSameDiff(SDVariable ... input) {
        SDVariable[] args = this.args();
        if (args == null || args.length == 0) {
            throw new IllegalStateException("No arguments found");
        }
        if (this.forward == null) {
            this.forward = args.length > 1 ? this.f().add(new Mmul(this.sameDiff, input[0], this.args()[0], MMulTranspose.builder().transposeA(false).transposeB(true).build()).outputVariables()[0], this.args()[1]) : new Mmul(this.sameDiff, input[0], this.args()[0], MMulTranspose.builder().transposeA(false).transposeB(true).build());
            this.outputVariables = this.forward.outputVariables();
        }
    }

    private static INDArray[] getParams(int nIn, int nOut, WeightInitScheme paramsScheme, WeightInitScheme biasInitScheme) {
        if (biasInitScheme != null) {
            return new INDArray[]{paramsScheme.create(Nd4j.defaultFloatingPointType(), nOut, nIn), biasInitScheme.create(Nd4j.defaultFloatingPointType(), nOut, 1L)};
        }
        return new INDArray[]{paramsScheme.create(Nd4j.defaultFloatingPointType(), nOut, nIn)};
    }

    public static LinearBuilder execBuilder() {
        return new LinearBuilder();
    }

    public static LinearBuilder sameDiffBuilder() {
        return new LinearBuilder();
    }

    public Linear() {
    }

    public static class LinearBuilder {
        private int nIn;
        private int nOut;
        private WeightInitScheme weightInitScheme;
        private WeightInitScheme biasWeightInitScheme;
        private SameDiff sameDiff;

        LinearBuilder() {
        }

        public LinearBuilder nIn(int nIn) {
            this.nIn = nIn;
            return this;
        }

        public LinearBuilder nOut(int nOut) {
            this.nOut = nOut;
            return this;
        }

        public LinearBuilder weightInitScheme(WeightInitScheme weightInitScheme) {
            this.weightInitScheme = weightInitScheme;
            return this;
        }

        public LinearBuilder biasWeightInitScheme(WeightInitScheme biasWeightInitScheme) {
            this.biasWeightInitScheme = biasWeightInitScheme;
            return this;
        }

        public Linear build() {
            return new Linear(this.nIn, this.nOut, this.weightInitScheme, this.biasWeightInitScheme);
        }

        public String toString() {
            return "Linear.LinearBuilder(nIn=" + this.nIn + ", nOut=" + this.nOut + ", weightInitScheme=" + this.weightInitScheme + ", biasWeightInitScheme=" + this.biasWeightInitScheme + ")";
        }

        public LinearBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }
    }
}

