package ai.grakn.graql.internal.gremlin;

import ai.grakn.GraknGraph;
import ai.grakn.graql.admin.PatternAdmin;
import ai.grakn.graql.internal.gremlin.fragment.Fragment;
import ai.grakn.graql.internal.gremlin.spanningtree.Arborescence;
import ai.grakn.graql.internal.gremlin.spanningtree.ChuLiuEdmonds;
import ai.grakn.graql.internal.gremlin.spanningtree.graph.DirectedEdge;
import ai.grakn.graql.internal.gremlin.spanningtree.graph.Node;
import ai.grakn.graql.internal.gremlin.spanningtree.graph.NodeId;
import ai.grakn.graql.internal.gremlin.spanningtree.graph.SparseWeightedGraph;
import ai.grakn.util.CommonUtil;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/grakn/graql/internal/gremlin/GreedyTraversalPlan.class */
public class GreedyTraversalPlan {
    protected static final Logger LOG = LoggerFactory.getLogger(GreedyTraversalPlan.class);

    public static GraqlTraversal createTraversal(PatternAdmin patternAdmin, GraknGraph graknGraph) {
        return GraqlTraversal.create((Set) patternAdmin.getDisjunctiveNormalForm().getPatterns().stream().map(conjunction -> {
            return new ConjunctionQuery(conjunction, graknGraph);
        }).map(GreedyTraversalPlan::planForConjunction).collect(CommonUtil.toImmutableSet()));
    }

    private static List<Fragment> planForConjunction(ConjunctionQuery conjunctionQuery) {
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        getConnectedFragmentSets(conjunctionQuery, hashMap).forEach(set -> {
            HashSet hashSet = new HashSet();
            HashMap hashMap2 = new HashMap();
            HashSet hashSet2 = new HashSet();
            HashSet hashSet3 = new HashSet();
            set.stream().filter(filterNodeFragment(arrayList, hashMap, hashSet, hashSet2)).flatMap(fragment -> {
                return fragment.getDirectedEdges(hashMap, hashMap2).stream();
            }).forEach(weighted -> {
                hashSet3.add(weighted);
                hashSet.add(((DirectedEdge) weighted.val).destination);
                hashSet.add(((DirectedEdge) weighted.val).source);
            });
            if (!hashSet3.isEmpty()) {
                SparseWeightedGraph from = SparseWeightedGraph.from(hashSet3);
                greedyTraversal(arrayList, (Arborescence) (hashSet2.isEmpty() ? (Collection) from.getNodes().stream().filter((v0) -> {
                    return v0.isValidStartingPoint();
                }).collect(Collectors.toSet()) : hashSet2).stream().map(node -> {
                    return ChuLiuEdmonds.getMaxArborescence(from, node);
                }).max(Comparator.comparingDouble(weighted2 -> {
                    return weighted2.weight;
                })).map(weighted3 -> {
                    return (Arborescence) weighted3.val;
                }).orElse(Arborescence.empty()), hashMap, hashMap2);
            }
            addUnvisitedNodeFragments(arrayList, hashMap, hashSet);
        });
        addUnvisitedNodeFragments(arrayList, hashMap, hashMap.values());
        LOG.trace("Greedy Plan = " + arrayList);
        return arrayList;
    }

    private static void addUnvisitedNodeFragments(List<Fragment> list, Map<NodeId, Node> map, Collection<Node> collection) {
        Object collect = collection.stream().filter(node -> {
            return (node.getFragmentsWithoutDependency().isEmpty() && node.getFragmentsWithDependencyVisited().isEmpty()) ? false : true;
        }).collect(Collectors.toSet());
        while (true) {
            Set set = (Set) collect;
            if (set.isEmpty()) {
                return;
            }
            set.forEach(node2 -> {
                addNodeFragmentToPlan(node2, list, map, false);
            });
            collect = collection.stream().filter(node3 -> {
                return (node3.getFragmentsWithoutDependency().isEmpty() && node3.getFragmentsWithDependencyVisited().isEmpty()) ? false : true;
            }).collect(Collectors.toSet());
        }
    }

    private static Predicate<Fragment> filterNodeFragment(List<Fragment> list, Map<NodeId, Node> map, Set<Node> set, Set<Node> set2) {
        return fragment -> {
            if (fragment.getEnd().isPresent()) {
                return true;
            }
            Node addIfAbsent = Node.addIfAbsent(NodeId.NodeType.VAR, fragment.getStart(), (Map<NodeId, Node>) map);
            set.add(addIfAbsent);
            if (fragment.hasFixedFragmentCost()) {
                list.add(fragment);
                set2.add(addIfAbsent);
                return false;
            }
            if (!fragment.mo15getDependencies().isEmpty()) {
                return false;
            }
            addIfAbsent.getFragmentsWithoutDependency().add(fragment);
            return false;
        };
    }

