package ai.minxiao.ds4s.core.dl4j.evaluation

import java.util.{Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable.{Map => MMap}

import org.deeplearning4j.nn.graph.ComputationGraph
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.nd4j.evaluation.classification.{Evaluation, ROC, ROCMultiClass}
import org.nd4j.evaluation.IEvaluation
import org.nd4j.evaluation.regression.RegressionEvaluation
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.dataset.api.iterator.{DataSetIterator, MultiDataSetIterator}
import org.nd4j.linalg.dataset.DataSet

import ai.minxiao.ds4s.core.dl4j.prediction.Predictor

/**
  * @author mx
  */
object Evaluator {

  def run(net: MultiLayerNetwork, data: DataSetIterator): Evaluation = {
    net.evaluate(data)
  }

  def run(net: MultiLayerNetwork, data: DataSet, numLabels: Int): Evaluation = {
    val eval: Evaluation = new Evaluation(numLabels)
    val output: INDArray = net.output(data.getFeatures)
    eval.eval(data.getLabels, output)
    eval
  }

  def run(net: ComputationGraph, data: DataSetIterator): Evaluation = {
    net.evaluate(data)
  }

  /**
    */
  def run(net: ComputationGraph, data: MultiDataSetIterator, evalType: String = "default"): IEvaluation[_ <: IEvaluation[_ <: AnyRef]] = {
    val evaluation = evalType match {
      // multi-task
      case "ROCMultiClass" // Multi-Task:Multi-Class (MTMC)
        => net.evaluateROCMultiClass(data, 0)
      case "ROC" // Multi-Task:Binary Class (MTBC)
        => net.evaluateROC(data, 0)
      case "Regression" // Regression
        => net.evaluateRegression(data)
      case _ // Default Singe Output Classification
        => net.evaluate(data)
    }
    evaluation.asInstanceOf[IEvaluation[_ <: IEvaluation[_ <: AnyRef]]]
  }

  def run(net: ComputationGraph, data: MultiDataSetIterator,
    evaluations: JMap[Integer, Array[IEvaluation[_ <: IEvaluation[_ <: AnyRef]]]]): JMap[Integer, Array[IEvaluation[_ <: IEvaluation[_ <: AnyRef]]]] = {
    net.evaluate(data, evaluations)
  }

  def run(net: ComputationGraph, data: MultiDataSetIterator,
    evalTypes: Map[Int, Array[String]], nCMap: Map[Int, Int]): JMap[Integer, Array[IEvaluation[_ <: IEvaluation[_ <: AnyRef]]]] = {
      val evaluations = {
        for ((k, vs) <- evalTypes)
          yield {
            val es = vs.map{ _ match {
              case "ROCMultiClass" => new ROCMultiClass
              case "ROC"           => new ROC
              case "Regression"    => new RegressionEvaluation
              case _               => new Evaluation(nCMap(k))
            }}.asInstanceOf[Array[IEvaluation[_ <: IEvaluation[_ <: AnyRef]]]]
            (k.asInstanceOf[Integer] -> es)
          }
      }.asJava
    Evaluator.run(net, data, evaluations)
  }

}
