package org.nd4j.linalg.activations.impl

import org.nd4j.linalg.activations.BaseActivationFunction
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.impl.transforms.{Tanh, TanhDerivative}
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.primitives.Pair

/**
  * RATanh
  * Rational Approximation of Tanh: h(x) = 1.7159 * tanh(2x/3)
  * @note <a href="https://arxiv.org/pdf/1508.01292.pdf">Compact Convolutional Neural Network Cascade for Face Detection</a>
  * @note remember to modify Activation.java
  * 
  * @author mx
  */
class ActivationRATanh extends BaseActivationFunction {
  // forward pass to obtain activations: in-place ops
  def getActivation(in: INDArray, training: Boolean): INDArray = {
    Nd4j.getExecutioner.execAndReturn(new Tanh(in.muli(2 / 3.0)))
    in.muli(1.7159)
    in
  }

  /**
    * backward pass to obtain gradients: in-place ops
    * @param in linear input to the activation node
    * @param epsilon the gradient of the loss function with respect to the output,
    * let the out be the output of the activation node such that h(in) = out,
    * then d(Loss)/d(in) = d(Loss)/d(out) * d(out)/d(in) = epsilon * h'(in)
    */
  def backprop(in: INDArray, epsilon: INDArray): Pair[INDArray, INDArray] = {
    // dldZ = h'(in): h(x) = 1.7159*tanh(2x/3) => h'(x) = 1.7159*[tanh(2x/3)]' * 2/3
    val dLdz = Nd4j.getExecutioner.execAndReturn(new TanhDerivative(in.muli(2 / 3.0)))
    dLdz.muli(2 / 3.0)
    dLdz.muli(1.7159)
    dLdz.muli(epsilon)
    new Pair[INDArray, INDArray](dLdz, null)
  }

  override def toString() = "ratanh"
}
