1from typing import Any, Dict, List, Set 2 3from ...error import GraphQLError 4from ...language import FragmentDefinitionNode, FragmentSpreadNode, VisitorAction, SKIP 5from . import ASTValidationContext, ASTValidationRule 6 7__all__ = ["NoFragmentCyclesRule"] 8 9 10class NoFragmentCyclesRule(ASTValidationRule): 11 """No fragment cycles""" 12 13 def __init__(self, context: ASTValidationContext): 14 super().__init__(context) 15 # Tracks already visited fragments to maintain O(N) and to ensure that 16 # cycles are not redundantly reported. 17 self.visited_frags: Set[str] = set() 18 # List of AST nodes used to produce meaningful errors 19 self.spread_path: List[FragmentSpreadNode] = [] 20 # Position in the spread path 21 self.spread_path_index_by_name: Dict[str, int] = {} 22 23 @staticmethod 24 def enter_operation_definition(*_args: Any) -> VisitorAction: 25 return SKIP 26 27 def enter_fragment_definition( 28 self, node: FragmentDefinitionNode, *_args: Any 29 ) -> VisitorAction: 30 self.detect_cycle_recursive(node) 31 return SKIP 32 33 def detect_cycle_recursive(self, fragment: FragmentDefinitionNode) -> None: 34 # This does a straight-forward DFS to find cycles. 35 # It does not terminate when a cycle was found but continues to explore 36 # the graph to find all possible cycles. 37 if fragment.name.value in self.visited_frags: 38 return 39 40 fragment_name = fragment.name.value 41 visited_frags = self.visited_frags 42 visited_frags.add(fragment_name) 43 44 spread_nodes = self.context.get_fragment_spreads(fragment.selection_set) 45 if not spread_nodes: 46 return 47 48 spread_path = self.spread_path 49 spread_path_index = self.spread_path_index_by_name 50 spread_path_index[fragment_name] = len(spread_path) 51 get_fragment = self.context.get_fragment 52 53 for spread_node in spread_nodes: 54 spread_name = spread_node.name.value 55 cycle_index = spread_path_index.get(spread_name) 56 57 spread_path.append(spread_node) 58 if cycle_index is None: 59 spread_fragment = get_fragment(spread_name) 60 if spread_fragment: 61 self.detect_cycle_recursive(spread_fragment) 62 else: 63 cycle_path = spread_path[cycle_index:] 64 via_path = ", ".join("'" + s.name.value + "'" for s in cycle_path[:-1]) 65 self.report_error( 66 GraphQLError( 67 f"Cannot spread fragment '{spread_name}' within itself" 68 + (f" via {via_path}." if via_path else "."), 69 cycle_path, 70 ) 71 ) 72 spread_path.pop() 73 74 del spread_path_index[fragment_name] 75