package ai.minxiao.ds4s.core.h2o.prediction

import hex.genmodel.easy.RowData
import hex.genmodel.easy.prediction._
import hex.ModelCategory

/**
  * H2O AutoEncoder Predictor
  *
  * @note should be coupled with H2OPredictor
  * @example new H2OPredictor(tp, src) with H2OAEPredictor
  *
  * @author mx
  */
trait H2OAEPredictor {
  // coupled with H2OPredictor
  this: H2OPredictor =>

  // model category should be AutoEncoder
  require(modelCategory == ModelCategory.AutoEncoder)

  /**
    * @param rowData data to infer anomaly reason(s)
    * @param feat2Reason feature to reason mappings, feature-index => reason
    * @param topN top-N reason(s), default=1
    */
  def anomalyReason(rowData: RowData, feat2Reason: Map[Int, String], topN: Int = 1): Array[(String, Double)] = {
    val p = _f(rowData).asInstanceOf[AutoEncoderModelPrediction]
    p.reconstructed.zip(p.original).
      map(x => (x._1 - x._2) * (x._1 - x._2)).
      zipWithIndex.
      map{case (score, index) => (feat2Reason(index), score)}.
      groupBy(_._1).
      map{case (reason, ls) => (reason, ls.map(_._2).reduce(_ + _))}.
      toArray.
      sortWith(_._2 > _._2). // or sortBy(- _._2)
      take(topN)
  }
}
