/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.ndarray.impl.dimension;

import java.util.Arrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.impl.dimension.Axis;
import org.tensorflow.ndarray.impl.dimension.Dimension;
import org.tensorflow.ndarray.impl.dimension.ReducedDimension;
import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace;
import org.tensorflow.ndarray.index.Index;

public class DimensionalSpace {
    private final Dimension[] dimensions;
    private final int segmentationIdx;
    private Shape shape;

    public static DimensionalSpace create(Shape shape) {
        Dimension[] dimensions = new Dimension[shape.numDimensions()];
        int elementSize = 1;
        for (int i = dimensions.length - 1; i >= 0; --i) {
            dimensions[i] = new Axis(shape.get(i), elementSize);
            elementSize = (int)((long)elementSize * dimensions[i].numElements());
        }
        return new DimensionalSpace(dimensions, shape);
    }

    public RelativeDimensionalSpace mapTo(Index[] indices) {
        if (this.dimensions == null) {
            throw new ArrayIndexOutOfBoundsException();
        }
        int dimIdx = 0;
        int indexIdx = 0;
        int newDimIdx = 0;
        int segmentationIdx = -1;
        long initialOffset = 0L;
        int newAxes = 0;
        boolean seenEllipsis = false;
        for (Index idx : indices) {
            if (idx.isNewAxis()) {
                ++newAxes;
            }
            if (!idx.isEllipsis()) continue;
            if (seenEllipsis) {
                throw new IllegalArgumentException("Only one ellipsis allowed");
            }
            seenEllipsis = true;
        }
        int newLength = this.dimensions.length + newAxes;
        Dimension[] newDimensions = new Dimension[newLength];
        while (indexIdx < indices.length) {
            Dimension newDimension;
            if (indices[indexIdx].isPoint()) {
                long offset = 0L;
                do {
                    offset += indices[indexIdx].mapCoordinate(0L, this.dimensions[dimIdx]);
                    ++dimIdx;
                } while (++indexIdx < indices.length && indices[indexIdx].isPoint());
                if (newDimIdx == 0) {
                    initialOffset = offset;
                    continue;
                }
                long reducedSize = this.dimensions[dimIdx - 1].elementSize();
                newDimensions[newDimIdx - 1] = new ReducedDimension(newDimensions[newDimIdx - 1], offset, reducedSize);
                segmentationIdx = newDimIdx - 1;
                continue;
            }
            if (indices[indexIdx].isNewAxis()) {
                long newSize = dimIdx == 0 ? this.dimensions[0].numElements() * this.dimensions[0].elementSize() : this.dimensions[dimIdx - 1].elementSize();
                newDimensions[newDimIdx] = new Axis(1L, newSize);
                segmentationIdx = newDimIdx++;
                ++indexIdx;
                continue;
            }
            if (indices[indexIdx].isEllipsis()) {
                int remainingDimensions = this.dimensions.length - dimIdx;
                int requiredDimensions = 0;
                for (int i = indexIdx + 1; i < indices.length; ++i) {
                    if (indices[i].isNewAxis()) continue;
                    ++requiredDimensions;
                }
                while (remainingDimensions > requiredDimensions) {
                    Dimension dim;
                    if ((dim = this.dimensions[dimIdx++]).isSegmented()) {
                        segmentationIdx = newDimIdx;
                    }
                    newDimensions[newDimIdx++] = dim;
                    --remainingDimensions;
                }
                ++indexIdx;
                continue;
            }
            newDimensions[newDimIdx] = newDimension = indices[indexIdx].apply(this.dimensions[dimIdx++]);
            if (newDimension.isSegmented()) {
                segmentationIdx = newDimIdx;
            }
            ++newDimIdx;
            ++indexIdx;
        }
        while (dimIdx < this.dimensions.length) {
            Dimension dim;
            newDimensions[newDimIdx] = dim = this.dimensions[dimIdx];
            if (dim.isSegmented()) {
                segmentationIdx = newDimIdx;
            }
            ++dimIdx;
            ++newDimIdx;
        }
        return new RelativeDimensionalSpace(Arrays.copyOf(newDimensions, newDimIdx), segmentationIdx, initialOffset);
    }

    public DimensionalSpace from(int dimensionStart) {
        if (dimensionStart > this.dimensions.length) {
            throw new IndexOutOfBoundsException();
        }
        Dimension[] newDimensions = Arrays.copyOfRange(this.dimensions, dimensionStart, this.dimensions.length);
        if (this.segmentationIdx >= dimensionStart) {
            return new DimensionalSpace(newDimensions, this.segmentationIdx - dimensionStart);
        }
        return new DimensionalSpace(newDimensions);
    }

    public Shape shape() {
        if (this.shape == null) {
            this.shape = DimensionalSpace.toShape(this.dimensions);
        }
        return this.shape;
    }

    public int numDimensions() {
        return this.dimensions.length;
    }

    public long numElements(int i) {
        return this.dimensions[i].numElements();
    }

    public long physicalSize() {
        return this.dimensions.length > 0 ? this.dimensions[0].physicalSize() : 1L;
    }

    public Dimension get(int i) {
        return this.dimensions[i];
    }

    public boolean isSegmented() {
        return this.segmentationIdx >= 0;
    }

    public int segmentationIdx() {
        return this.segmentationIdx;
    }

    public long positionOf(long[] coords) {
        long position = 0L;
        for (int i = 0; i < coords.length; ++i) {
            position += this.dimensions[i].positionOf(coords[i]);
        }
        return position;
    }

    public String toString() {
        return Arrays.toString(this.dimensions);
    }

    DimensionalSpace(Dimension[] dimensions, int segmentationIdx) {
        this.dimensions = dimensions;
        this.segmentationIdx = segmentationIdx;
    }

    private DimensionalSpace(Dimension[] dimensions) {
        this(dimensions, -1);
    }

    private DimensionalSpace(Dimension[] dimensions, Shape shape) {
        this(dimensions);
        this.shape = shape;
    }

    private static Shape toShape(Dimension[] dimensions) {
        long[] shapeDimSizes = new long[dimensions.length];
        int i = 0;
        for (Dimension dimension : dimensions) {
            shapeDimSizes[i++] = dimension.numElements();
        }
        return Shape.of(shapeDimSizes);
    }
}

