/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.iterator.impl;

import java.util.List;
import org.datavec.image.loader.CifarLoader;
import org.datavec.image.transform.ImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.DataSet;

public class CifarDataSetIterator
extends RecordReaderDataSetIterator {
    protected static int height = 32;
    protected static int width = 32;
    protected static int channels = 3;
    protected static CifarLoader loader;
    protected int totalExamples;
    protected int numExamples;
    protected int exampleCount;
    protected boolean overshot;
    protected ImageTransform imageTransform;
    protected static boolean useSpecialPreProcessCifar;
    protected static boolean train;

    public CifarDataSetIterator(int batchSize, int numExamples, boolean train) {
        this(batchSize, numExamples, new int[]{height, width, channels}, 10, null, useSpecialPreProcessCifar, train);
    }

    public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim) {
        this(batchSize, numExamples, imgDim, 10, null, useSpecialPreProcessCifar, train);
    }

    public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim, boolean train) {
        this(batchSize, numExamples, imgDim, 10, null, useSpecialPreProcessCifar, train);
    }

    public CifarDataSetIterator(int batchSize, int numExamples) {
        this(batchSize, numExamples, new int[]{height, width, channels}, 10, null, useSpecialPreProcessCifar, train);
    }

    public CifarDataSetIterator(int batchSize, int[] imgDim) {
        this(batchSize, 50000, imgDim, 10, null, useSpecialPreProcessCifar, train);
    }

    public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim, boolean useSpecialPreProcessCifar, boolean train) {
        this(batchSize, numExamples, imgDim, 10, null, useSpecialPreProcessCifar, train);
    }

    public CifarDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numPossibleLables, ImageTransform imageTransform, boolean useSpecialPreProcessCifar, boolean train) {
        super(null, batchSize, 1, numExamples);
        this.numExamples = this.totalExamples = 50000;
        this.exampleCount = 0;
        this.overshot = false;
        loader = new CifarLoader(imgDim[0], imgDim[1], imgDim[2], imageTransform, train, useSpecialPreProcessCifar);
        this.totalExamples = train ? this.totalExamples : 10000;
        this.numExamples = numExamples > this.totalExamples ? this.totalExamples : numExamples;
        this.numPossibleLabels = numPossibleLables;
        this.imageTransform = imageTransform;
        CifarDataSetIterator.useSpecialPreProcessCifar = useSpecialPreProcessCifar;
        CifarDataSetIterator.train = train;
    }

    @Override
    public org.nd4j.linalg.dataset.DataSet next(int batchSize) {
        if (this.useCurrent) {
            this.useCurrent = false;
            return this.last;
        }
        org.nd4j.linalg.dataset.DataSet result = useSpecialPreProcessCifar ? loader.next(batchSize, this.exampleCount) : loader.next(batchSize);
        this.exampleCount += batchSize;
        ++this.batchNum;
        if (result.getFeatureMatrix() == null || result == new org.nd4j.linalg.dataset.DataSet() || this.maxNumBatches > -1 && this.batchNum >= this.maxNumBatches) {
            this.overshot = true;
            return this.last;
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)result);
        }
        this.last = result;
        if (loader.getLabels() != null) {
            result.setLabelNames(loader.getLabels());
        }
        return result;
    }

    @Override
    public boolean hasNext() {
        return this.exampleCount < this.numExamples && (this.maxNumBatches == -1 || this.batchNum < this.maxNumBatches) && !this.overshot;
    }

    @Override
    public int totalExamples() {
        return this.totalExamples;
    }

    @Override
    public void reset() {
        this.exampleCount = 0;
        this.overshot = false;
        this.batchNum = 0;
        loader.reset();
    }

    @Override
    public List<String> getLabels() {
        return loader.getLabels();
    }

    @Override
    public boolean asyncSupported() {
        return false;
    }

    public void train() {
        train = true;
        loader.train();
        this.reset();
    }

    public void test() {
        this.test(10000, this.batchSize);
    }

    public void test(int numExamples) {
        this.test(numExamples, this.batchSize);
    }

    public void test(int numExamples, int batchSize) {
        this.batchSize = batchSize;
        train = false;
        loader.test();
        this.numExamples = numExamples;
        this.totalExamples = 10000;
        this.exampleCount = 0;
        this.overshot = false;
        this.batchNum = 0;
    }

    static {
        useSpecialPreProcessCifar = false;
        train = true;
    }
}

