package ai.libs.jaicore.ml.dyadranking.algorithm.featuretransform;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.exception.ConfigurationException;
import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModelConfiguration;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.algorithm.IPLDyadRanker;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.optimizing.BilinFunction;
import ai.libs.jaicore.ml.dyadranking.optimizing.DyadRankingFeatureTransformNegativeLogLikelihood;
import ai.libs.jaicore.ml.dyadranking.optimizing.DyadRankingFeatureTransformNegativeLogLikelihoodDerivative;
import ai.libs.jaicore.ml.dyadranking.optimizing.IDyadRankingFeatureTransformPLGradientDescendableFunction;
import ai.libs.jaicore.ml.dyadranking.optimizing.IDyadRankingFeatureTransformPLGradientFunction;
import edu.stanford.nlp.optimization.QNMinimizer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/dyadranking/algorithm/featuretransform/FeatureTransformPLDyadRanker.class */
public class FeatureTransformPLDyadRanker implements IPLDyadRanker {
    private static final Logger log = LoggerFactory.getLogger(FeatureTransformPLDyadRanker.class);
    private IDyadFeatureTransform featureTransform;
    private Vector w;
    private IDyadRankingFeatureTransformPLGradientDescendableFunction negativeLogLikelihood;
    private IDyadRankingFeatureTransformPLGradientFunction negativeLogLikelihoodDerivative;

    public FeatureTransformPLDyadRanker() {
        this(new BiliniearFeatureTransform());
    }

    public FeatureTransformPLDyadRanker(IDyadFeatureTransform iDyadFeatureTransform) {
        this.negativeLogLikelihood = new DyadRankingFeatureTransformNegativeLogLikelihood();
        this.negativeLogLikelihoodDerivative = new DyadRankingFeatureTransformNegativeLogLikelihoodDerivative();
        this.featureTransform = iDyadFeatureTransform;
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModel
    public IDyadRankingInstance predict(IDyadRankingInstance iDyadRankingInstance) throws PredictionException {
        if (this.w == null) {
            throw new PredictionException("The Ranker has not been trained yet.");
        }
        log.debug("Training ranker with instance {}", iDyadRankingInstance);
        ArrayList arrayList = new ArrayList();
        for (Dyad dyad : iDyadRankingInstance) {
            arrayList.add(new Pair(Double.valueOf(computeSkillForDyad(dyad)), dyad));
        }
        return new DyadRankingInstance((List) arrayList.stream().sorted((pair, pair2) -> {
            return Double.compare(((Double) pair.getX()).doubleValue(), ((Double) pair2.getX()).doubleValue());
        }).map((v0) -> {
            return v0.getY();
        }).collect(Collectors.toList()));
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModel
    public List<IDyadRankingInstance> predict(DyadRankingDataset dyadRankingDataset) throws PredictionException {
        ArrayList arrayList = new ArrayList();
        Iterator<IDyadRankingInstance> it = dyadRankingDataset.iterator();
        while (it.hasNext()) {
            arrayList.add(predict(it.next()));
        }
        return arrayList;
    }

    private double computeSkillForDyad(Dyad dyad) {
        Vector transform = this.featureTransform.transform(dyad);
        double dotProduct = this.w.dotProduct(transform);
        double exp = Math.exp(dotProduct);
        log.debug("Feature transform for dyad {} is {}. \n Dot-Product is {} and skill is {}", new Object[]{dyad, transform, Double.valueOf(dotProduct), Double.valueOf(exp)});
        return exp;
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IBatchLearner
    public void train(DyadRankingDataset dyadRankingDataset) throws TrainingException {
        Map<IDyadRankingInstance, Map<Dyad, Vector>> preComputedFeatureTransforms = this.featureTransform.getPreComputedFeatureTransforms(dyadRankingDataset);
        this.negativeLogLikelihood.initialize(dyadRankingDataset, preComputedFeatureTransforms);
        this.negativeLogLikelihoodDerivative.initialize(dyadRankingDataset, preComputedFeatureTransforms);
        int length = dyadRankingDataset.get(0).getDyadAtPosition(0).getAlternative().length();
        int length2 = dyadRankingDataset.get(0).getDyadAtPosition(0).getInstance().length();
        this.w = new DenseDoubleVector(this.featureTransform.getTransformedVectorLength(length, length2), 0.3d);
        log.debug("Likelihood of the randomly filled w is {}", Double.valueOf(likelihoodOfParameter(this.w, dyadRankingDataset)));
        this.w = new DenseDoubleVector(new QNMinimizer().minimize(new BilinFunction(preComputedFeatureTransforms, dyadRankingDataset, this.featureTransform.getTransformedVectorLength(length, length2)), 0.01d, this.w.asArray()));
        log.debug("Finished optimizing, the final w is {}", this.w);
    }

    private double likelihoodOfParameter(Vector vector, DyadRankingDataset dyadRankingDataset) {
        int size = dyadRankingDataset.size();
        double d = 1.0d;
        for (int i = 0; i < size; i++) {
            IDyadRankingInstance iDyadRankingInstance = dyadRankingDataset.get(i);
            int length = iDyadRankingInstance.length();
            double d2 = 1.0d;
            for (int i2 = 0; i2 < length; i2++) {
                double exp = Math.exp(vector.dotProduct(this.featureTransform.transform(iDyadRankingInstance.getDyadAtPosition(i2))));
                double d3 = 0.0d;
                for (int i3 = i2; i3 < length; i3++) {
                    d3 += Math.exp(vector.dotProduct(this.featureTransform.transform(iDyadRankingInstance.getDyadAtPosition(i3))));
                }
                d2 *= exp / d3;
            }
            d *= d2;
        }
        return d;
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModel
    public IPredictiveModelConfiguration getConfiguration() {
        return null;
    }

    @Override // ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModel
    public void setConfiguration(IPredictiveModelConfiguration iPredictiveModelConfiguration) throws ConfigurationException {
    }
}
