package ai.minxiao.ds4s.core.dl4j.math

import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j

/**
  * Sawtooth Wave Function 
  *
  * @author mx
  */
class SawtoothMathFunction extends MathFunction {
  def getFunctionValues(x: INDArray): INDArray = {
    val sawtoothPeriod = 4.0
    //the input data is the intervals at which the wave is being calculated
    val xd2 = x.data.asDouble
    val yd2 = new Array[Double](xd2.length)
    var i = 0
    for (i <- xd2.indices) {
      //Using the sawtooth wave function, find the values at the given intervals
      yd2(i) = 2 * (xd2(i) / sawtoothPeriod - Math.floor(xd2(i) / sawtoothPeriod + 0.5))
    }
    Nd4j.create(yd2, Array[Int](xd2.length, 1)) //Column vector
  }

  def getName: String = "Sawtooth"
}
