package jaicore.search.algorithms.standard.mcts;

import ai.libs.jaicore.basic.ILoggingCustomizable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:jaicore/search/algorithms/standard/mcts/UCBPolicy.class */
public class UCBPolicy<T, A> implements IPathUpdatablePolicy<T, A, Double>, ILoggingCustomizable {
    private String loggerName;
    private Logger logger;
    private final boolean maximize;
    private final Map<T, UCBPolicy<T, A>.NodeLabel> labels;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:jaicore/search/algorithms/standard/mcts/UCBPolicy$NodeLabel.class */
    public class NodeLabel {
        private final DescriptiveStatistics scores = new DescriptiveStatistics();
        private int visits;

        NodeLabel() {
        }

        public String toString() {
            return "NodeLabel [scores=" + this.scores + ", visits=" + this.visits + "]";
        }

        static /* synthetic */ int access$008(NodeLabel nodeLabel) {
            int i = nodeLabel.visits;
            nodeLabel.visits = i + 1;
            return i;
        }
    }

    public UCBPolicy() {
        this(true);
    }

    public UCBPolicy(boolean z) {
        this.logger = LoggerFactory.getLogger(UCBPolicy.class);
        this.labels = new HashMap();
        this.maximize = z;
    }

    @Override // jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy
    public void updatePath(List<T> list, Double d) {
        this.logger.debug("Updating path {} with score {}", list, d);
        for (T t : list) {
            if (!this.labels.containsKey(t)) {
                this.labels.put(t, new NodeLabel());
            }
            UCBPolicy<T, A>.NodeLabel nodeLabel = this.labels.get(t);
            NodeLabel.access$008(nodeLabel);
            ((NodeLabel) nodeLabel).scores.addValue(d.doubleValue());
            this.logger.trace("Updated label of node {}. Visits now {}, stats contains {} entries with mean {}", new Object[]{t, Integer.valueOf(((NodeLabel) nodeLabel).visits), Long.valueOf(((NodeLabel) nodeLabel).scores.getN()), Double.valueOf(((NodeLabel) nodeLabel).scores.getMean())});
        }
    }

    @Override // jaicore.search.algorithms.standard.mcts.IPolicy
    public A getAction(T t, Map<A, T> map) {
        Set<A> keySet = map.keySet();
        this.logger.debug("Deriving action for node {}. The {} options are: {}", new Object[]{t, Integer.valueOf(keySet.size()), map});
        List list = (List) keySet.stream().filter(obj -> {
            return !this.labels.containsKey(map.get(obj));
        }).collect(Collectors.toList());
        if (!list.isEmpty()) {
            A a = (A) list.get(0);
            this.labels.put(map.get(a), new NodeLabel());
            this.logger.info("Dictating action {}, because this was never played before.", a);
            return a;
        }
        double d = this.maximize ? Double.MIN_VALUE : Double.MAX_VALUE;
        this.logger.debug("All actions have been tried. Label is: {}", this.labels.get(t));
        int i = ((NodeLabel) this.labels.get(t)).visits;
        A a2 = null;
        for (A a3 : keySet) {
            T t2 = map.get(a3);
            UCBPolicy<T, A>.NodeLabel nodeLabel = this.labels.get(t2);
            if (!$assertionsDisabled && ((NodeLabel) nodeLabel).visits == 0) {
                throw new AssertionError("Visits of node " + t2 + " cannot be 0 if we already used this action before!");
            }
            if (!$assertionsDisabled && ((NodeLabel) nodeLabel).scores.getN() == 0) {
                throw new AssertionError("Number of observations cannot be 0 if we already visited this node before");
            }
            this.logger.trace("Considering action {} whose successor state has stats {} and {} visits", new Object[]{a3, Double.valueOf(((NodeLabel) nodeLabel).scores.getMean()), Integer.valueOf(((NodeLabel) nodeLabel).visits)});
            double mean = ((NodeLabel) nodeLabel).scores.getMean() + ((this.maximize ? 1 : -1) * Math.sqrt((2.0d * Math.log(i)) / ((NodeLabel) nodeLabel).visits));
            if (!$assertionsDisabled && new Double(mean).equals(Double.valueOf(Double.NaN))) {
                throw new AssertionError("The UCB score is NaN, which cannot be the case. Score mean is " + ((NodeLabel) nodeLabel).scores.getMean() + ", number of visits is " + ((NodeLabel) nodeLabel).visits);
            }
            if ((!this.maximize || mean <= d) && (this.maximize || mean >= d)) {
                this.logger.trace("Skipping current solution {} since its score {} is not better than the currently best {}.", new Object[]{a3, Double.valueOf(mean), Double.valueOf(d)});
            } else {
                this.logger.trace("Updating best choice {} with {} since it is better than the current solution with performance {}", new Object[]{a2, a3, Double.valueOf(d)});
                d = mean;
                a2 = a3;
            }
        }
        if (!$assertionsDisabled && a2 == null) {
            throw new AssertionError("Would return null, but this must not be the case!");
        }
        this.logger.info("Recommending action {}.", a2);
        return a2;
    }

    public String getLoggerName() {
        return this.loggerName;
    }

    public void setLoggerName(String str) {
        this.loggerName = str;
        this.logger = LoggerFactory.getLogger(str);
    }

    static {
        $assertionsDisabled = !UCBPolicy.class.desiredAssertionStatus();
    }
}
