1from typing import Any, Dict, List, Set
2
3from ...error import GraphQLError
4from ...language import FragmentDefinitionNode, FragmentSpreadNode, VisitorAction, SKIP
5from . import ASTValidationContext, ASTValidationRule
6
7__all__ = ["NoFragmentCyclesRule"]
8
9
10class NoFragmentCyclesRule(ASTValidationRule):
11    """No fragment cycles"""
12
13    def __init__(self, context: ASTValidationContext):
14        super().__init__(context)
15        # Tracks already visited fragments to maintain O(N) and to ensure that
16        # cycles are not redundantly reported.
17        self.visited_frags: Set[str] = set()
18        # List of AST nodes used to produce meaningful errors
19        self.spread_path: List[FragmentSpreadNode] = []
20        # Position in the spread path
21        self.spread_path_index_by_name: Dict[str, int] = {}
22
23    @staticmethod
24    def enter_operation_definition(*_args: Any) -> VisitorAction:
25        return SKIP
26
27    def enter_fragment_definition(
28        self, node: FragmentDefinitionNode, *_args: Any
29    ) -> VisitorAction:
30        self.detect_cycle_recursive(node)
31        return SKIP
32
33    def detect_cycle_recursive(self, fragment: FragmentDefinitionNode) -> None:
34        # This does a straight-forward DFS to find cycles.
35        # It does not terminate when a cycle was found but continues to explore
36        # the graph to find all possible cycles.
37        if fragment.name.value in self.visited_frags:
38            return
39
40        fragment_name = fragment.name.value
41        visited_frags = self.visited_frags
42        visited_frags.add(fragment_name)
43
44        spread_nodes = self.context.get_fragment_spreads(fragment.selection_set)
45        if not spread_nodes:
46            return
47
48        spread_path = self.spread_path
49        spread_path_index = self.spread_path_index_by_name
50        spread_path_index[fragment_name] = len(spread_path)
51        get_fragment = self.context.get_fragment
52
53        for spread_node in spread_nodes:
54            spread_name = spread_node.name.value
55            cycle_index = spread_path_index.get(spread_name)
56
57            spread_path.append(spread_node)
58            if cycle_index is None:
59                spread_fragment = get_fragment(spread_name)
60                if spread_fragment:
61                    self.detect_cycle_recursive(spread_fragment)
62            else:
63                cycle_path = spread_path[cycle_index:]
64                via_path = ", ".join("'" + s.name.value + "'" for s in cycle_path[:-1])
65                self.report_error(
66                    GraphQLError(
67                        f"Cannot spread fragment '{spread_name}' within itself"
68                        + (f" via {via_path}." if via_path else "."),
69                        cycle_path,
70                    )
71                )
72            spread_path.pop()
73
74        del spread_path_index[fragment_name]
75