1from ..language.ast import FragmentDefinition, FragmentSpread, OperationDefinition 2from ..language.visitor import ParallelVisitor, TypeInfoVisitor, Visitor, visit 3from ..type import GraphQLSchema 4from ..utils.type_info import TypeInfo 5from .rules import specified_rules 6 7# Necessary for static type checking 8if False: # flake8: noqa 9 from typing import List, Union, Optional, Dict, Set, Any, Type 10 from ..language.ast import Document, SelectionSet, Node 11 from ..error import GraphQLError 12 from .rules.base import ValidationRule 13 from ..type.definition import ( 14 GraphQLObjectType, 15 GraphQLInterfaceType, 16 GraphQLField, 17 GraphQLArgument, 18 GraphQLType, 19 GraphQLInputObjectType, 20 ) 21 22 23def validate(schema, ast, rules=specified_rules): 24 # type: (GraphQLSchema, Document, List[Type[ValidationRule]]) -> List 25 assert schema, "Must provide schema" 26 assert ast, "Must provide document" 27 assert isinstance(schema, GraphQLSchema) 28 type_info = TypeInfo(schema) 29 return visit_using_rules(schema, type_info, ast, rules) 30 31 32def visit_using_rules(schema, type_info, ast, rules): 33 # type: (GraphQLSchema, TypeInfo, Document, List[Type[ValidationRule]]) -> List 34 context = ValidationContext(schema, ast, type_info) 35 visitors = [rule(context) for rule in rules] 36 visit(ast, TypeInfoVisitor(type_info, ParallelVisitor(visitors))) 37 return context.get_errors() 38 39 40class VariableUsage(object): 41 __slots__ = "node", "type" 42 43 def __init__(self, node, type): 44 self.node = node 45 self.type = type 46 47 48class UsageVisitor(Visitor): 49 __slots__ = "usages", "type_info" 50 51 def __init__(self, usages, type_info): 52 # type: (List[VariableUsage], TypeInfo) -> None 53 self.usages = usages 54 self.type_info = type_info 55 56 def enter_VariableDefinition(self, node, key, parent, path, ancestors): 57 return False 58 59 def enter_Variable(self, node, key, parent, path, ancestors): 60 usage = VariableUsage(node, type=self.type_info.get_input_type()) 61 self.usages.append(usage) 62 63 64class ValidationContext(object): 65 __slots__ = ( 66 "_schema", 67 "_ast", 68 "_type_info", 69 "_errors", 70 "_fragments", 71 "_fragment_spreads", 72 "_recursively_referenced_fragments", 73 "_variable_usages", 74 "_recursive_variable_usages", 75 ) 76 77 def __init__(self, schema, ast, type_info): 78 # type: (GraphQLSchema, Document, TypeInfo) -> None 79 self._schema = schema 80 self._ast = ast 81 self._type_info = type_info 82 self._errors = [] # type: List[GraphQLError] 83 self._fragments = None # type: Optional[Dict[str, FragmentDefinition]] 84 self._fragment_spreads = {} # type: Dict[Node, List[FragmentSpread]] 85 self._recursively_referenced_fragments = ( 86 {} 87 ) # type: Dict[OperationDefinition, List[FragmentSpread]] 88 self._variable_usages = {} # type: Dict[Node, List[VariableUsage]] 89 self._recursive_variable_usages = ( 90 {} 91 ) # type: Dict[OperationDefinition, List[VariableUsage]] 92 93 def report_error(self, error): 94 self._errors.append(error) 95 96 def get_errors(self): 97 # type: () -> List 98 return self._errors 99 100 def get_schema(self): 101 # type: () -> GraphQLSchema 102 return self._schema 103 104 def get_variable_usages(self, node): 105 # type: (OperationDefinition) -> List[VariableUsage] 106 usages = self._variable_usages.get(node) 107 if usages is None: 108 usages = [] 109 sub_visitor = UsageVisitor(usages, self._type_info) 110 visit(node, TypeInfoVisitor(self._type_info, sub_visitor)) 111 self._variable_usages[node] = usages 112 113 return usages 114 115 def get_recursive_variable_usages(self, operation): 116 # type: (OperationDefinition) -> List[VariableUsage] 117 assert isinstance(operation, OperationDefinition) 118 usages = self._recursive_variable_usages.get(operation) 119 if usages is None: 120 usages = self.get_variable_usages(operation) 121 fragments = self.get_recursively_referenced_fragments(operation) 122 for fragment in fragments: 123 usages.extend(self.get_variable_usages(fragment)) 124 self._recursive_variable_usages[operation] = usages 125 126 return usages 127 128 def get_recursively_referenced_fragments(self, operation): 129 # type: (OperationDefinition) -> List 130 assert isinstance(operation, OperationDefinition) 131 fragments = self._recursively_referenced_fragments.get(operation) 132 if not fragments: 133 fragments = [] 134 collected_names = set() # type: Set[str] 135 nodes_to_visit = [operation.selection_set] 136 while nodes_to_visit: 137 node = nodes_to_visit.pop() 138 spreads = self.get_fragment_spreads(node) 139 for spread in spreads: 140 frag_name = spread.name.value 141 if frag_name not in collected_names: 142 collected_names.add(frag_name) 143 fragment = self.get_fragment(frag_name) 144 if fragment: 145 fragments.append(fragment) 146 nodes_to_visit.append(fragment.selection_set) 147 self._recursively_referenced_fragments[operation] = fragments 148 return fragments 149 150 def get_fragment_spreads(self, node): 151 # type: (SelectionSet) -> List[FragmentSpread] 152 spreads = self._fragment_spreads.get(node) 153 if not spreads: 154 spreads = [] 155 sets_to_visit = [node] 156 while sets_to_visit: 157 _set = sets_to_visit.pop() 158 for selection in _set.selections: 159 if isinstance(selection, FragmentSpread): 160 spreads.append(selection) 161 elif selection.selection_set: 162 sets_to_visit.append(selection.selection_set) 163 164 self._fragment_spreads[node] = spreads 165 return spreads 166 167 def get_ast(self): 168 return self._ast 169 170 def get_fragment(self, name): 171 fragments = self._fragments 172 if fragments is None: 173 self._fragments = fragments = {} 174 for statement in self.get_ast().definitions: 175 if isinstance(statement, FragmentDefinition): 176 fragments[statement.name.value] = statement 177 return fragments.get(name) 178 179 def get_type(self): 180 # type: () -> Optional[GraphQLType] 181 return self._type_info.get_type() 182 183 def get_parent_type(self): 184 # type: () -> Union[GraphQLInterfaceType, GraphQLObjectType, None] 185 return self._type_info.get_parent_type() 186 187 def get_input_type(self): 188 # type: () -> Optional[GraphQLInputObjectType] 189 return self._type_info.get_input_type() # type: ignore 190 191 def get_field_def(self): 192 # type: () -> Optional[GraphQLField] 193 return self._type_info.get_field_def() 194 195 def get_directive(self): 196 # type: () -> Optional[Any] 197 return self._type_info.get_directive() 198 199 def get_argument(self): 200 # type: () -> Optional[GraphQLArgument] 201 return self._type_info.get_argument() 202