package org.nd4j.linalg.lossfunctions.impl

import lombok.EqualsAndHashCode
import org.nd4j.linalg.primitives.Pair
import org.nd4j.linalg.activations.IActivation
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.lossfunctions.ILossFunction
import org.nd4j.linalg.ops.transforms.Transforms
import org.slf4j.{Logger, LoggerFactory}

/**
  * LossComm
  * LossFunction Implementation Common Utilities
  * Only 3 parts for modification: 1) scoreArray and 2) computedLdYHat, 3) name
  *
  * @author mx
  */
@EqualsAndHashCode
abstract class LossComm(weights: INDArray) extends ILossFunction {

  // overload constructors
  def this() {
    this(null)
  }

  /**
    * scoreArray
    * Calculates the loss for a single data point or in other words a batch size of one
    * @param labels       Labels/expected output
    * @param preOutput    Output of the model (neural network)
    * @param activationFn Activation function that should be applied to preOutput
    * @param mask         Mask associated with the labels
    * @return An array the shape and size of the output of the neural net.
    * @note needs modification based on the actual loss function.
    */
  def scoreArray(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray): INDArray

  /**
    * computeScoreArray
    * Compute the score (loss function value) for each example individually.
    * For input [numExamples,nOut] returns scores as a column vector: [numExamples,1]
    * @param labels       Labels/expected output
    * @param preOutput    Output of the model (neural network)
    * @param activationFn Activation function that should be applied to preOutput
    * @param mask
    * @return Loss function value for each example; column vector
    * @note fixed, no need to change
    */
  def computeScoreArray(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray): INDArray =
    scoreArray(labels, preOutput, activationFn, mask).sum(1)

  /**
    * computeScore
    * Compute the score (loss function value) for the given inputs.
    * @param labels       Label/expected preOutput
    * @param preOutput    Output of the model (neural network)
    * @param activationFn Activation function that should be applied to preOutput
    * @param mask         Mask array; may be null
    * @param average      Whether the score should be averaged (divided by number of rows in labels/preOutput) or not
    * @return Loss function value
    * @note fixed, no need to change
    */
  def computeScore(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray, average: Boolean): Double =
    computeScoreArray(labels, preOutput, activationFn, mask).sumNumber.doubleValue / (if (average) labels.size(0) else 1)

    /**
      * computedLdYHat
      * Compute the gradient of the loss function with respect to the prediction: dLdYHat
      *
      * @param labels       Label/expected output
      * @param preOutput    Output of the model (neural network), before the activation function is applied
      * @param activationFn Activation function that should be applied to preOutput
      * @param mask         Mask array; may be null
      * @return Gradient dL/dYHat
      * @note needs modification
      */
 def computedLdYHat(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray): INDArray

  /**
    * Compute the gradient of the loss function with respect to the inputs: dL/dOutput
    *
    * @param labels       Label/expected output
    * @param preOutput    Output of the model (neural network), before the activation function is applied
    * @param activationFn Activation function that should be applied to preOutput
    * @param mask         Mask array; may be null
    * @return Gradient dL/dPreOut
    * @note fixed, no need to change
    */
  def computeGradient(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray): INDArray = {
    val dLdYHat: INDArray = computedLdYHat(labels, preOutput, activationFn, mask)
    //Everything below remains the same
    val dLdPreOut: INDArray = activationFn.backprop(preOutput.dup, dLdYHat).getFirst
    //multiply with masks, always
    if (mask != null) {
      dLdPreOut.muliColumnVector(mask)
    }
    dLdPreOut
  }

  /**
    * computeGradientAndScore
    * Compute both the score (loss function value) and gradient. This is equivalent to calling computeScore and computeGradient individually
    *
    * @param labels       Label/expected output
    * @param preOutput    Output of the model (neural network)
    * @param activationFn Activation function that should be applied to preOutput
    * @param mask         Mask array; may be null
    * @param average      Whether the score should be averaged (divided by number of rows in labels/output) or not
    * @return The score (loss function value) and gradient
    * @note fixed, no need to change
    */
  override def computeGradientAndScore(labels: INDArray, preOutput: INDArray, activationFn: IActivation,
    mask: INDArray, average: Boolean): Pair[java.lang.Double, INDArray] =
    new Pair[java.lang.Double, INDArray](
      computeScore(labels, preOutput, activationFn, mask, average),
      computeGradient(labels, preOutput, activationFn, mask)
    )

  override def name(): String = "LossComm()"

}