    private static Collection<Set<Fragment>> getConnectedFragmentSets(ConjunctionQuery conjunctionQuery, Map<NodeId, Node> map) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        int[] iArr = {0};
        conjunctionQuery.getEquivalentFragmentSets().stream().flatMap((v0) -> {
            return v0.stream();
        }).forEach(fragment -> {
            if (!fragment.mo15getDependencies().isEmpty()) {
                Node addIfAbsent = Node.addIfAbsent(NodeId.NodeType.VAR, fragment.getStart(), (Map<NodeId, Node>) map);
                Node addIfAbsent2 = Node.addIfAbsent(NodeId.NodeType.VAR, fragment.mo15getDependencies().iterator().next(), (Map<NodeId, Node>) map);
                addIfAbsent.getFragmentsWithDependency().add(fragment);
                addIfAbsent2.getDependants().add(fragment);
                if (fragment.getEquivalentFragmentSet().fragments().size() == 1) {
                    addIfAbsent2.getFragmentsWithDependency().add(fragment);
                    addIfAbsent.getDependants().add(fragment);
                }
            }
            HashSet newHashSet = Sets.newHashSet(fragment.getVariableNames());
            ArrayList arrayList = new ArrayList();
            hashMap.forEach((num, set) -> {
                if (Collections.disjoint(set, newHashSet)) {
                    return;
                }
                arrayList.add(num);
            });
            if (arrayList.isEmpty()) {
                iArr[0] = iArr[0] + 1;
                hashMap.put(Integer.valueOf(iArr[0]), newHashSet);
                hashMap2.put(Integer.valueOf(iArr[0]), Sets.newHashSet(new Fragment[]{fragment}));
                return;
            }
            Iterator it = arrayList.iterator();
            Integer num2 = (Integer) it.next();
            ((Set) hashMap.get(num2)).addAll(newHashSet);
            ((Set) hashMap2.get(num2)).add(fragment);
            while (it.hasNext()) {
                Integer num3 = (Integer) it.next();
                ((Set) hashMap.get(num2)).addAll((Collection) hashMap.remove(num3));
                ((Set) hashMap2.get(num2)).addAll((Collection) hashMap2.remove(num3));
            }
        });
        return hashMap2.values();
    }

    private static void greedyTraversal(List<Fragment> list, Arborescence<Node> arborescence, Map<NodeId, Node> map, Map<Node, Map<Node, Fragment>> map2) {
        HashMap hashMap = new HashMap();
        arborescence.getParents().forEach((node, node2) -> {
            if (!hashMap.containsKey(node2)) {
                hashMap.put(node2, new HashSet());
            }
            ((Set) hashMap.get(node2)).add(node);
        });
        HashSet newHashSet = Sets.newHashSet(new Node[]{arborescence.getRoot()});
        while (!newHashSet.isEmpty()) {
            Node node3 = (Node) newHashSet.stream().min(Comparator.comparingDouble(node4 -> {
                return getEdgeFragmentCost(node4, arborescence, map2);
            })).get();
            Optional<Fragment> edgeFragment = getEdgeFragment(node3, arborescence, map2);
            list.getClass();
            edgeFragment.ifPresent((v1) -> {
                r1.add(v1);
            });
            addNodeFragmentToPlan(node3, list, map, true);
            newHashSet.remove(node3);
            if (hashMap.containsKey(node3)) {
                newHashSet.addAll((Collection) hashMap.get(node3));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void addNodeFragmentToPlan(Node node, List<Fragment> list, Map<NodeId, Node> map, boolean z) {
        if (!z) {
            node.getFragmentsWithoutDependency().stream().min(Comparator.comparingDouble((v0) -> {
                return v0.fragmentCost();
            })).ifPresent(fragment -> {
                list.add(fragment);
                node.getFragmentsWithoutDependency().remove(fragment);
            });
        }
        node.getFragmentsWithoutDependency().addAll(node.getFragmentsWithDependencyVisited());
        list.addAll((Collection) node.getFragmentsWithoutDependency().stream().sorted(Comparator.comparingDouble((v0) -> {
            return v0.fragmentCost();
        })).collect(Collectors.toList()));
        node.getFragmentsWithoutDependency().clear();
        node.getFragmentsWithDependencyVisited().clear();
        if (!node.getFragmentsWithDependency().isEmpty()) {
            node.getDependants().forEach(fragment2 -> {
                Node node2 = (Node) map.get(new NodeId(NodeId.NodeType.VAR, fragment2.getStart()));
                if (node.equals(node2)) {
                    node2 = (Node) map.get(new NodeId(NodeId.NodeType.VAR, fragment2.mo15getDependencies().iterator().next()));
                }
                node2.getFragmentsWithDependencyVisited().add(fragment2);
                node2.getFragmentsWithDependency().remove(fragment2);
            });
            node.getFragmentsWithDependency().clear();
        }
        node.getDependants().clear();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double getEdgeFragmentCost(Node node, Arborescence<Node> arborescence, Map<Node, Map<Node, Fragment>> map) {
        return ((Double) getEdgeFragment(node, arborescence, map).map((v0) -> {
            return v0.fragmentCost();
        }).orElse(Double.valueOf(0.0d))).doubleValue();
    }

    private static Optional<Fragment> getEdgeFragment(Node node, Arborescence<Node> arborescence, Map<Node, Map<Node, Fragment>> map) {
        return (map.containsKey(node) && map.get(node).containsKey(arborescence.getParents().get(node))) ? Optional.of(map.get(node).get(arborescence.getParents().get(node))) : Optional.empty();
    }
}
