package ai.minxiao.ds4s.core.dl4j.vectorization

import java.io.File
import java.util.Random

import org.datavec.api.records.reader.impl.csv.{CSVRecordReader, CSVSequenceRecordReader}
import org.datavec.api.records.reader.impl.transform.TransformProcessRecordReader
import org.datavec.api.records.reader.RecordReader
import org.datavec.api.split.FileSplit
import org.datavec.api.transform.schema.Schema
import org.datavec.api.transform.schema.Schema.{Builder => SchemaBuilder}
import org.datavec.api.transform.transform.categorical.{CategoricalToIntegerTransform, CategoricalToOneHotTransform}
import org.datavec.api.transform.transform.column.RemoveColumnsTransform
import org.datavec.api.transform.TransformProcess
import org.datavec.api.transform.TransformProcess.{Builder => TPBuilder}
import org.deeplearning4j.datasets.datavec.{RecordReaderDataSetIterator, SequenceRecordReaderDataSetIterator}

/**
  * CSV Iterator: Load CSV Inputs to an Iterator, which is ready for training
  *
  * @author mx
  */
object CSVIterator {

  /**
    * Columnar CSV file
    * @param filePath file name path
    * @param numLinesToSkip number of lines to skip, default=1
    * @param delimiter delimiter, default=','.
    * Other options:
    * {{{
    * '\t': for tab-delimitd
    * CSVRecordReader.QUOTE_HANDLING_DELIMITER: for records containing quotes with commas in them
    * }}}
    * @param quote quote, default='\"'
    *
    * @param batchSize batch size, default=128
    *
    * @param classification whether for classification, default=false
    * @param labelIndex label index when classification=true, default=0
    * @param numPossibleLabels label cardinality when classification=true, default=2
    * @param maxNumBatches Maximum number of batches to return
    * @param regression whether for regression, default=false
    * @param labelIndexTo only for regression, the last index (inclusive) for multi-output regression, default=0
    *
    * @return dataset iterator
    *
    * @author mx
    */
  def fromPlainCSV(
    filePath: String,
    numLinesToSkip: Int = 1,
    delimiter: Char = CSVRecordReader.DEFAULT_DELIMITER,
    quote: Char = CSVRecordReader.DEFAULT_QUOTE,
    batchSize: Int = 128,
    classification: Boolean = false, labelIndex: Int = 0, numPossibleLabels: Int = 2,
    regression: Boolean = false, labelIndexTo: Int = 0,
    seed: Long = 2018L): RecordReaderDataSetIterator = {
      val file = new File(filePath)
      val rr = new CSVRecordReader(numLinesToSkip, delimiter, quote)
      val inputSplit = new FileSplit(file)
      rr.initialize(inputSplit)
      val iteratorBuilder = new RecordReaderDataSetIterator.Builder(rr, batchSize)
      if (classification) iteratorBuilder.classification(labelIndex, numPossibleLabels)
      if (regression) iteratorBuilder.regression(labelIndex, labelIndexTo)
      val iterator = iteratorBuilder.build()
      iterator
  }

