package ai.minxiao.ds4s.core.h2o.prediction

import hex.genmodel.{GenModel, MojoModel}
import hex.genmodel.easy.{EasyPredictModelWrapper, RowData}
import hex.genmodel.easy.exception.PredictWrongModelCategoryException
import hex.genmodel.easy.prediction._
import hex.ModelCategory

/**
  * H2OPredictor
  *
  * load a pre-trained h2o pojo/mojo model and add a prediction function
  *
  * @constructor load a trained model with a type and source (name/path).
  * @param tp the model's type: mojo vs pojo
  * @param src the name/path of the pojo/mojo
  *
  * @author mx
  */
@SerialVersionUID(7280L)
class H2OPredictor(private val tp: String = "mojo", private val src: String)
    extends Serializable {

  protected lazy val genModel: GenModel =
    if (tp == "pojo") Class.forName(src).newInstance().asInstanceOf[GenModel]
    else MojoModel.load(src)

  protected lazy val easyModel: EasyPredictModelWrapper =
    new EasyPredictModelWrapper(
      new EasyPredictModelWrapper.Config()
        .setModel(genModel)
        .setConvertUnknownCategoricalLevelsToNa(true)
        .setConvertInvalidNumbersToNa(true)
        .setUseExtendedOutput(true)
    )

  /** feature and response column names */
  protected lazy val names: Array[String] = genModel.getNames

  val modelCategory: ModelCategory = easyModel.getModelCategory

  /**
    * partial function for prediction, auxiliary function
    * <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-genmodel/src/main/java/hex/genmodel/easy/EasyPredictModelWrapper.java" target="_blank">h2o github</a>
    */
  protected lazy val _f = modelCategory match {
    case ModelCategory.Binomial      => easyModel.predictBinomial(_: RowData, 0.0)
    case ModelCategory.Multinomial   => easyModel.predictMultinomial(_: RowData, 0.0)
    case ModelCategory.Ordinal       => easyModel.predictOrdinal(_: RowData, 0.0)
    case ModelCategory.Regression | ModelCategory.CoxPH
                                     => easyModel.predictRegression(_: RowData, 0.0)
    case ModelCategory.Clustering    => easyModel.predictClustering _
    case ModelCategory.DimReduction  => easyModel.predictDimReduction _
    case ModelCategory.AutoEncoder   => easyModel.predictAutoEncoder _
    case ModelCategory.WordEmbedding => easyModel.predictWord2Vec _
    case ModelCategory.AnomalyDetection
                                     //=> easyModel.predictAnomaly _
                                     => throw new PredictWrongModelCategoryException("unsupported model prediction")
    case ModelCategory.Unknown       => throw new PredictWrongModelCategoryException("unknown model category")
  }

  /**
    * predict on each RowData
    *
    * @param rowData: instance (feature-values)
    * @return Tuple2[String, Double]:
    *    String: class/cluster/low-dimension-string/regression-value-string
    *    Double: confidence score if applicable
    */
  def predict(rowData: RowData): (String, Double) = {
    _f(rowData) match {
      case p : BinomialModelPrediction
        => (p.label, p.classProbabilities(p.labelIndex))
      case p : MultinomialModelPrediction
        => (p.label, p.classProbabilities(p.labelIndex))
      case p: OrdinalModelPrediction
        => (p.label, p.classProbabilities(p.labelIndex))
      case p : RegressionModelPrediction
        => (p.value.toString, p.value)
      case p : ClusteringModelPrediction
        => (p.cluster.toString, p.distances(p.cluster))
      case p : DimReductionModelPrediction
        => (p.dimensions.mkString(","), 0.0)
      case p : AutoEncoderModelPrediction
        => (p.reconstructed.mkString(","),
             math.sqrt(
               p.reconstructed.zip(p.original)
               .map(x => (x._1 - x._2)*(x._1 - x._2))
               .reduce(_ + _) / p.reconstructed.size
             )
           )
      case p : Word2VecPrediction
        => (p.wordEmbeddings.values.asInstanceOf[Vector[Vector[Double]]].reduce{
             (x, y) => (x, y).zipped.map(_ + _)}.mkString(","),
             0.0
           )
    }
  }
}
