package ai.minxiao.ds4s.core.dl4j.vectorization

import java.io.File

import org.datavec.api.conf.Configuration
import org.datavec.api.records.reader.impl.misc.SVMLightRecordReader
import org.datavec.api.split.FileSplit
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator

/**
  * svmLight Iterator
  *
  * @author mx
  */
object SvmLightIterator {

  /**
    * svmLight file
    * @param filePath file name path
    * @param numFeatures number of features, required
    * @param batchSize batch size, required
    * @param zeroBasedIndexing whether zero-based index for features, default=true
    * @param zeroBasedLabelIndexing whether zero-based index for labels, default=false
    * @param multilabel whether for multi-label, default=false
    * @param numLabels number of labels, default=-1, (positive for multi-task/multiple outputs), for multilabels, labels are separated by comma
    *
    * @param numClasses number of classes, default=-1
    *
    * @note refer to <a href="https://github.com/deeplearning4j/DataVec/blob/master/datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java">
    * SVMLightRecordReader.java</a>
    */
  def load(filePath: String,
    numFeatures: Int,
    batchSize: Int,
    zeroBasedIndexing: Boolean = true,
    zeroBasedLabelIndexing: Boolean = false,
    multilabel: Boolean = false,
    numLabels: Int = -1,
    numClasses: Int = -1
    ): RecordReaderDataSetIterator = {

    val config = new Configuration()
    config.setInt(SVMLightRecordReader.NUM_FEATURES, numFeatures)
    config.setInt(SVMLightRecordReader.NUM_LABELS, numLabels)
    config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, zeroBasedIndexing)
    config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, zeroBasedLabelIndexing)
    config.setBoolean(SVMLightRecordReader.MULTILABEL, multilabel)

    val file = new File(filePath)
    val fileSplit = new FileSplit(file)
    val rr = new SVMLightRecordReader()
    rr.initialize(config, fileSplit)
    val iter = new RecordReaderDataSetIterator(rr, batchSize, numFeatures, numClasses)
    iter
  }
}
