package ai.minxiao.ds4s.core.dl4j.ui

import java.io.File

import org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter
import org.deeplearning4j.nn.api.NeuralNetwork
import org.deeplearning4j.nn.graph.ComputationGraph
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.optimize.listeners.ScoreIterationListener
import org.deeplearning4j.ui.api.UIServer
import org.deeplearning4j.ui.stats.{J7StatsListener, StatsListener}
import org.deeplearning4j.ui.storage.{FileStatsStorage, InMemoryStatsStorage}
import org.deeplearning4j.ui.storage.sqlite.J7FileStatsStorage

/**
  * UI Starter
  * 
  * @author mx
  */
object UIStarter {

  /** Initiate a new console display
    * @param net multilayer network
    * @param listenFreq listener frequency
    */
  def initNewInConsole(net: MultiLayerNetwork, listenFreq: Int): Unit = {
    net.setListeners(new ScoreIterationListener(listenFreq))
  }

  /** Initiate a new console display
    * @param net computation graph
    * @param listenFreq listener frequency
    */
  def initNewInConsole(net: ComputationGraph, listenFreq: Int): Unit = {
    net.setListeners(new ScoreIterationListener(listenFreq))
  }

  /** Initiate a new UI instance and save all stats in the memory.
    *
    * @param net multilayer network
    * @param listenFreq listener frequency
    * @param enableRemote whether to enable remote monitoring
    * @return the UI instance
    */
  def initNewInMemory(net: MultiLayerNetwork, listenFreq: Int = 1, enableRemote: Boolean = false): UIServer = {
    val uiServer = UIServer.getInstance
    val statsStorage = new InMemoryStatsStorage() //Alternative: new FileStatsStorage(File) - see UIStorageExample
    uiServer.attach(statsStorage)
    net.setListeners(new ScoreIterationListener(listenFreq), new StatsListener(statsStorage))
    if (enableRemote) uiServer.enableRemoteListener()
    uiServer
  }

  /** Initiate a new UI instance and save all stats in the memory.
    *
    * @param net computation graph
    * @param listenFreq listener frequency
    * @param enableRemote whether to enable remote monitoring
    * @return the UI instance
    */
  def initNewInMemory(net: ComputationGraph, listenFreq: Int, enableRemote: Boolean): UIServer = {
    val uiServer = UIServer.getInstance
    val statsStorage = new InMemoryStatsStorage() //Alternative: new FileStatsStorage(File) - see UIStorageExample
    uiServer.attach(statsStorage)
    net.setListeners(new ScoreIterationListener(listenFreq), new StatsListener(statsStorage))
    if (enableRemote) uiServer.enableRemoteListener()
    uiServer
  }

  /** Write all stats in a file, which can be later loaded into a UI instance.
    *
    * @param net multilayer network
    * @param storagePath storage file path
    * @param listenFreq listener frequency
    */
  def write2File(net: MultiLayerNetwork, storagePath: String, listenFreq: Int = 1): Unit = {
    net.setListeners(
      if (System.getProperty("java.version").startsWith("1.7"))
        new J7StatsListener(new J7FileStatsStorage(new File(storagePath)), listenFreq)
      else
        new StatsListener(new FileStatsStorage(new File(storagePath)), listenFreq)
      )
  }

  /** Write all stats in a file, which can be later loaded into a UI instance.
    *
    * @param net Computation Graph
    * @param storagePath storage file path
    * @param listenFreq listener frequency
    */
  def write2File(net: ComputationGraph, storagePath: String, listenFreq: Int): Unit = {
    net.setListeners(
      if (System.getProperty("java.version").startsWith("1.7"))
        new J7StatsListener(new J7FileStatsStorage(new File(storagePath)), listenFreq)
      else
        new StatsListener(new FileStatsStorage(new File(storagePath)), listenFreq)
      )
  }

  /** Write all stats in a file, which can be later loaded into a UI instance.
    *
    * @param net computation graph
    * @param storagePath storage file path
    * @param listenFreq listener frequency


  def write2File(net: ComputationGraph, storagePath: String, listenFreq: Int = 1): Unit = {
    net.setListeners(
      if (System.getProperty("java.version").startsWith("1.7"))
        new J7StatsListener(new J7FileStatsStorage(new File(storagePath)), listenFreq)
      else
        new StatsListener(new FileStatsStorage(new File(storagePath)), listenFreq)
      )
  }    */

  /** Initiate a UI instance with pre-calculated stats loaded from a file.
    * @param storagePath file path containing all precalculated stats
    * @return the UI instance
    */
  def initNewFromFile(storagePath: String): UIServer = {
    val uiServer = UIServer.getInstance
    val file = new File(storagePath)
    uiServer.attach(new FileStatsStorage(file))
    uiServer
  }

  /** Enable Remote Listener
    * @param uiServer UI Instance
    * @param return the UI instance with remote-listener enabled.
    */
  def enableRemote(uiServer: UIServer, storagePath: Option[String] = None): UIServer = {
    if (storagePath == None)
      uiServer.enableRemoteListener()
    else
      uiServer.enableRemoteListener(new FileStatsStorage(new File(storagePath.get)), true)
    uiServer
  }

  /** Connect to an existing UI instance, which is initiated in another JVM and enabled remote-listener.
    * @param net
    * @param listenFreq
    * @param ip ip address, could be 127.0.0.1 or localhost
    * @param port port, could be 9000
    */
  def connect(net: MultiLayerNetwork, listenFreq: Int = 1,
    ip: String = "localhost", port: String = "9000"): Unit = {
    net.setListeners(new StatsListener(new RemoteUIStatsStorageRouter(s"http://${ip}:${port}"), listenFreq))
  }

  /** Connect to an existing UI instance, which is initiated in another JVM and enabled remote-listener.
    * @param net
    * @param listenFreq
    * @param ip ip address, could be 127.0.0.1 or localhost
    * @param port port, could be 9000
    */
  def connect(net: ComputationGraph, listenFreq: Int,
    ip: String, port: String): Unit = {
    net.setListeners(new StatsListener(new RemoteUIStatsStorageRouter(s"http://${ip}:${port}"), listenFreq))
  }

}
