package ai.minxiao.ds4s.core.dl4j

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork

/**
  * @author mx
  */
package object util {
  def showParams(net: MultiLayerNetwork): (Int, Long) = {
    val layers = net.getLayers
    val nLayers = layers.length
    var totalNumParams = 0L
    println(s"$nLayers layers ...")
    for (i <- layers.indices) {
      val nParams  = layers(i).numParams
      println("Number of parameters in layer " + i + ": " + nParams)
      totalNumParams += nParams
    }
    println("Total number of network parameters: " + totalNumParams)
    (nLayers, totalNumParams)
  }

}
