package ai.libs.jaicore.ml.dyadranking.dataset;

import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.ml.core.dataset.IOrderedLabeledDataset;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/dyadranking/dataset/DyadRankingDataset.class */
public class DyadRankingDataset extends ArrayList<IDyadRankingInstance> implements IOrderedLabeledDataset<IDyadRankingInstance, IDyadRankingInstance> {
    private transient Logger logger;
    private static final long serialVersionUID = -1102494546233523992L;

    public DyadRankingDataset() {
        this.logger = LoggerFactory.getLogger(DyadRankingDataset.class);
    }

    public DyadRankingDataset(Collection<IDyadRankingInstance> collection) {
        super(collection);
        this.logger = LoggerFactory.getLogger(DyadRankingDataset.class);
    }

    public DyadRankingDataset(int i) {
        super(i);
        this.logger = LoggerFactory.getLogger(DyadRankingDataset.class);
    }

    public DyadRankingDataset(List<IDyadRankingInstance> list) {
        super(list);
        this.logger = LoggerFactory.getLogger(DyadRankingDataset.class);
    }

    public void serialize(OutputStream outputStream) {
        try {
            Iterator<IDyadRankingInstance> it = iterator();
            while (it.hasNext()) {
                for (Dyad dyad : it.next()) {
                    outputStream.write(dyad.getInstance().toString().getBytes());
                    outputStream.write(";".getBytes());
                    outputStream.write(dyad.getAlternative().toString().getBytes());
                    outputStream.write("|".getBytes());
                }
                outputStream.write("\n".getBytes());
            }
        } catch (IOException e) {
            this.logger.warn(e.getMessage());
        }
    }

    public void deserialize(InputStream inputStream) {
        clear();
        try {
            LineIterator lineIterator = IOUtils.lineIterator(inputStream, StandardCharsets.UTF_8);
            while (lineIterator.hasNext()) {
                String next = lineIterator.next();
                if (next.isEmpty()) {
                    break;
                }
                LinkedList linkedList = new LinkedList();
                for (String str : next.split("\\|")) {
                    String[] split = str.split(";");
                    if (split[0].length() > 1 && split[1].length() > 1) {
                        String[] split2 = split[0].substring(1, split[0].length() - 1).split(",");
                        String[] split3 = split[1].substring(1, split[1].length() - 1).split(",");
                        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(split2.length);
                        for (int i = 0; i < split2.length; i++) {
                            denseDoubleVector.setValue(i, Double.parseDouble(split2[i]));
                        }
                        DenseDoubleVector denseDoubleVector2 = new DenseDoubleVector(split3.length);
                        for (int i2 = 0; i2 < split3.length; i2++) {
                            denseDoubleVector2.setValue(i2, Double.parseDouble(split3[i2]));
                        }
                        linkedList.add(new Dyad(denseDoubleVector, denseDoubleVector2));
                    }
                }
                add(new DyadRankingInstance(linkedList));
            }
        } catch (IOException e) {
            this.logger.warn(e.getMessage());
        }
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.Collection, java.util.List
    public boolean equals(Object obj) {
        if (!(obj instanceof DyadRankingDataset)) {
            return false;
        }
        DyadRankingDataset dyadRankingDataset = (DyadRankingDataset) obj;
        if (dyadRankingDataset.size() != size()) {
            return false;
        }
        for (int i = 0; i < dyadRankingDataset.size(); i++) {
            if (!get(i).equals(dyadRankingDataset.get(i))) {
                return false;
            }
        }
        return true;
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.Collection, java.util.List
    public int hashCode() {
        int i = 17;
        Iterator<IDyadRankingInstance> it = iterator();
        while (it.hasNext()) {
            i = (i * 31) + it.next().hashCode();
        }
        return i;
    }

    public List<INDArray> toND4j() {
        ArrayList arrayList = new ArrayList();
        Iterator<IDyadRankingInstance> it = iterator();
        while (it.hasNext()) {
            arrayList.add(dyadRankingToMatrix(it.next()));
        }
        return arrayList;
    }

    private INDArray dyadToVector(Dyad dyad) {
        return Nd4j.hstack(new INDArray[]{Nd4j.create(dyad.getInstance().asArray()), Nd4j.create(dyad.getAlternative().asArray())});
    }

    public static DyadRankingDataset fromOrderedDyadList(List<Dyad> list) {
        return new DyadRankingDataset((List<IDyadRankingInstance>) Arrays.asList(new DyadRankingInstance(list)));
    }

    private INDArray dyadRankingToMatrix(IDyadRankingInstance iDyadRankingInstance) {
        ArrayList arrayList = new ArrayList(iDyadRankingInstance.length());
        Iterator<Dyad> it = iDyadRankingInstance.iterator();
        while (it.hasNext()) {
            arrayList.add(dyadToVector(it.next()));
        }
        return Nd4j.vstack(arrayList);
    }

    @Override // ai.libs.jaicore.ml.core.dataset.IDataset
    public DyadRankingDataset createEmpty() {
        return new DyadRankingDataset();
    }

    @Override // ai.libs.jaicore.ml.core.dataset.IDataset
    public int getFrequency(IDyadRankingInstance iDyadRankingInstance) {
        Stream stream = stream();
        Objects.requireNonNull(iDyadRankingInstance);
        return (int) stream.filter((v1) -> {
            return r1.equals(v1);
        }).count();
    }
}
