1from ..language.ast import FragmentDefinition, FragmentSpread, OperationDefinition
2from ..language.visitor import ParallelVisitor, TypeInfoVisitor, Visitor, visit
3from ..type import GraphQLSchema
4from ..utils.type_info import TypeInfo
5from .rules import specified_rules
6
7# Necessary for static type checking
8if False:  # flake8: noqa
9    from typing import List, Union, Optional, Dict, Set, Any, Type
10    from ..language.ast import Document, SelectionSet, Node
11    from ..error import GraphQLError
12    from .rules.base import ValidationRule
13    from ..type.definition import (
14        GraphQLObjectType,
15        GraphQLInterfaceType,
16        GraphQLField,
17        GraphQLArgument,
18        GraphQLType,
19        GraphQLInputObjectType,
20    )
21
22
23def validate(schema, ast, rules=specified_rules):
24    # type: (GraphQLSchema, Document, List[Type[ValidationRule]]) -> List
25    assert schema, "Must provide schema"
26    assert ast, "Must provide document"
27    assert isinstance(schema, GraphQLSchema)
28    type_info = TypeInfo(schema)
29    return visit_using_rules(schema, type_info, ast, rules)
30
31
32def visit_using_rules(schema, type_info, ast, rules):
33    # type: (GraphQLSchema, TypeInfo, Document, List[Type[ValidationRule]]) -> List
34    context = ValidationContext(schema, ast, type_info)
35    visitors = [rule(context) for rule in rules]
36    visit(ast, TypeInfoVisitor(type_info, ParallelVisitor(visitors)))
37    return context.get_errors()
38
39
40class VariableUsage(object):
41    __slots__ = "node", "type"
42
43    def __init__(self, node, type):
44        self.node = node
45        self.type = type
46
47
48class UsageVisitor(Visitor):
49    __slots__ = "usages", "type_info"
50
51    def __init__(self, usages, type_info):
52        # type: (List[VariableUsage], TypeInfo) -> None
53        self.usages = usages
54        self.type_info = type_info
55
56    def enter_VariableDefinition(self, node, key, parent, path, ancestors):
57        return False
58
59    def enter_Variable(self, node, key, parent, path, ancestors):
60        usage = VariableUsage(node, type=self.type_info.get_input_type())
61        self.usages.append(usage)
62
63
64class ValidationContext(object):
65    __slots__ = (
66        "_schema",
67        "_ast",
68        "_type_info",
69        "_errors",
70        "_fragments",
71        "_fragment_spreads",
72        "_recursively_referenced_fragments",
73        "_variable_usages",
74        "_recursive_variable_usages",
75    )
76
77    def __init__(self, schema, ast, type_info):
78        # type: (GraphQLSchema, Document, TypeInfo) -> None
79        self._schema = schema
80        self._ast = ast
81        self._type_info = type_info
82        self._errors = []  # type: List[GraphQLError]
83        self._fragments = None  # type: Optional[Dict[str, FragmentDefinition]]
84        self._fragment_spreads = {}  # type: Dict[Node, List[FragmentSpread]]
85        self._recursively_referenced_fragments = (
86            {}
87        )  # type: Dict[OperationDefinition, List[FragmentSpread]]
88        self._variable_usages = {}  # type: Dict[Node, List[VariableUsage]]
89        self._recursive_variable_usages = (
90            {}
91        )  # type: Dict[OperationDefinition, List[VariableUsage]]
92
93    def report_error(self, error):
94        self._errors.append(error)
95
96    def get_errors(self):
97        # type: () -> List
98        return self._errors
99
100    def get_schema(self):
101        # type: () -> GraphQLSchema
102        return self._schema
103
104    def get_variable_usages(self, node):
105        # type: (OperationDefinition) -> List[VariableUsage]
106        usages = self._variable_usages.get(node)
107        if usages is None:
108            usages = []
109            sub_visitor = UsageVisitor(usages, self._type_info)
110            visit(node, TypeInfoVisitor(self._type_info, sub_visitor))
111            self._variable_usages[node] = usages
112
113        return usages
114
115    def get_recursive_variable_usages(self, operation):
116        # type: (OperationDefinition) -> List[VariableUsage]
117        assert isinstance(operation, OperationDefinition)
118        usages = self._recursive_variable_usages.get(operation)
119        if usages is None:
120            usages = self.get_variable_usages(operation)
121            fragments = self.get_recursively_referenced_fragments(operation)
122            for fragment in fragments:
123                usages.extend(self.get_variable_usages(fragment))
124            self._recursive_variable_usages[operation] = usages
125
126        return usages
127
128    def get_recursively_referenced_fragments(self, operation):
129        # type: (OperationDefinition) -> List
130        assert isinstance(operation, OperationDefinition)
131        fragments = self._recursively_referenced_fragments.get(operation)
132        if not fragments:
133            fragments = []
134            collected_names = set()  # type: Set[str]
135            nodes_to_visit = [operation.selection_set]
136            while nodes_to_visit:
137                node = nodes_to_visit.pop()
138                spreads = self.get_fragment_spreads(node)
139                for spread in spreads:
140                    frag_name = spread.name.value
141                    if frag_name not in collected_names:
142                        collected_names.add(frag_name)
143                        fragment = self.get_fragment(frag_name)
144                        if fragment:
145                            fragments.append(fragment)
146                            nodes_to_visit.append(fragment.selection_set)
147            self._recursively_referenced_fragments[operation] = fragments
148        return fragments
149
150    def get_fragment_spreads(self, node):
151        # type: (SelectionSet) -> List[FragmentSpread]
152        spreads = self._fragment_spreads.get(node)
153        if not spreads:
154            spreads = []
155            sets_to_visit = [node]
156            while sets_to_visit:
157                _set = sets_to_visit.pop()
158                for selection in _set.selections:
159                    if isinstance(selection, FragmentSpread):
160                        spreads.append(selection)
161                    elif selection.selection_set:
162                        sets_to_visit.append(selection.selection_set)
163
164            self._fragment_spreads[node] = spreads
165        return spreads
166
167    def get_ast(self):
168        return self._ast
169
170    def get_fragment(self, name):
171        fragments = self._fragments
172        if fragments is None:
173            self._fragments = fragments = {}
174            for statement in self.get_ast().definitions:
175                if isinstance(statement, FragmentDefinition):
176                    fragments[statement.name.value] = statement
177        return fragments.get(name)
178
179    def get_type(self):
180        # type: () -> Optional[GraphQLType]
181        return self._type_info.get_type()
182
183    def get_parent_type(self):
184        # type: () -> Union[GraphQLInterfaceType, GraphQLObjectType, None]
185        return self._type_info.get_parent_type()
186
187    def get_input_type(self):
188        # type: () -> Optional[GraphQLInputObjectType]
189        return self._type_info.get_input_type()  # type: ignore
190
191    def get_field_def(self):
192        # type: () -> Optional[GraphQLField]
193        return self._type_info.get_field_def()
194
195    def get_directive(self):
196        # type: () -> Optional[Any]
197        return self._type_info.get_directive()
198
199    def get_argument(self):
200        # type: () -> Optional[GraphQLArgument]
201        return self._type_info.get_argument()
202