package ai.minxiao.ds4s.core.dl4j.vectorization

import java.io.File
import java.util.Random

import org.datavec.api.records.reader.impl.csv.{CSVRecordReader, CSVSequenceRecordReader}
import org.datavec.api.records.reader.impl.transform.TransformProcessRecordReader
import org.datavec.api.split.FileSplit
import org.datavec.api.transform.schema.Schema
import org.datavec.api.transform.schema.Schema.{Builder => SchemaBuilder}
import org.datavec.api.transform.transform.categorical.{CategoricalToIntegerTransform, CategoricalToOneHotTransform}
import org.datavec.api.transform.transform.column.RemoveColumnsTransform
import org.datavec.api.transform.TransformProcess
import org.datavec.api.transform.TransformProcess.{Builder => TPBuilder}
import org.deeplearning4j.datasets.datavec.{
  RecordReaderDataSetIterator, RecordReaderMultiDataSetIterator,
  SequenceRecordReaderDataSetIterator}

/**
  * CSV Multi-Dataset Iterator
  *
  * @author mx
  */
object CSVMDSIterator {

  /**
    * Columnar CSV file
    * @param filePath file name path
    * @param numLinesToSkip number of lines to skip, default=1
    * @param delimiter delimiter, default=','.
    * Other options:
    * {{{
    * '\t': for tab-delimitd
    * CSVRecordReader.QUOTE_HANDLING_DELIMITER: for records containing quotes with commas in them
    * }}}
    * @param quote quote, default='\"'
    *
    * @param batchSize batch size, default=128
    *
    * @param rrName record reader name, default=rr
    * @param iCols input columns, an array of tuples (indexFrom indexTo), default=Array()
    * @param iOHCols one-hot input columns, an array of tuples (index, cardinality), default=Array()
    * @param oCols output columns, an array of tuples (indexFrom indexTo), default=Array()
    * @param oOHCols one-hot output columns, an array of tuples (index, cardinality), default=Array()
    *
    * @param seed random generator seed, default=2018L
    *
    * @return dataset iterator
    *
    * @author mx
    */
  def fromPlainCSV(
    filePath: String,
    numLinesToSkip: Int = 1,
    delimiter: Char = CSVRecordReader.DEFAULT_DELIMITER,
    quote: Char = CSVRecordReader.DEFAULT_QUOTE,
    batchSize: Int = 128,
    rrName: String = "rr",
    // inputs & outputs
    iCols: Array[Tuple2[Int, Int]] = Array(),
    iOHCols: Array[Tuple2[Int, Int]] = Array(),
    oCols: Array[Tuple2[Int, Int]] = Array(),
    oOHCols: Array[Tuple2[Int, Int]] = Array(),
    seed: Long = 2018L): RecordReaderMultiDataSetIterator = {
      val file = new File(filePath)
      val rr = new CSVRecordReader(numLinesToSkip, delimiter, quote)
      val inputSplit = new FileSplit(file)
      rr.initialize(inputSplit)

      val iteratorBuilder = new RecordReaderMultiDataSetIterator.Builder(batchSize).
        addReader(rrName, rr)
      iCols.foreach(x => iteratorBuilder.addInput(rrName, x._1, x._2))
      iOHCols.foreach(x => iteratorBuilder.addInputOneHot(rrName, x._1, x._2))
      oCols.foreach(x => iteratorBuilder.addOutput(rrName, x._1, x._2))
      oOHCols.foreach(x => iteratorBuilder.addOutputOneHot(rrName, x._1, x._2))

      val iterator = iteratorBuilder.build()
      iterator
  }


