/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.datavec;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.SelfWritableConverter;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataComposableMap;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.ConcatenatingRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RecordReaderDataSetIterator
implements DataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(RecordReaderDataSetIterator.class);
    private static final String READER_KEY = "reader";
    protected RecordReader recordReader;
    protected WritableConverter converter;
    protected int batchSize = 10;
    protected int maxNumBatches = -1;
    protected int batchNum = 0;
    protected int labelIndex = -1;
    protected int labelIndexTo = -1;
    protected int numPossibleLabels = -1;
    protected Iterator<List<Writable>> sequenceIter;
    protected DataSet last;
    protected boolean useCurrent = false;
    protected boolean regression = false;
    protected DataSetPreProcessor preProcessor;
    private boolean collectMetaData = false;
    private RecordReaderMultiDataSetIterator underlying;
    private boolean underlyingIsDisjoint;

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, -1, -1, recordReader.getLabels() == null ? -1 : recordReader.getLabels().size(), -1, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndex, int numPossibleLabels) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, labelIndex, labelIndex, numPossibleLabels, -1, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndex, int numPossibleLabels, int maxNumBatches) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, labelIndex, labelIndex, numPossibleLabels, maxNumBatches, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int batchSize, int labelIndexFrom, int labelIndexTo, boolean regression) {
        this(recordReader, (WritableConverter)new SelfWritableConverter(), batchSize, labelIndexFrom, labelIndexTo, -1, -1, regression);
        if (!regression) {
            throw new IllegalArgumentException("This constructor is only for creating regression iterators. If you're doing classification you need to use another constructor that (implicitly) specifies numPossibleLabels");
        }
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter converter, int batchSize, int labelIndexFrom, int labelIndexTo, int numPossibleLabels, int maxNumBatches, boolean regression) {
        this.recordReader = recordReader;
        this.converter = converter;
        this.batchSize = batchSize;
        this.maxNumBatches = maxNumBatches;
        this.labelIndex = labelIndexFrom;
        this.labelIndexTo = labelIndexTo;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
    }

    protected RecordReaderDataSetIterator(Builder b) {
        this.recordReader = b.recordReader;
        this.converter = b.converter;
        this.batchSize = b.batchSize;
        this.maxNumBatches = b.maxNumBatches;
        this.labelIndex = b.labelIndex;
        this.labelIndexTo = b.labelIndexTo;
        this.numPossibleLabels = b.numPossibleLabels;
        this.regression = b.regression;
        this.preProcessor = b.preProcessor;
    }

    public void setCollectMetaData(boolean collectMetaData) {
        if (this.underlying != null) {
            this.underlying.setCollectMetaData(collectMetaData);
        }
        this.collectMetaData = collectMetaData;
    }

    private void initializeUnderlying() {
        if (this.underlying == null) {
            Record next = this.recordReader.nextRecord();
            this.initializeUnderlying(next);
        }
    }

    private void initializeUnderlying(Record next) {
        int totalSize = next.getRecord().size();
        if (this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndexTo = this.labelIndex = totalSize - 1;
        }
        if (this.recordReader.resetSupported()) {
            this.recordReader.reset();
        } else {
            this.recordReader = new ConcatenatingRecordReader(new RecordReader[]{new CollectionRecordReader(Collections.singletonList(next.getRecord())), this.recordReader});
        }
        RecordReaderMultiDataSetIterator.Builder builder = new RecordReaderMultiDataSetIterator.Builder(this.batchSize);
        if (this.recordReader instanceof SequenceRecordReader) {
            builder.addSequenceReader(READER_KEY, (SequenceRecordReader)this.recordReader);
        } else {
            builder.addReader(READER_KEY, this.recordReader);
        }
        if (this.regression) {
            builder.addOutput(READER_KEY, this.labelIndex, this.labelIndexTo);
        } else if (this.numPossibleLabels >= 1) {
            builder.addOutputOneHot(READER_KEY, this.labelIndex, this.numPossibleLabels);
        }
        if (this.labelIndex >= 0 && (this.labelIndex == 0 || this.labelIndexTo == totalSize - 1)) {
            int inputTo;
            int inputFrom;
            if (this.labelIndex < 0) {
                inputFrom = 0;
                inputTo = totalSize - 1;
            } else if (this.labelIndex == 0) {
                inputFrom = this.labelIndexTo + 1;
                inputTo = totalSize - 1;
            } else {
                inputFrom = 0;
                inputTo = this.labelIndex - 1;
            }
            builder.addInput(READER_KEY, inputFrom, inputTo);
            this.underlyingIsDisjoint = false;
        } else if (this.labelIndex >= 0) {
            Preconditions.checkState((this.labelIndex < next.getRecord().size() ? 1 : 0) != 0, (String)"Invalid label (from) index: index must be in range 0 to first record size of (0 to %s inclusive), got %s", (int)(next.getRecord().size() - 1), (int)this.labelIndex);
            Preconditions.checkState((this.labelIndexTo < next.getRecord().size() ? 1 : 0) != 0, (String)"Invalid label (to) index: index must be in range 0 to first record size of (0 to %s inclusive), got %s", (int)(next.getRecord().size() - 1), (int)this.labelIndexTo);
            int firstFrom = 0;
            int firstTo = this.labelIndex - 1;
            int secondFrom = this.labelIndexTo + 1;
            int secondTo = totalSize - 1;
            builder.addInput(READER_KEY, firstFrom, firstTo);
            builder.addInput(READER_KEY, secondFrom, secondTo);
            this.underlyingIsDisjoint = true;
        } else {
            builder.addInput(READER_KEY);
            this.underlyingIsDisjoint = false;
        }
        this.underlying = builder.build();
        if (this.collectMetaData) {
            this.underlying.setCollectMetaData(true);
        }
    }

    private DataSet mdsToDataSet(MultiDataSet mds) {
        INDArray f;
        INDArray fm;
        if (this.underlyingIsDisjoint) {
            INDArray f1 = RecordReaderDataSetIterator.getOrNull(mds.getFeatures(), 0);
            INDArray f2 = RecordReaderDataSetIterator.getOrNull(mds.getFeatures(), 1);
            fm = RecordReaderDataSetIterator.getOrNull(mds.getFeaturesMaskArrays(), 0);
            f = Nd4j.hstack((INDArray[])new INDArray[]{f1, f2});
        } else {
            f = RecordReaderDataSetIterator.getOrNull(mds.getFeatures(), 0);
            fm = RecordReaderDataSetIterator.getOrNull(mds.getFeaturesMaskArrays(), 0);
        }
        INDArray l = RecordReaderDataSetIterator.getOrNull(mds.getLabels(), 0);
        INDArray lm = RecordReaderDataSetIterator.getOrNull(mds.getLabelsMaskArrays(), 0);
        DataSet ds = new DataSet(f, l, fm, lm);
        if (this.collectMetaData) {
            List temp = mds.getExampleMetaData();
            ArrayList temp2 = new ArrayList(temp.size());
            for (Serializable s : temp) {
                RecordMetaDataComposableMap m = (RecordMetaDataComposableMap)s;
                temp2.add(m.getMeta().get(READER_KEY));
            }
            ds.setExampleMetaData(temp2);
        }
        if (this.labelIndex == -1 && this.numPossibleLabels == -1 && ds.getLabels() == null) {
            ds.setLabels(ds.getFeatures());
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.DataSet)ds);
        }
        return ds;
    }

    public DataSet next(int num) {
        if (this.useCurrent) {
            this.useCurrent = false;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.DataSet)this.last);
            }
            return this.last;
        }
        if (this.underlying == null) {
            this.initializeUnderlying();
        }
        ++this.batchNum;
        return this.mdsToDataSet(this.underlying.next(num));
    }

    static INDArray getOrNull(INDArray[] arr, int idx) {
        if (arr == null || arr.length == 0) {
            return null;
        }
        return arr[idx];
    }

    public int inputColumns() {
        if (this.last == null) {
            DataSet next;
            this.last = next = this.next();
            this.useCurrent = true;
            return next.numInputs();
        }
        return this.last.numInputs();
    }

    public int totalOutcomes() {
        if (this.last == null) {
            DataSet next;
            this.last = next = this.next();
            this.useCurrent = true;
            return next.numOutcomes();
        }
        return this.last.numOutcomes();
    }

    public boolean resetSupported() {
        if (this.underlying == null) {
            this.initializeUnderlying();
        }
        return this.underlying.resetSupported();
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        this.batchNum = 0;
        if (this.underlying != null) {
            this.underlying.reset();
        }
        this.last = null;
        this.useCurrent = false;
    }

    public int batch() {
        return this.batchSize;
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    public boolean hasNext() {
        return (this.sequenceIter != null && this.sequenceIter.hasNext() || this.recordReader.hasNext()) && (this.maxNumBatches < 0 || this.batchNum < this.maxNumBatches);
    }

    public DataSet next() {
        return this.next(this.batchSize);
    }

    public void remove() {
        throw new UnsupportedOperationException();
    }

    public List<String> getLabels() {
        return this.recordReader.getLabels();
    }

    public DataSet loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return this.loadFromMetaData(Collections.singletonList(recordMetaData));
    }

    public DataSet loadFromMetaData(List<RecordMetaData> list) throws IOException {
        if (this.underlying == null) {
            Record r = this.recordReader.loadFromMetaData(list.get(0));
            this.initializeUnderlying(r);
        }
        ArrayList<RecordMetaData> l = new ArrayList<RecordMetaData>(list.size());
        for (RecordMetaData m : list) {
            l.add((RecordMetaData)new RecordMetaDataComposableMap(Collections.singletonMap(READER_KEY, m)));
        }
        MultiDataSet m = this.underlying.loadFromMetaData(l);
        return this.mdsToDataSet(m);
    }

    public RecordReader getRecordReader() {
        return this.recordReader;
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public boolean isCollectMetaData() {
        return this.collectMetaData;
    }

    public static class Builder {
        protected RecordReader recordReader;
        protected WritableConverter converter;
        protected int batchSize;
        protected int maxNumBatches = -1;
        protected int labelIndex = -1;
        protected int labelIndexTo = -1;
        protected int numPossibleLabels = -1;
        protected boolean regression = false;
        protected DataSetPreProcessor preProcessor;
        private boolean collectMetaData = false;
        private boolean clOrRegCalled = false;

        public Builder(@NonNull RecordReader rr, int batchSize) {
            if (rr == null) {
                throw new NullPointerException("rr is marked @NonNull but is null");
            }
            this.recordReader = rr;
            this.batchSize = batchSize;
        }

        public Builder writableConverter(WritableConverter converter) {
            this.converter = converter;
            return this;
        }

        public Builder maxNumBatches(int maxNumBatches) {
            this.maxNumBatches = maxNumBatches;
            return this;
        }

        public Builder regression(int labelIndex) {
            return this.regression(labelIndex, labelIndex);
        }

        public Builder regression(int labelIndexFrom, int labelIndexTo) {
            this.labelIndex = labelIndexFrom;
            this.labelIndexTo = labelIndexTo;
            this.regression = true;
            this.clOrRegCalled = true;
            return this;
        }

        public Builder classification(int labelIndex, int numClasses) {
            this.labelIndex = labelIndex;
            this.labelIndexTo = labelIndex;
            this.numPossibleLabels = numClasses;
            this.regression = false;
            this.clOrRegCalled = true;
            return this;
        }

        public Builder preProcessor(DataSetPreProcessor preProcessor) {
            this.preProcessor = preProcessor;
            return this;
        }

        public Builder collectMetaData(boolean collectMetaData) {
            this.collectMetaData = collectMetaData;
            return this;
        }

        public RecordReaderDataSetIterator build() {
            return new RecordReaderDataSetIterator(this);
        }
    }
}

