package ai.minxiao.ds4s.core.dl4j.vectorization

import java.io.File
import scala.util.Random

import org.apache.commons.io.FilenameUtils
import org.datavec.api.io.filters.BalancedPathFilter
import org.datavec.api.io.labels.ParentPathLabelGenerator
import org.datavec.api.split.FileSplit
import org.datavec.image.loader.NativeImageLoader
import org.datavec.image.recordreader.ImageRecordReader
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator


/**
  * Image Iterator
  * 
  * @author mx
  */
object ImageIterator {

  /**
    * @param folder root directory
    * @param numClasses number of classes
    * @param inputChannels input image depth (number of channels): 3-color, 1-grayscale
    * @param outputHeight output pixels height scale, default=32
    * @param outputWidth output pixels width scale, default=32
    * @param batchSize batch size, default=128
    * @param splitRatios split ratios, default=Array(100), sum to 100
    * @return dataset iterators
    */
  def fromLabelPartitions(
    folder: String,
    numClasses: Int,
    inputChannels: Int,
    outputHeight: Int = 32, outputWidth: Int = 32,
    batchSize: Int = 128,
    splitRatios: Array[Double] = Array(100)
  ): Array[RecordReaderDataSetIterator] = {

    require(splitRatios.sum == 100)

    val rootDir = new File(folder)
    val inputSplit = new FileSplit(rootDir, NativeImageLoader.ALLOWED_FORMATS)
    val labelMaker = new ParentPathLabelGenerator
    val sampleSplits = inputSplit.sample(null, splitRatios: _*)
    val iterators = {
      for (aSplit <- sampleSplits) yield {
        val reader = new ImageRecordReader(outputHeight, outputWidth, inputChannels, labelMaker)
        reader.initialize(aSplit)
        new RecordReaderDataSetIterator(reader, batchSize, 1, numClasses) // always 1 for labelIndex
      }
    }.toArray
    iterators
  }

}