  /**
    * Columnar CSV file
    * @param filePath file name path
    * @param columnNameTypes columm name and types, preserving the order of the input files
    *   types: "double", "float", "long", "integer", "string", "categorical" (must provide the value list).
    * @param columnCategoricalValues categorical column values
    * @param removeColumns columns to ignore
    * @param categorical2OneHots columns to transform to one-hots
    * @param categorical2Ints columns to transform to ints
    *
    * @param numLinesToSkip number of lines to skip, default=1
    * @param delimiter delimiter, default=','.
    * Other options:
    * {{{
    * '\t': for tab-delimitd
    * CSVRecordReader.QUOTE_HANDLING_DELIMITER: for records containing quotes with commas in them
    * }}}
    * @param quote quote, default='\"'
    *
    * @param batchSize batch size, default=128
    *
    * @param rrName record reader name, default=rr
    * @param iCols input columns, an array of tuples (whether one-hot, (indexFrom indexTo)), default=Array()
    * @param oCols output columns, an array of tuples (whether one-hot, (indexFrom indexTo)), default=Array()
    *
    * @return dataset iterator
    *
    * @author mx
    */
  def fromPlainCSVWithSchema(
    filePath: String,
    // -----------------------------
    columnNameTypes: Array[(String, String)],
    columnCategoricalValues: Map[String, Array[String]] = Map(),
    removeColumns: Array[String] = Array(),
    categorical2OneHots: Array[String] = Array(),
    categorical2Ints: Array[String] = Array(),
    // -------------------------------
    numLinesToSkip: Int = 1,
    delimiter: Char = CSVRecordReader.DEFAULT_DELIMITER,
    quote: Char = CSVRecordReader.DEFAULT_QUOTE,
    batchSize: Int = 128,
    // ---------------------------------
    rrName: String = "rr",
    iCols: Array[(String, (Int, Int))] = Array(),
    oCols: Array[(String, (Int, Int))] = Array(),
    seed: Long = 2018L): RecordReaderMultiDataSetIterator = {

    def buildSchema(): Schema = {
      val schemaBuilder = new SchemaBuilder()
      columnNameTypes.foreach {
        case (n, t) if t == "double"
          => schemaBuilder.addColumnDouble(n)
        case (n, t) if t == "float"
          => schemaBuilder.addColumnFloat(n)
        case (n, t) if t == "long"
          => schemaBuilder.addColumnLong(n)
        case (n, t) if t == "integer"
          => schemaBuilder.addColumnInteger(n)
        case (n, t) if t == "string"
          => schemaBuilder.addColumnString(n)
        case (n, t) if t == "categorical"
          => schemaBuilder.addColumnCategorical(n, columnCategoricalValues(n): _*)
      }
      schemaBuilder.build
    }

    def buildTransformProcess(schema: Schema): TransformProcess = {
      val tpBuilder = new TPBuilder(schema)
      if (! removeColumns.isEmpty) tpBuilder.removeColumns(removeColumns: _*)
      if (! categorical2OneHots.isEmpty) tpBuilder.categoricalToOneHot(categorical2OneHots: _*)
      if (! categorical2Ints.isEmpty) tpBuilder.categoricalToInteger(categorical2Ints: _*)
      tpBuilder.build
    }

    val file = new File(filePath)
    val rr = new CSVRecordReader(numLinesToSkip, delimiter, quote)
    val inputSplit = new FileSplit(file)
    rr.initialize(inputSplit)

    val schema = buildSchema()
    println("schema before transformation ...")
    println(schema)
    val transformProcess = buildTransformProcess(schema)
    val outputSchema = transformProcess.getFinalSchema()
    val tprr = new TransformProcessRecordReader(rr, transformProcess)

    println("schema after transformation ...")
    println(outputSchema)

    val iteratorBuilder = new RecordReaderMultiDataSetIterator.Builder(batchSize).
      addReader(rrName, tprr)

    for ((t, x) <- iCols)
      if (t == "onehot") iteratorBuilder.addInputOneHot(rrName, x._1, x._2)
      else iteratorBuilder.addInput(rrName, x._1, x._2)
    for ((t, x) <- oCols)
      if (t == "onehot") iteratorBuilder.addOutputOneHot(rrName, x._1, x._2)
      else iteratorBuilder.addOutput(rrName, x._1, x._2)

    val iterator = iteratorBuilder.build()
    iterator
  }



}
