package org.nd4j.linalg.lossfunctions.impl

import lombok.EqualsAndHashCode
import org.nd4j.linalg.primitives.Pair
import org.nd4j.linalg.activations.IActivation
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.lossfunctions.ILossFunction
import org.nd4j.linalg.ops.transforms.Transforms
import org.slf4j.{Logger, LoggerFactory}

@EqualsAndHashCode
object LossL1L2 {
  private val logger = LoggerFactory.getLogger(classOf[LossL1L2])
}

@EqualsAndHashCode
class LossL1L2(weights: INDArray) extends LossComm(weights) with ILossFunction {

  // overload constructors
  def this() {
    this(null)
  }

  // score the input: for each point, (y - yHat)^2 + |y - yHat|
  def scoreArray(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray): INDArray = {
    val output: INDArray     = activationFn.getActivation(preOutput.dup, true)
    val yMinusyHat: INDArray = Transforms.abs(labels.sub(output))//|y - y_hat|
    val scoreArr: INDArray   = yMinusyHat.mul(yMinusyHat) //(y-y_hat)^2
    scoreArr.addi(yMinusyHat) //+ |y - y_hat|

    // weighted
    if (weights != null) scoreArr.muliRowVector(weights)

    if (mask != null) {
      scoreArr.muliColumnVector(mask)
    }
    scoreArr
  }

  /**
    * computeGradient dLdYHat
    * Compute the gradient wrt to the preout (which is the input to the final layer of the neural net)
    * Use the chain rule
    * In this case L = (y - yHat)^2 + |y - yHat|
    * dL/dyHat = -2*(y-yHat) - sign(y-yHat), sign of y - yHat = +1 if y-yHat>= 0 else -1
    * dyHat/dpreout = d(Activation(preout))/dpreout = Activation'(preout)
    * dL/dpreout = dL/dyHat * dyHat/dpreout
    */
  def computedLdYHat(labels: INDArray, preOutput: INDArray, activationFn: IActivation, mask: INDArray): INDArray = {
    val output: INDArray     = activationFn.getActivation(preOutput.dup, true)
    val yMinusyHat: INDArray = labels.sub(output)
    val dLdYHat: INDArray    = yMinusyHat.mul(-2).sub(Transforms.sign(yMinusyHat))

    if (weights != null) dLdYHat.muliRowVector(weights)

    dLdYHat
  }

  override def name(): String = "LossL1L2()"

}
