/*
 Copyright 2019 The TensorFlow Authors. All Rights Reserved.

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 =======================================================================
 */
package org.tensorflow.ndarray.impl.dense;

import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArraySequence;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.impl.AbstractNdArray;
import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace;
import org.tensorflow.ndarray.impl.sequence.FastElementSequence;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DataBufferWindow;
import org.tensorflow.ndarray.impl.dimension.DimensionalSpace;
import org.tensorflow.ndarray.impl.sequence.SlicingElementSequence;
import org.tensorflow.ndarray.impl.sequence.SingleElementSequence;

@SuppressWarnings("unchecked")
public abstract class AbstractDenseNdArray<T, U extends NdArray<T>> extends AbstractNdArray<T, U> {

  @Override
  public NdArraySequence<U> elements(int dimensionIdx) {
    if (dimensionIdx >= shape().numDimensions()) {
      throw new IllegalArgumentException("Cannot iterate elements in dimension '" + dimensionIdx +
          "' of array with shape " + shape());
    }
    if (rank() == 0 && dimensionIdx < 0) {
      return new SingleElementSequence<>(this);
    }
    DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1);
    try {
      DataBufferWindow<? extends DataBuffer<T>> elemWindow = buffer().window(elemDims.physicalSize());
      U element = instantiateView(elemWindow.buffer(), elemDims);
      return new FastElementSequence(this, dimensionIdx, element, elemWindow);
    } catch (UnsupportedOperationException e) {
      // If buffer windows are not supported, fallback to slicing (and slower) sequence
      return new SlicingElementSequence<>(this, dimensionIdx, elemDims);
    }
  }

  @Override
  public U withShape(Shape shape) {
    if (shape == null || shape.isUnknown() || shape.size() != this.shape().size()) {
      throw new IllegalArgumentException("Shape " + shape + " cannot be used to reshape ndarray of shape " + this.shape());
    }
    if (shape.equals(this.shape())) {
      return (U)this;
    }
    return instantiateView(buffer(), DimensionalSpace.create(shape));
  }

  @Override
  public U slice(long position, DimensionalSpace sliceDimensions) {
    DataBuffer<T> sliceBuffer = buffer().slice(position, sliceDimensions.physicalSize());
    return instantiateView(sliceBuffer, sliceDimensions);
  }

  @Override
  public U slice(Index... indices) {
    if (indices == null) {
      throw new IllegalArgumentException("Slicing requires at least one index");
    }
    RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices);
    return slice(sliceDimensions.position(), sliceDimensions);
  }

  @Override
  public U get(long... coords) {
    return slice(positionOf(coords, false), dimensions().from(coords.length));
  }

  @Override
  public T getObject(long... coords) {
    return buffer().getObject(positionOf(coords, true));
  }

  @Override
  public U set(NdArray<T> src, long... coordinates) {
    src.copyTo((coordinates == null || coordinates.length == 0) ? this : get(coordinates));
    return (U)this;
  }

  @Override
  public U setObject(T value, long... coords) {
    buffer().setObject(value, positionOf(coords, true));
    return (U)this;
  }

  @Override
  public U copyTo(DataBuffer<T> dst) {
    Validator.copyToBufferArgs(this, dst);
    DataTransfer.execute(buffer(), dimensions(), dst, DataTransfer::ofValue);
    return (U)this;
  }

  @Override
  public U copyFrom(DataBuffer<T> src) {
    Validator.copyFromBufferArgs(this, src);
    DataTransfer.execute(src, buffer(), dimensions(), DataTransfer::ofValue);
    return (U)this;
  }

  @Override
  public int hashCode() {
    if (dimensions().isSegmented()) {
      return slowHashCode();
    }
    final int prime = 31;
    int result = 1;
    result = prime * result + buffer().hashCode();
    result = prime * result + shape().hashCode();
    return result;
  }

  @Override
  public boolean equals(Object obj) {
    if (this == obj) {
      return true;
    }
    if (!(obj instanceof AbstractDenseNdArray)) {
      return super.equals(obj);
    }
    AbstractDenseNdArray<?, ?> other = (AbstractDenseNdArray<?, ?>)obj;
    if (dimensions().isSegmented() || other.dimensions().isSegmented()) {
      return slowEquals(other);
    }
    if (!shape().equals(other.shape())) {
      return false;
    }
    return buffer().equals(other.buffer());
  }

  /**
   * A String showing the type and shape of this dense ndarray.
   * @return A string containing the type and shape.
   */
  @Override
  public String toString() {
    return this.getClass().getSimpleName() + "(shape=" + this.shape() + ")";
  }

  protected AbstractDenseNdArray(DimensionalSpace dimensions) {
    super(dimensions);
  }

  abstract protected DataBuffer<T> buffer();

  abstract U instantiateView(DataBuffer<T> buffer, DimensionalSpace dimensions);

  long positionOf(long[] coords, boolean isValue) {
    if (coords == null || coords.length == 0) {
      return 0;
    }
    Validator.coordinates(dimensions, coords, isValue);
    return dimensions.positionOf(coords);
  }

  @Override
  protected void slowCopyTo(NdArray<T> array) {
    if (array instanceof AbstractDenseNdArray) {
      AbstractDenseNdArray<T, U> dst = (AbstractDenseNdArray)array;
      long offset = 0L;
      for (NdArray<T> s : scalars()) {
        dst.buffer().setObject(s.getObject(), offset++);
      }
    } else {
      super.slowCopyTo(array);
    }
  }
}
