package ai.libs.jaicore.ml.dyadranking.algorithm;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:ai/libs/jaicore/ml/dyadranking/algorithm/PLNetLoss.class */
public class PLNetLoss {
    private PLNetLoss() {
    }

    public static INDArray computeLoss(INDArray iNDArray) {
        if (!iNDArray.isRowVector() || iNDArray.size(1) < 2) {
            throw new IllegalArgumentException("Input has to be a row vector of 2 or more elements.");
        }
        long size = iNDArray.size(1);
        double d = 0.0d;
        for (int i = 0; i <= size - 2; i++) {
            d += Transforms.log(Transforms.exp(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(i, size)})).sum(new int[]{1})).getDouble(0L);
        }
        return Nd4j.create(new double[]{d - iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, size - 1)}).sum(new int[]{1}).getDouble(0L)});
    }

    public static INDArray computeLossGradient(INDArray iNDArray, int i) {
        if (!iNDArray.isRowVector() || iNDArray.size(1) < 2 || i < 0 || i >= iNDArray.size(1)) {
            throw new IllegalArgumentException("Input has to be a row vector of 2 or more elements. And k has to be a valid index of plNetOutputs.");
        }
        long size = iNDArray.size(1);
        double d = 0.0d;
        for (int i2 = 0; i2 <= i; i2++) {
            d += Math.exp(iNDArray.getDouble(i)) / Transforms.exp(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(i2, size)})).sum(new int[]{1}).getDouble(0L);
        }
        return Nd4j.create(new double[]{d - 1.0d});
    }
}
