/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.image.recordreader.objdetect;

import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import org.datavec.api.records.impl.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataImageURI;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.util.files.FileFromPathIterator;
import org.datavec.api.util.files.URIUtil;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.datavec.image.data.Image;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.BaseImageRecordReader;
import org.datavec.image.recordreader.objdetect.ImageObject;
import org.datavec.image.recordreader.objdetect.ImageObjectLabelProvider;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.util.ImageUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class ObjectDetectionRecordReader
extends BaseImageRecordReader {
    private final int gridW;
    private final int gridH;
    private final ImageObjectLabelProvider labelProvider;
    protected Image currentImage;

    public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider) {
        super(height, width, channels, null, null);
        this.gridW = gridW;
        this.gridH = gridH;
        this.labelProvider = labelProvider;
        this.appendLabel = labelProvider != null;
    }

    public ObjectDetectionRecordReader(int height, int width, int channels, int gridH, int gridW, ImageObjectLabelProvider labelProvider, ImageTransform imageTransform) {
        super(height, width, channels, null, null);
        this.gridW = gridW;
        this.gridH = gridH;
        this.labelProvider = labelProvider;
        this.appendLabel = labelProvider != null;
        this.imageTransform = imageTransform;
    }

    @Override
    public List<Writable> next() {
        return this.next(1).get(0);
    }

    @Override
    public void initialize(InputSplit split) throws IOException {
        if (this.imageLoader == null) {
            this.imageLoader = new NativeImageLoader(this.height, this.width, this.channels, this.imageTransform);
        }
        this.inputSplit = split;
        URI[] locations = split.locations();
        HashSet<String> labelSet = new HashSet<String>();
        if (locations != null && locations.length >= 1) {
            for (URI location : locations) {
                List<ImageObject> imageObjects = this.labelProvider.getImageObjectsForPath(location);
                for (ImageObject io : imageObjects) {
                    String name = io.getLabel();
                    if (labelSet.contains(name)) continue;
                    labelSet.add(name);
                }
            }
        } else {
            throw new IllegalArgumentException("No path locations found in the split.");
        }
        this.iter = new FileFromPathIterator(this.inputSplit.locationsPathIterator());
        if (split instanceof FileSplit) {
            FileSplit split1 = (FileSplit)split;
            this.labels.remove(split1.getRootDir());
        }
        this.labels = new ArrayList(labelSet);
        Collections.sort(this.labels);
    }

    @Override
    public List<List<Writable>> next(int num) {
        ArrayList<File> files = new ArrayList<File>(num);
        ArrayList<List<ImageObject>> objects = new ArrayList<List<ImageObject>>(num);
        for (int i = 0; i < num && this.hasNext(); ++i) {
            File f;
            this.currentFile = f = (File)this.iter.next();
            if (f.isDirectory()) continue;
            files.add(f);
            objects.add(this.labelProvider.getImageObjectsForPath(f.getPath()));
        }
        int nClasses = this.labels.size();
        INDArray outImg = Nd4j.create((long[])new long[]{files.size(), this.channels, this.height, this.width});
        INDArray outLabel = Nd4j.create((int[])new int[]{files.size(), 4 + nClasses, this.gridH, this.gridW});
        int exampleNum = 0;
        for (int i = 0; i < files.size(); ++i) {
            File imageFile;
            this.currentFile = imageFile = (File)files.get(i);
            try {
                Image image;
                this.invokeListeners(imageFile);
                this.currentImage = image = this.imageLoader.asImageMatrix(imageFile);
                Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE);
                outImg.put(new INDArrayIndex[]{NDArrayIndex.point((long)exampleNum), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}, image.getImage());
                List objectsThisImg = (List)objects.get(exampleNum);
                this.label(image, objectsThisImg, outLabel, exampleNum);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            ++exampleNum;
        }
        return new NDArrayRecordBatch(Arrays.asList(outImg, outLabel));
    }

    private void label(Image image, List<ImageObject> objectsThisImg, INDArray outLabel, int exampleNum) {
        int oW = image.getOrigW();
        int oH = image.getOrigH();
        int W = oW;
        int H = oH;
        for (ImageObject io : objectsThisImg) {
            double cx = io.getXCenterPixels();
            double cy = io.getYCenterPixels();
            if (this.imageTransform != null) {
                W = this.imageTransform.getCurrentImage().getWidth();
                H = this.imageTransform.getCurrentImage().getHeight();
                float[] pts = this.imageTransform.query(io.getX1(), io.getY1(), io.getX2(), io.getY2());
                int minX = Math.round(Math.min(pts[0], pts[2]));
                int maxX = Math.round(Math.max(pts[0], pts[2]));
                int minY = Math.round(Math.min(pts[1], pts[3]));
                int maxY = Math.round(Math.max(pts[1], pts[3]));
                io = new ImageObject(minX, minY, maxX, maxY, io.getLabel());
                cx = io.getXCenterPixels();
                cy = io.getYCenterPixels();
                if (cx < 0.0 || cx >= (double)W || cy < 0.0 || cy >= (double)H) continue;
            }
            double[] cxyPostScaling = ImageUtils.translateCoordsScaleImage(cx, cy, W, H, this.width, this.height);
            double[] tlPost = ImageUtils.translateCoordsScaleImage(io.getX1(), io.getY1(), W, H, this.width, this.height);
            double[] brPost = ImageUtils.translateCoordsScaleImage(io.getX2(), io.getY2(), W, H, this.width, this.height);
            int imgGridX = (int)(cxyPostScaling[0] / (double)this.width * (double)this.gridW);
            int imgGridY = (int)(cxyPostScaling[1] / (double)this.height * (double)this.gridH);
            tlPost[0] = tlPost[0] / (double)this.width * (double)this.gridW;
            tlPost[1] = tlPost[1] / (double)this.height * (double)this.gridH;
            brPost[0] = brPost[0] / (double)this.width * (double)this.gridW;
            brPost[1] = brPost[1] / (double)this.height * (double)this.gridH;
            Preconditions.checkState((imgGridY >= 0 && (long)imgGridY < outLabel.size(2) ? 1 : 0) != 0, (String)"Invalid image center in Y axis: calculated grid location of %s, must be between 0 (inclusive) and %s (exclusive). Object label center is outside of image bounds. Image object: %s", (Object)imgGridY, (Object)outLabel.size(2), (Object)io);
            Preconditions.checkState((imgGridX >= 0 && (long)imgGridX < outLabel.size(3) ? 1 : 0) != 0, (String)"Invalid image center in X axis: calculated grid location of %s, must be between 0 (inclusive) and %s (exclusive). Object label center is outside of image bounds. Image object: %s", (Object)imgGridY, (Object)outLabel.size(2), (Object)io);
            outLabel.putScalar((long)exampleNum, 0L, (long)imgGridY, (long)imgGridX, tlPost[0]);
            outLabel.putScalar((long)exampleNum, 1L, (long)imgGridY, (long)imgGridX, tlPost[1]);
            outLabel.putScalar((long)exampleNum, 2L, (long)imgGridY, (long)imgGridX, brPost[0]);
            outLabel.putScalar((long)exampleNum, 3L, (long)imgGridY, (long)imgGridX, brPost[1]);
            int labelIdx = this.labels.indexOf(io.getLabel());
            outLabel.putScalar((long)exampleNum, (long)(4 + labelIdx), (long)imgGridY, (long)imgGridX, 1.0);
        }
    }

    @Override
    public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
        this.invokeListeners(uri);
        if (this.imageLoader == null) {
            this.imageLoader = new NativeImageLoader(this.height, this.width, this.channels, this.imageTransform);
        }
        Image image = this.imageLoader.asImageMatrix(dataInputStream);
        Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE);
        List ret = RecordConverter.toRecord((INDArray)image.getImage());
        if (this.appendLabel) {
            List<ImageObject> imageObjectsForPath = this.labelProvider.getImageObjectsForPath(uri.getPath());
            int nClasses = this.labels.size();
            INDArray outLabel = Nd4j.create((int[])new int[]{1, 4 + nClasses, this.gridH, this.gridW});
            this.label(image, imageObjectsForPath, outLabel, 0);
            ret.add(new NDArrayWritable(outLabel));
        }
        return ret;
    }

    @Override
    public org.datavec.api.records.Record nextRecord() {
        List<Writable> list = this.next();
        URI uri = URIUtil.fileToURI((File)this.currentFile);
        return new Record(list, (RecordMetaData)new RecordMetaDataImageURI(uri, BaseImageRecordReader.class, this.currentImage.getOrigC(), this.currentImage.getOrigH(), this.currentImage.getOrigW()));
    }
}