  /**
    * Columnar CSV file
    * @param filePath file name path
    * @param columnNameTypes columm name and types, preserving the order of the input files
    *   types: "double", "float", "long", "integer", "string", "categorical" (must provide the value list).
    * @param columnCategoricalValues categorical column values
    * @param removeColumns columns to ignore
    * @param categorical2OneHots columns to transform to one-hots
    * @param categorical2Ints columns to transform to ints
    *
    * @param numLinesToSkip number of lines to skip, default=1
    * @param delimiter delimiter, default=','.
    * Other options:
    * {{{
    * '\t': for tab-delimitd
    * CSVRecordReader.QUOTE_HANDLING_DELIMITER: for records containing quotes with commas in them
    * }}}
    * @param quote quote, default='\"'
    *
    * @param batchSize batch size, default=128
    *
    * @param classification whether for classification, default=false
    * @param labelIndex label index when classification=true, default=0
    * @param numPossibleLabels label cardinality when classification=true, default=2
    * @param maxNumBatches Maximum number of batches to return
    * @param regression whether for regression, default=false
    * @param labelIndexTo only for regression, the last index (inclusive) for multi-output regression, default=0
    *
    * @return dataset iterator
    *
    * @author mx
    */
  def fromPlainCSVWithSchema(
    filePath: String,
    // -----------------------------
    columnNameTypes: Array[(String, String)],
    columnCategoricalValues: Map[String, Array[String]] = Map(),
    removeColumns: Array[String] = Array(),
    categorical2OneHots: Array[String] = Array(),
    categorical2Ints: Array[String] = Array(),
    // -------------------------------
    numLinesToSkip: Int = 1,
    delimiter: Char = CSVRecordReader.DEFAULT_DELIMITER,
    quote: Char = CSVRecordReader.DEFAULT_QUOTE,
    batchSize: Int = 128,
    classification: Boolean = false, labelIndex: Int = 0, numPossibleLabels: Int = 2,
    regression: Boolean = false, labelIndexTo: Int = 0,
    seed: Long = 2018L): RecordReaderDataSetIterator = {

    def buildSchema(): Schema = {
      val schemaBuilder = new SchemaBuilder()
      columnNameTypes.foreach {
        case (n, t) if t == "double"
          => schemaBuilder.addColumnDouble(n)
        case (n, t) if t == "float"
          => schemaBuilder.addColumnFloat(n)
        case (n, t) if t == "long"
          => schemaBuilder.addColumnLong(n)
        case (n, t) if t == "integer"
          => schemaBuilder.addColumnInteger(n)
        case (n, t) if t == "string"
          => schemaBuilder.addColumnString(n)
        case (n, t) if t == "categorical"
          => schemaBuilder.addColumnCategorical(n, columnCategoricalValues(n): _*)
      }
      schemaBuilder.build
    }

    def buildTransformProcess(schema: Schema): TransformProcess = {
      val tpBuilder = new TPBuilder(schema)
      if (! removeColumns.isEmpty) tpBuilder.removeColumns(removeColumns: _*)
      if (! categorical2OneHots.isEmpty) tpBuilder.categoricalToOneHot(categorical2OneHots: _*)
      if (! categorical2Ints.isEmpty) tpBuilder.categoricalToInteger(categorical2Ints: _*)
      tpBuilder.build
    }

    val file = new File(filePath)
    val rr = new CSVRecordReader(numLinesToSkip, delimiter, quote)
    val inputSplit = new FileSplit(file)
    rr.initialize(inputSplit)

    val schema = buildSchema()
    val transformProcess = buildTransformProcess(schema)
    val outputSchema = transformProcess.getFinalSchema()
    val tprr = new TransformProcessRecordReader(rr, transformProcess)

    val iteratorBuilder = new RecordReaderDataSetIterator.Builder(tprr, batchSize)
    if (classification) iteratorBuilder.classification(labelIndex, numPossibleLabels)
    if (regression) iteratorBuilder.regression(labelIndex, labelIndexTo)
    val iterator = iteratorBuilder.build()
    iterator
  }

