/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.iterator;

import java.util.ArrayList;
import java.util.List;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class KFoldIterator
implements DataSetIterator {
    private DataSet allData;
    private int k;
    private int batch;
    private int lastBatch;
    private int kCursor = 0;
    private DataSet test;
    private DataSet train;
    protected DataSetPreProcessor preProcessor;

    public KFoldIterator(DataSet allData) {
        this(10, allData);
    }

    public KFoldIterator(int k, DataSet allData) {
        this.k = k;
        this.allData = allData.copy();
        if (k <= 1) {
            throw new IllegalArgumentException();
        }
        if (allData.numExamples() % k != 0) {
            this.batch = (int)Math.ceil((double)allData.numExamples() / (double)k);
            this.lastBatch = allData.numExamples() - (k - 1) * this.batch;
        } else {
            this.batch = allData.numExamples() / k;
            this.lastBatch = allData.numExamples() / k;
        }
    }

    @Override
    public DataSet next(int num) throws UnsupportedOperationException {
        return null;
    }

    public int totalExamples() {
        return (int)this.allData.getLabels().size(0);
    }

    @Override
    public int inputColumns() {
        return (int)this.allData.getFeatures().size(1);
    }

    @Override
    public int totalOutcomes() {
        return (int)this.allData.getLabels().size(1);
    }

    @Override
    public boolean resetSupported() {
        return true;
    }

    @Override
    public boolean asyncSupported() {
        return false;
    }

    @Override
    public void reset() {
        this.allData.shuffle();
        this.kCursor = 0;
    }

    @Override
    public int batch() {
        return this.batch;
    }

    public int lastBatch() {
        return this.lastBatch;
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    @Override
    public List<String> getLabels() {
        return this.allData.getLabelNamesList();
    }

    @Override
    public boolean hasNext() {
        return this.kCursor < this.k;
    }

    @Override
    public DataSet next() {
        this.nextFold();
        return this.train;
    }

    @Override
    public void remove() {
    }

    private void nextFold() {
        int right;
        int left;
        if (this.kCursor == this.k - 1) {
            left = this.totalExamples() - this.lastBatch;
            right = this.totalExamples();
        } else {
            left = this.kCursor * this.batch;
            right = left + this.batch;
        }
        ArrayList<DataSet> kMinusOneFoldList = new ArrayList<DataSet>();
        if (right < this.totalExamples()) {
            if (left > 0) {
                kMinusOneFoldList.add((DataSet)this.allData.getRange(0, left));
            }
            kMinusOneFoldList.add((DataSet)this.allData.getRange(right, this.totalExamples()));
            this.train = DataSet.merge(kMinusOneFoldList);
        } else {
            this.train = (DataSet)this.allData.getRange(0, left);
        }
        this.test = (DataSet)this.allData.getRange(left, right);
        ++this.kCursor;
    }

    public DataSet testFold() {
        return this.test;
    }
}

