package ai.minxiao.ds4s.core.dl4j.cgnn

import org.deeplearning4j.nn.conf.layers.{
  ConvolutionLayer, DenseLayer,
  EmbeddingLayer, EmbeddingSequenceLayer,
  GravesBidirectionalLSTM, GravesLSTM, LSTM, OutputLayer, RnnOutputLayer,
  SubsamplingLayer}
  import org.deeplearning4j.nn.conf.layers.SubsamplingLayer.PoolingType
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder
import org.deeplearning4j.nn.graph.vertex.GraphVertex
import org.deeplearning4j.nn.conf.graph.{ElementWiseVertex,
  L2NormalizeVertex, L2Vertex, MergeVertex,
  PreprocessorVertex, PoolHelperVertex, ReshapeVertex, ScaleVertex, ShiftVertex,
  StackVertex, SubsetVertex, UnstackVertex}
import org.deeplearning4j.nn.conf.graph.rnn.{DuplicateToTimeSeriesVertex, LastTimeStepVertex, ReverseTimeSeriesVertex}
import org.deeplearning4j.nn.conf.InputPreProcessor
import org.deeplearning4j.nn.conf.preprocessor.{FeedForwardToRnnPreProcessor, RnnToFeedForwardPreProcessor}
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.lossfunctions.impl._
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction



/**
  * Computation Graph Vertex
  * @constructor
  * @param vType vertex type, refer to VType
  * @param vName vertex name
  * @param vInputs vertex inputs
  * @param lossFunction loss function
  * @param nIn input cardinality
  * @param nOut output cardinality
  * @param activation activation
  * @param poolingType pooling layer type
  * @param weightInit
  * @param op operation type of elementwise vertex Op.X (Add, Subtract, Product, Average, Max)
  *
  * @author mx
  */
@SerialVersionUID(677186L)
case class CGVertex(
  vType: VType.Value,
  vName: String,
  vInputs: Array[String],
  lossFunction: Option[LossFunction] = None,
  classWeights: Option[Array[Double]] = None,
  nIn: Option[Int] = None,
  nOut: Option[Int] = None,
  activation: Option[Activation] = None,
  poolingType: Option[PoolingType] = None,
  weightInit: Option[WeightInit] = None,
  biasInit: Option[Double] = None,
  dropOut: Option[Double] = None,
  kernelSize: Option[Array[Int]] = None,
  stride: Option[Array[Int]] = None,
  padding: Option[Array[Int]] = None,
  subIndices: Option[Tuple2[Int, Int]] = None,
  op: Option[ElementWiseVertex.Op] = None,
  scaleFactor: Option[Double] = None,
  shiftFactor: Option[Double] = None,
  from: Option[Int] = None,
  stackSize: Option[Int] = None,
  processor: Option[InputPreProcessor] = None,
  maskArrayInputName: Option[String] = None,
  inputName: Option[String] = None
)

/**
  * Companion Object
  *
  * @author mx
  */
object CGVertex {