  def fromPlainCSVWithSchemaToSRR(
    filePath: String,
    // -----------------------------
    columnNameTypes: Array[(String, String)],
    columnCategoricalValues: Map[String, Array[String]] = Map(),
    removeColumns: Array[String] = Array(),
    categorical2OneHots: Array[String] = Array(),
    categorical2Ints: Array[String] = Array(),
    // -------------------------------
    numLinesToSkip: Int = 1,
    delimiter: Char = CSVRecordReader.DEFAULT_DELIMITER,
    quote: Char = CSVRecordReader.DEFAULT_QUOTE,
    batchSize: Int = 128,
    classification: Boolean = false, labelIndex: Int = 0, numPossibleLabels: Int = 2,
    regression: Boolean = false, labelIndexTo: Int = 0,
    seed: Long = 2018L): (RecordReaderDataSetIterator, Schema, RecordReader) = {

    def buildSchema(): Schema = {
      val schemaBuilder = new SchemaBuilder()
      columnNameTypes.foreach {
        case (n, t) if t == "double"
          => schemaBuilder.addColumnDouble(n)
        case (n, t) if t == "float"
          => schemaBuilder.addColumnFloat(n)
        case (n, t) if t == "long"
          => schemaBuilder.addColumnLong(n)
        case (n, t) if t == "integer"
          => schemaBuilder.addColumnInteger(n)
        case (n, t) if t == "string"
          => schemaBuilder.addColumnString(n)
        case (n, t) if t == "categorical"
          => schemaBuilder.addColumnCategorical(n, columnCategoricalValues(n): _*)
      }
      schemaBuilder.build
    }

    def buildTransformProcess(schema: Schema): TransformProcess = {
      val tpBuilder = new TPBuilder(schema)
      if (! removeColumns.isEmpty) tpBuilder.removeColumns(removeColumns: _*)
      if (! categorical2OneHots.isEmpty) tpBuilder.categoricalToOneHot(categorical2OneHots: _*)
      if (! categorical2Ints.isEmpty) tpBuilder.categoricalToInteger(categorical2Ints: _*)
      tpBuilder.build
    }

    val file = new File(filePath)
    val rr = new CSVRecordReader(numLinesToSkip, delimiter, quote)
    val inputSplit = new FileSplit(file)
    rr.initialize(inputSplit)

    val schema = buildSchema()
    val transformProcess = buildTransformProcess(schema)
    val outputSchema = transformProcess.getFinalSchema()
    val tprr = new TransformProcessRecordReader(rr, transformProcess)

    val iteratorBuilder = new RecordReaderDataSetIterator.Builder(tprr, batchSize)
    if (classification) iteratorBuilder.classification(labelIndex, numPossibleLabels)
    if (regression) iteratorBuilder.regression(labelIndex, labelIndexTo)
    val iterator = iteratorBuilder.build()
    (iterator, outputSchema, tprr)
  }

  def fromPlainCSV0(
    filePath: String,
    numLinesToSkip: Int = 1,
    delimiter: Char = CSVRecordReader.DEFAULT_DELIMITER,
    quote: Char = CSVRecordReader.DEFAULT_QUOTE,
    batchSize: Int = 128,
    classification: Boolean = false, labelIndex: Int = 0, numPossibleLabels: Int = 2,
    regression: Boolean = false, labelIndexTo: Int = 0,
    seed: Long = 2018L): RecordReaderDataSetIterator = {
      val file = new File(filePath)
      val rr = new CSVRecordReader(numLinesToSkip, delimiter, quote)
      val inputSplit = new FileSplit(file)
      rr.initialize(inputSplit)
      val iterator =
        if (classification) new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numPossibleLabels)
        else if (regression) new RecordReaderDataSetIterator(rr, batchSize, labelIndex, labelIndexTo, regression)
        else new RecordReaderDataSetIterator(rr, batchSize)
      iterator
  }

  /**
    * Load CSV files, one file per sequence, within one file,
    * one row per time step, with columnar features and label index)
    * @param baseDir folder
    * @param labelIndex label index
    * @param numClasses number of classes
    * @param regression whether for regression, default=false (for classification)
    * @param numLinesToSkip number of lines to skip, default=1
    * @param delimiter delimiter, default=','
    * @param batchSize batch size, default=128
    * @param seed seed for random generator, default=2018L
    */
  def fromSeqCSVs(
    baseDir: String,
    labelIndex: Int,
    numClasses: Int,
    regression: Boolean = false,
    numLinesToSkip: Int = 1,
    delimiter: String = ",",
    batchSize: Int = 128,
    seed: Long = 2018L
  ): SequenceRecordReaderDataSetIterator = {
    val folder = new File(baseDir)
    val rr = new CSVSequenceRecordReader(numLinesToSkip, delimiter)
    val inputSplit = new FileSplit(folder, new Random(seed))
    rr.initialize(inputSplit)
    val iterator = new SequenceRecordReaderDataSetIterator(rr, batchSize, numClasses, labelIndex, regression)
    iterator
  }

}
