package won.matcher.utils.tensor;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import org.la4j.Matrices;
import org.la4j.matrix.sparse.CCSMatrix;
import org.la4j.vector.functor.VectorProcedure;

/* loaded from: input_file:won/matcher/utils/tensor/ThirdOrderSparseTensor.class */
public class ThirdOrderSparseTensor {
    private int[] dims = null;
    private ArrayList<CCSMatrix> slices = new ArrayList<>();

    /* loaded from: input_file:won/matcher/utils/tensor/ThirdOrderSparseTensor$NonZeroVectorProcedure.class */
    private class NonZeroVectorProcedure implements VectorProcedure {
        private List<Integer> nonZeroIndices = new LinkedList();

        public NonZeroVectorProcedure() {
        }

        @Override // org.la4j.vector.functor.VectorProcedure
        public void apply(int i, double d) {
            this.nonZeroIndices.add(Integer.valueOf(i));
        }

        public Collection<Integer> getNonZeroIndices() {
            return this.nonZeroIndices;
        }
    }

    public ThirdOrderSparseTensor(int i, int i2) {
        resize(i, i2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void resize(int i, int i2) {
        for (int i3 = 0; i3 < this.slices.size(); i3++) {
            if (this.slices.get(i3) != null) {
                this.slices.set(i3, this.slices.get(i3).copyOfShape(i, i2).to(Matrices.CCS));
            }
        }
        this.dims = new int[]{i, i2, this.slices.size()};
    }

    public void setEntry(double d, int i, int i2, int i3) {
        if (this.slices.size() <= i3) {
            for (int size = this.slices.size(); size <= i3; size++) {
                this.slices.add(size, CCSMatrix.zero(this.dims[0], this.dims[1]));
            }
            this.dims = new int[]{this.dims[0], this.dims[1], this.slices.size()};
        }
        this.slices.get(i3).set(i, i2, d);
    }

    public double getEntry(int i, int i2, int i3) {
        return this.slices.get(i3).get(i, i2);
    }

    public int getNonZeroEntries(int i) {
        return this.slices.get(i).cardinality();
    }

    public int[] getDimensions() {
        return this.dims;
    }

    public void writeSliceToFile(String str, int i) throws IOException {
        new FileOutputStream(new File(str)).write(this.slices.get(i).toMatrixMarket(DecimalFormat.getInstance(Locale.US)).replace("column-major", "").getBytes());
    }

    public Collection<Integer> getNonZeroIndicesOfRow(int i, int i2) {
        NonZeroVectorProcedure nonZeroVectorProcedure = new NonZeroVectorProcedure();
        this.slices.get(i2).eachNonZeroInRow(i, nonZeroVectorProcedure);
        return nonZeroVectorProcedure.getNonZeroIndices();
    }

    public boolean hasNonZeroEntryInRow(int i, int i2) {
        return this.slices.get(i2).getRow(i).max() > 0.0d;
    }
}