  /** add one vertex to the graph builder */
  def oneVertexConfBuilder(graphBuilder: GraphBuilder, vertex: CGVertex): GraphBuilder = {
    // preprocessor
    if (vertex.processor != None) graphBuilder.inputPreProcessor(vertex.vName, vertex.processor.get)

    vertex.vType match {
      // LayerVertex
      //   Embedding Layer
      case VType.EmbeddingLayer =>
        graphBuilder.addLayer(
          vertex.vName,
          {
            val eBuilder = new EmbeddingLayer.Builder().nIn(vertex.nIn.get).nOut(vertex.nOut.get)
            if (vertex.weightInit != None) eBuilder.weightInit(vertex.weightInit.get)
            if (vertex.biasInit != None) eBuilder.biasInit(vertex.biasInit.get)
            eBuilder
          }.build(),
          vertex.vInputs:_*
        )
      case VType.EmbeddingSequenceLayer =>
        graphBuilder.addLayer(
          vertex.vName,
          {
            val eBuilder = new EmbeddingSequenceLayer.Builder().nIn(vertex.nIn.get).nOut(vertex.nOut.get)
            if (vertex.weightInit != None) eBuilder.weightInit(vertex.weightInit.get)
            if (vertex.biasInit != None) eBuilder.biasInit(vertex.biasInit.get)
            eBuilder
          }.build(),
          vertex.vInputs:_*
        )
      //   Fully Connected
      case VType.DenseLayer =>
        graphBuilder.addLayer(
          vertex.vName,
          {
            val dBuilder = new DenseLayer.Builder().nIn(vertex.nIn.get).nOut(vertex.nOut.get)
            if (vertex.activation != None) dBuilder.activation(vertex.activation.get)
            if (vertex.weightInit != None) dBuilder.weightInit(vertex.weightInit.get)
            if (vertex.biasInit != None) dBuilder.biasInit(vertex.biasInit.get)
            if (vertex.dropOut != None) dBuilder.dropOut(vertex.dropOut.get)
            dBuilder
          }.build(),
          vertex.vInputs:_*
        )
      case VType.OutputLayer =>
        val oBuilder = new OutputLayer.Builder(vertex.lossFunction.get).nIn(vertex.nIn.get).nOut(vertex.nOut.get)
        // weighted loss
        if (vertex.classWeights != None) {
          val classWeights = Nd4j.create(vertex.classWeights.get)
          oBuilder.lossFunction(
            (vertex.lossFunction.get: @unchecked) match {
              case LossFunction.L2                              => new LossL2(classWeights)
              case LossFunction.MSE | LossFunction.SQUARED_LOSS => new LossMSE(classWeights)
              case LossFunction.L1                              => new LossL1(classWeights)
              case LossFunction.MEAN_ABSOLUTE_ERROR             => new LossMAE(classWeights)
              case LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR  => new LossMAPE(classWeights)
              case LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR  => new LossMSLE(classWeights)
              case LossFunction.XENT                            => new LossBinaryXENT(classWeights)
              case LossFunction.MCXENT                          => new LossMCXENT(classWeights)
              case LossFunction.NEGATIVELOGLIKELIHOOD           => new LossNegativeLogLikelihood(classWeights)
            }
          )
        }
        if (vertex.activation != None) oBuilder.activation(vertex.activation.get)
        if (vertex.weightInit != None) oBuilder.weightInit(vertex.weightInit.get)
        if (vertex.biasInit != None) oBuilder.biasInit(vertex.biasInit.get)
        graphBuilder.addLayer(
          vertex.vName,
          oBuilder.build(),
          vertex.vInputs:_*
        )
      //    Recurrent
      case VType.LSTM =>
        graphBuilder.addLayer(
          vertex.vName,
          {
            val rBuilder = new LSTM.Builder().nIn(vertex.nIn.get).nOut(vertex.nOut.get)
            if (vertex.activation != None) rBuilder.activation(vertex.activation.get)
            if (vertex.weightInit != None) rBuilder.weightInit(vertex.weightInit.get)
            if (vertex.biasInit != None) rBuilder.biasInit(vertex.biasInit.get)
            rBuilder
          }.build(),
          vertex.vInputs:_*
        )
      case VType.GravesLSTM =>
        graphBuilder.addLayer(
          vertex.vName,
          {
            val rBuilder = new GravesLSTM.Builder().nIn(vertex.nIn.get).nOut(vertex.nOut.get)
            if (vertex.activation != None) rBuilder.activation(vertex.activation.get)
            if (vertex.weightInit != None) rBuilder.weightInit(vertex.weightInit.get)
            if (vertex.biasInit != None) rBuilder.biasInit(vertex.biasInit.get)
            rBuilder
          }.build(),
          vertex.vInputs:_*
        )
      case VType.GravesBidirectionalLSTM =>
        graphBuilder.addLayer(
          vertex.vName,
          {
            val rBuilder = new GravesBidirectionalLSTM.Builder().nIn(vertex.nIn.get).nOut(vertex.nOut.get)
            if (vertex.activation != None) rBuilder.activation(vertex.activation.get)
            if (vertex.weightInit != None) rBuilder.weightInit(vertex.weightInit.get)
            if (vertex.biasInit != None) rBuilder.biasInit(vertex.biasInit.get)
            rBuilder
          }.build(),
          vertex.vInputs:_*
        )
      case VType.RnnOutputLayer =>
        val roBuilder = new RnnOutputLayer.Builder(vertex.lossFunction.get).nIn(vertex.nIn.get).nOut(vertex.nOut.get)
        // weighted loss
        if (vertex.classWeights != None) {
          val classWeights = Nd4j.create(vertex.classWeights.get)
          roBuilder.lossFunction(
            (vertex.lossFunction.get: @unchecked) match {
              case LossFunction.L2                              => new LossL2(classWeights)
              case LossFunction.MSE | LossFunction.SQUARED_LOSS => new LossMSE(classWeights)
              case LossFunction.L1                              => new LossL1(classWeights)
              case LossFunction.MEAN_ABSOLUTE_ERROR             => new LossMAE(classWeights)
              case LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR  => new LossMAPE(classWeights)
              case LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR  => new LossMSLE(classWeights)
              case LossFunction.XENT                            => new LossBinaryXENT(classWeights)
              case LossFunction.MCXENT                          => new LossMCXENT(classWeights)
              case LossFunction.NEGATIVELOGLIKELIHOOD           => new LossNegativeLogLikelihood(classWeights)
            }
          )
        }
        if (vertex.activation != None) roBuilder.activation(vertex.activation.get)
        if (vertex.weightInit != None) roBuilder.weightInit(vertex.weightInit.get)
        if (vertex.biasInit != None) roBuilder.biasInit(vertex.biasInit.get)
        graphBuilder.addLayer(
          vertex.vName,
          roBuilder.build(),
          vertex.vInputs:_*
        )
      //    Convolution
      case VType.ConvolutionLayer =>
        graphBuilder.addLayer(
          vertex.vName,
          {
            val cBuilder = new ConvolutionLayer.Builder().nOut(vertex.nOut.get)
            if (vertex.activation != None) cBuilder.activation(vertex.activation.get)
            if (vertex.weightInit != None) cBuilder.weightInit(vertex.weightInit.get)
            if (vertex.biasInit != None) cBuilder.biasInit(vertex.biasInit.get)
            if (vertex.dropOut != None) cBuilder.dropOut(vertex.dropOut.get)
            if (vertex.kernelSize != None) cBuilder.kernelSize(vertex.kernelSize.get:_*)
            if (vertex.stride != None) cBuilder.stride(vertex.stride.get:_*)
            if (vertex.padding != None) cBuilder.padding(vertex.padding.get:_*)
            cBuilder
          }.build(),
          vertex.vInputs:_*
        )
      case VType.SubsamplingLayer =>
        graphBuilder.addLayer(
          vertex.vName,
          {
            val pBuilder = new SubsamplingLayer.Builder(vertex.poolingType.get)
            if (vertex.kernelSize != None) pBuilder.kernelSize(vertex.kernelSize.get:_*)
            if (vertex.stride != None) pBuilder.stride(vertex.stride.get:_*)
            if (vertex.padding != None) pBuilder.padding(vertex.padding.get:_*)
            pBuilder
          }.build(),
          vertex.vInputs:_*
        )
      // Misc
      case VType.MergeVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new MergeVertex(),
          vertex.vInputs:_*
        )
      case VType.SubsetVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new SubsetVertex(vertex.subIndices.get._1, vertex.subIndices.get._2),
          vertex.vInputs:_*
        )
      case VType.ElementWiseVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new ElementWiseVertex(vertex.op.get),
          vertex.vInputs:_*
        )
      case VType.L2NormalizeVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new L2NormalizeVertex(),
          vertex.vInputs:_*
        )
      case VType.L2Vertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new L2Vertex(),
          vertex.vInputs:_*
        )
      case VType.PoolHelperVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new PoolHelperVertex(),
          vertex.vInputs:_*
        )
      case VType.ReshapeVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new ReshapeVertex(),
          vertex.vInputs:_*
        )
      case VType.ScaleVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new ScaleVertex(vertex.scaleFactor.get),
          vertex.vInputs:_*
        )
      case VType.ShiftVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new ShiftVertex(vertex.shiftFactor.get),
          vertex.vInputs:_*
        )
      case VType.StackVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new StackVertex(),
          vertex.vInputs:_*
        )
      case VType.UnstackVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new UnstackVertex(vertex.from.get, vertex.stackSize.get),
          vertex.vInputs:_*
        )
      case VType.DuplicateToTimeSeriesVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new DuplicateToTimeSeriesVertex(vertex.inputName.get),
          vertex.vInputs:_*
        )
      case VType.LastTimeStepVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new LastTimeStepVertex(vertex.maskArrayInputName.get),
          vertex.vInputs: _*
        )
      case VType.ReverseTimeSeriesVertex =>
        graphBuilder.addVertex(
          vertex.vName,
          new ReverseTimeSeriesVertex(vertex.maskArrayInputName.get),
          vertex.vInputs: _*
        )
      case _/*Unknown Type*/=>
        graphBuilder
    }
  }
}
