1# -*- coding: utf-8 -*- 2import logging 3from traceback import format_exception 4 5from ..error import GraphQLError 6from ..language import ast 7from ..pyutils.default_ordered_dict import DefaultOrderedDict 8from ..type.definition import GraphQLInterfaceType, GraphQLUnionType 9from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective 10from ..type.introspection import ( 11 SchemaMetaFieldDef, 12 TypeMetaFieldDef, 13 TypeNameMetaFieldDef, 14) 15from ..utils.type_from_ast import type_from_ast 16from .values import get_argument_values, get_variable_values 17 18# Necessary for static type checking 19if False: # flake8: noqa 20 from ..type.definition import GraphQLObjectType, GraphQLField 21 from ..type.schema import GraphQLSchema 22 from ..language.ast import ( 23 Document, 24 OperationDefinition, 25 SelectionSet, 26 Directive, 27 FragmentDefinition, 28 InlineFragment, 29 Field, 30 ) 31 from .base import ResolveInfo 32 from types import TracebackType 33 from typing import Any, List, Dict, Optional, Union, Callable, Set, Tuple 34 35logger = logging.getLogger(__name__) 36 37 38class ExecutionContext(object): 39 """Data that must be available at all points during query execution. 40 41 Namely, schema of the type system that is currently executing, 42 and the fragments defined in the query document""" 43 44 __slots__ = ( 45 "schema", 46 "fragments", 47 "root_value", 48 "operation", 49 "variable_values", 50 "errors", 51 "context_value", 52 "argument_values_cache", 53 "executor", 54 "middleware", 55 "allow_subscriptions", 56 "_subfields_cache", 57 ) 58 59 def __init__( 60 self, 61 schema, # type: GraphQLSchema 62 document_ast, # type: Document 63 root_value, # type: Any 64 context_value, # type: Any 65 variable_values, # type: Optional[Dict[str, Any]] 66 operation_name, # type: Optional[str] 67 executor, # type: Any 68 middleware, # type: Optional[Any] 69 allow_subscriptions, # type: bool 70 ): 71 # type: (...) -> None 72 """Constructs a ExecutionContext object from the arguments passed 73 to execute, which we will pass throughout the other execution 74 methods.""" 75 errors = [] # type: List[Exception] 76 operation = None 77 fragments = {} # type: Dict[str, FragmentDefinition] 78 79 for definition in document_ast.definitions: 80 if isinstance(definition, ast.OperationDefinition): 81 if not operation_name and operation: 82 raise GraphQLError( 83 "Must provide operation name if query contains multiple operations." 84 ) 85 86 if ( 87 not operation_name 88 or definition.name 89 and definition.name.value == operation_name 90 ): 91 operation = definition 92 93 elif isinstance(definition, ast.FragmentDefinition): 94 fragments[definition.name.value] = definition 95 96 else: 97 raise GraphQLError( 98 u"GraphQL cannot execute a request containing a {}.".format( 99 definition.__class__.__name__ 100 ), 101 definition, 102 ) 103 104 if not operation: 105 if operation_name: 106 raise GraphQLError( 107 u'Unknown operation named "{}".'.format(operation_name) 108 ) 109 110 else: 111 raise GraphQLError("Must provide an operation.") 112 113 variable_values = get_variable_values( 114 schema, operation.variable_definitions or [], variable_values 115 ) 116 117 self.schema = schema 118 self.fragments = fragments 119 self.root_value = root_value 120 self.operation = operation 121 self.variable_values = variable_values 122 self.errors = errors 123 self.context_value = context_value 124 self.argument_values_cache = ( 125 {} 126 ) # type: Dict[Tuple[GraphQLField, Field], Dict[str, Any]] 127 self.executor = executor 128 self.middleware = middleware 129 self.allow_subscriptions = allow_subscriptions 130 self._subfields_cache = ( 131 {} 132 ) # type: Dict[Tuple[GraphQLObjectType, Tuple[Field, ...]], DefaultOrderedDict] 133 134 def get_field_resolver(self, field_resolver): 135 # type: (Callable) -> Callable 136 if not self.middleware: 137 return field_resolver 138 return self.middleware.get_field_resolver(field_resolver) 139 140 def get_argument_values(self, field_def, field_ast): 141 # type: (GraphQLField, Field) -> Dict[str, Any] 142 k = field_def, field_ast 143 if k not in self.argument_values_cache: 144 self.argument_values_cache[k] = get_argument_values( 145 field_def.args, field_ast.arguments, self.variable_values 146 ) 147 148 return self.argument_values_cache[k] 149 150 def report_error(self, error, traceback=None): 151 # type: (Exception, Optional[TracebackType]) -> None 152 exception = format_exception( 153 type(error), error, getattr(error, "stack", None) or traceback 154 ) 155 logger.error("".join(exception)) 156 self.errors.append(error) 157 158 def get_sub_fields(self, return_type, field_asts): 159 # type: (GraphQLObjectType, List[Field]) -> DefaultOrderedDict 160 k = return_type, tuple(field_asts) 161 if k not in self._subfields_cache: 162 subfield_asts = DefaultOrderedDict(list) 163 visited_fragment_names = set() # type: Set[str] 164 for field_ast in field_asts: 165 selection_set = field_ast.selection_set 166 if selection_set: 167 subfield_asts = collect_fields( 168 self, 169 return_type, 170 selection_set, 171 subfield_asts, 172 visited_fragment_names, 173 ) 174 self._subfields_cache[k] = subfield_asts 175 return self._subfields_cache[k] 176 177 178class SubscriberExecutionContext(object): 179 __slots__ = "exe_context", "errors" 180 181 def __init__(self, exe_context): 182 # type: (ExecutionContext) -> None 183 self.exe_context = exe_context 184 self.errors = [] # type: List[Exception] 185 186 def reset(self): 187 # type: () -> None 188 self.errors = [] 189 190 def __getattr__(self, name): 191 # type: (str) -> Any 192 return getattr(self.exe_context, name) 193 194 195def get_operation_root_type(schema, operation): 196 # type: (GraphQLSchema, OperationDefinition) -> GraphQLObjectType 197 op = operation.operation 198 if op == "query": 199 return schema.get_query_type() 200 201 elif op == "mutation": 202 mutation_type = schema.get_mutation_type() 203 204 if not mutation_type: 205 raise GraphQLError("Schema is not configured for mutations", [operation]) 206 207 return mutation_type 208 209 elif op == "subscription": 210 subscription_type = schema.get_subscription_type() 211 212 if not subscription_type: 213 raise GraphQLError( 214 "Schema is not configured for subscriptions", [operation] 215 ) 216 217 return subscription_type 218 219 raise GraphQLError( 220 "Can only execute queries, mutations and subscriptions", [operation] 221 ) 222 223 224def collect_fields( 225 ctx, # type: ExecutionContext 226 runtime_type, # type: GraphQLObjectType 227 selection_set, # type: SelectionSet 228 fields, # type: DefaultOrderedDict 229 prev_fragment_names, # type: Set[str] 230): 231 # type: (...) -> DefaultOrderedDict 232 """ 233 Given a selectionSet, adds all of the fields in that selection to 234 the passed in map of fields, and returns it at the end. 235 236 collect_fields requires the "runtime type" of an object. For a field which 237 returns and Interface or Union type, the "runtime type" will be the actual 238 Object type returned by that field. 239 """ 240 for selection in selection_set.selections: 241 directives = selection.directives 242 243 if isinstance(selection, ast.Field): 244 if not should_include_node(ctx, directives): 245 continue 246 247 name = get_field_entry_key(selection) 248 fields[name].append(selection) 249 250 elif isinstance(selection, ast.InlineFragment): 251 if not should_include_node( 252 ctx, directives 253 ) or not does_fragment_condition_match(ctx, selection, runtime_type): 254 continue 255 256 collect_fields( 257 ctx, runtime_type, selection.selection_set, fields, prev_fragment_names 258 ) 259 260 elif isinstance(selection, ast.FragmentSpread): 261 frag_name = selection.name.value 262 263 if frag_name in prev_fragment_names or not should_include_node( 264 ctx, directives 265 ): 266 continue 267 268 prev_fragment_names.add(frag_name) 269 fragment = ctx.fragments[frag_name] 270 frag_directives = fragment.directives 271 if ( 272 not fragment 273 or not should_include_node(ctx, frag_directives) 274 or not does_fragment_condition_match(ctx, fragment, runtime_type) 275 ): 276 continue 277 278 collect_fields( 279 ctx, runtime_type, fragment.selection_set, fields, prev_fragment_names 280 ) 281 282 return fields 283 284 285def should_include_node(ctx, directives): 286 # type: (ExecutionContext, Optional[List[Directive]]) -> bool 287 """Determines if a field should be included based on the @include and 288 @skip directives, where @skip has higher precidence than @include.""" 289 # TODO: Refactor based on latest code 290 if directives: 291 skip_ast = None 292 293 for directive in directives: 294 if directive.name.value == GraphQLSkipDirective.name: 295 skip_ast = directive 296 break 297 298 if skip_ast: 299 args = get_argument_values( 300 GraphQLSkipDirective.args, skip_ast.arguments, ctx.variable_values 301 ) 302 if args.get("if") is True: 303 return False 304 305 include_ast = None 306 307 for directive in directives: 308 if directive.name.value == GraphQLIncludeDirective.name: 309 include_ast = directive 310 break 311 312 if include_ast: 313 args = get_argument_values( 314 GraphQLIncludeDirective.args, include_ast.arguments, ctx.variable_values 315 ) 316 317 if args.get("if") is False: 318 return False 319 320 return True 321 322 323def does_fragment_condition_match( 324 ctx, # type: ExecutionContext 325 fragment, # type: Union[FragmentDefinition, InlineFragment] 326 type_, # type: GraphQLObjectType 327): 328 # type: (...) -> bool 329 type_condition_ast = fragment.type_condition 330 if not type_condition_ast: 331 return True 332 333 conditional_type = type_from_ast(ctx.schema, type_condition_ast) 334 if conditional_type.is_same_type(type_): 335 return True 336 337 if isinstance(conditional_type, (GraphQLInterfaceType, GraphQLUnionType)): 338 return ctx.schema.is_possible_type(conditional_type, type_) 339 340 return False 341 342 343def get_field_entry_key(node): 344 # type: (Field) -> str 345 """Implements the logic to compute the key of a given field's entry""" 346 if node.alias: 347 return node.alias.value 348 return node.name.value 349 350 351def default_resolve_fn(source, info, **args): 352 # type: (Any, ResolveInfo, **Any) -> Optional[Any] 353 """If a resolve function is not given, then a default resolve behavior is used which takes the property of the source object 354 of the same name as the field and returns it as the result, or if it's a function, returns the result of calling that function.""" 355 name = info.field_name 356 if isinstance(source, dict): 357 property = source.get(name) 358 else: 359 property = getattr(source, name, None) 360 if callable(property): 361 return property() 362 return property 363 364 365def get_field_def( 366 schema, # type: GraphQLSchema 367 parent_type, # type: GraphQLObjectType 368 field_name, # type: str 369): 370 # type: (...) -> Optional[GraphQLField] 371 """This method looks up the field on the given type defintion. 372 It has special casing for the two introspection fields, __schema 373 and __typename. __typename is special because it can always be 374 queried as a field, even in situations where no other fields 375 are allowed, like on a Union. __schema could get automatically 376 added to the query type, but that would require mutating type 377 definitions, which would cause issues.""" 378 if field_name == "__schema" and schema.get_query_type() == parent_type: 379 return SchemaMetaFieldDef 380 elif field_name == "__type" and schema.get_query_type() == parent_type: 381 return TypeMetaFieldDef 382 elif field_name == "__typename": 383 return TypeNameMetaFieldDef 384 return parent_type.fields.get(field_name) 385