/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class DataSetLossCalculatorCG
implements ScoreCalculator<ComputationGraph> {
    private DataSetIterator dataSetIterator;
    private MultiDataSetIterator multiDataSetIterator;
    private boolean average;

    public DataSetLossCalculatorCG(DataSetIterator dataSetIterator, boolean average) {
        this.dataSetIterator = dataSetIterator;
        this.average = average;
    }

    public DataSetLossCalculatorCG(MultiDataSetIterator dataSetIterator, boolean average) {
        this.multiDataSetIterator = dataSetIterator;
        this.average = average;
    }

    @Override
    public double calculateScore(ComputationGraph network) {
        double lossSum = 0.0;
        int exCount = 0;
        if (this.dataSetIterator != null) {
            this.dataSetIterator.reset();
            while (this.dataSetIterator.hasNext()) {
                org.nd4j.linalg.dataset.DataSet dataSet = (org.nd4j.linalg.dataset.DataSet)this.dataSetIterator.next();
                int nEx = dataSet.getFeatureMatrix().size(0);
                lossSum += network.score((DataSet)dataSet) * (double)nEx;
                exCount += nEx;
            }
        } else {
            this.multiDataSetIterator.reset();
            while (this.multiDataSetIterator.hasNext()) {
                MultiDataSet dataSet = (MultiDataSet)this.multiDataSetIterator.next();
                int nEx = dataSet.getFeatures(0).size(0);
                lossSum += network.score(dataSet) * (double)nEx;
                exCount += nEx;
            }
        }
        if (this.average) {
            return lossSum / (double)exCount;
        }
        return lossSum;
    }

    public String toString() {
        return "DataSetLossCalculatorCG(" + this.dataSetIterator + ",average=" + this.average + ")";
    }
}

