1from typing import Any, Dict, List, Set
2
3from ..language import (
4    DocumentNode,
5    FragmentDefinitionNode,
6    FragmentSpreadNode,
7    OperationDefinitionNode,
8    SelectionSetNode,
9    Visitor,
10    visit,
11)
12
13__all__ = ["separate_operations"]
14
15
16DepGraph = Dict[str, List[str]]
17
18
19def separate_operations(document_ast: DocumentNode) -> Dict[str, DocumentNode]:
20    """Separate operations in a given AST document.
21
22    This function accepts a single AST document which may contain many operations and
23    fragments and returns a collection of AST documents each of which contains a single
24    operation as well the fragment definitions it refers to.
25    """
26    operations: List[OperationDefinitionNode] = []
27    dep_graph: DepGraph = {}
28
29    # Populate metadata and build a dependency graph.
30    for definition_node in document_ast.definitions:
31        if isinstance(definition_node, OperationDefinitionNode):
32            operations.append(definition_node)
33        elif isinstance(
34            definition_node, FragmentDefinitionNode
35        ):  # pragma: no cover else
36            dep_graph[definition_node.name.value] = collect_dependencies(
37                definition_node.selection_set
38            )
39
40    # For each operation, produce a new synthesized AST which includes only what is
41    # necessary for completing that operation.
42    separated_document_asts: Dict[str, DocumentNode] = {}
43    for operation in operations:
44        dependencies: Set[str] = set()
45
46        for fragment_name in collect_dependencies(operation.selection_set):
47            collect_transitive_dependencies(dependencies, dep_graph, fragment_name)
48
49        # Provides the empty string for anonymous operations.
50        operation_name = operation.name.value if operation.name else ""
51
52        # The list of definition nodes to be included for this operation, sorted
53        # to retain the same order as the original document.
54        separated_document_asts[operation_name] = DocumentNode(
55            definitions=[
56                node
57                for node in document_ast.definitions
58                if node is operation
59                or (
60                    isinstance(node, FragmentDefinitionNode)
61                    and node.name.value in dependencies
62                )
63            ]
64        )
65
66    return separated_document_asts
67
68
69def collect_transitive_dependencies(
70    collected: Set[str], dep_graph: DepGraph, from_name: str
71) -> None:
72    """Collect transitive dependencies.
73
74    From a dependency graph, collects a list of transitive dependencies by recursing
75    through a dependency graph.
76    """
77    if from_name not in collected:
78        collected.add(from_name)
79
80        immediate_deps = dep_graph.get(from_name)
81        if immediate_deps is not None:
82            for to_name in immediate_deps:
83                collect_transitive_dependencies(collected, dep_graph, to_name)
84
85
86class DependencyCollector(Visitor):
87    dependencies: List[str]
88
89    def __init__(self) -> None:
90        super().__init__()
91        self.dependencies = []
92        self.add_dependency = self.dependencies.append
93
94    def enter_fragment_spread(self, node: FragmentSpreadNode, *_args: Any) -> None:
95        self.add_dependency(node.name.value)
96
97
98def collect_dependencies(selection_set: SelectionSetNode) -> List[str]:
99    collector = DependencyCollector()
100    visit(selection_set, collector)
101    return collector.dependencies
102