1import collections
2
3try:
4    from collections.abc import Iterable
5except ImportError:  # Python < 3.3
6    from collections import Iterable
7import functools
8import logging
9import sys
10import warnings
11from rx import Observable
12
13from six import string_types
14from promise import Promise, promise_for_dict, is_thenable
15
16from ..error import GraphQLError, GraphQLLocatedError
17from ..pyutils.default_ordered_dict import DefaultOrderedDict
18from ..pyutils.ordereddict import OrderedDict
19from ..utils.undefined import Undefined
20from ..type import (
21    GraphQLEnumType,
22    GraphQLInterfaceType,
23    GraphQLList,
24    GraphQLNonNull,
25    GraphQLObjectType,
26    GraphQLScalarType,
27    GraphQLSchema,
28    GraphQLUnionType,
29)
30from .base import (
31    ExecutionContext,
32    ExecutionResult,
33    ResolveInfo,
34    collect_fields,
35    default_resolve_fn,
36    get_field_def,
37    get_operation_root_type,
38    SubscriberExecutionContext,
39)
40from .executors.sync import SyncExecutor
41from .middleware import MiddlewareManager
42
43# Necessary for static type checking
44if False:  # flake8: noqa
45    from typing import Any, Optional, Union, Dict, List, Callable
46    from ..language.ast import Document, OperationDefinition, Field
47
48logger = logging.getLogger(__name__)
49
50
51def subscribe(*args, **kwargs):
52    # type: (*Any, **Any) -> Union[ExecutionResult, Observable]
53    allow_subscriptions = kwargs.pop("allow_subscriptions", True)
54    return execute(  # type: ignore
55        *args, allow_subscriptions=allow_subscriptions, **kwargs
56    )
57
58
59def execute(
60    schema,  # type: GraphQLSchema
61    document_ast,  # type: Document
62    root=None,  # type: Any
63    context=None,  # type: Optional[Any]
64    variables=None,  # type: Optional[Any]
65    operation_name=None,  # type: Optional[str]
66    executor=None,  # type: Any
67    return_promise=False,  # type: bool
68    middleware=None,  # type: Optional[Any]
69    allow_subscriptions=False,  # type: bool
70    **options  # type: Any
71):
72    # type: (...) -> Union[ExecutionResult, Promise[ExecutionResult]]
73
74    if root is None and "root_value" in options:
75        warnings.warn(
76            "root_value has been deprecated. Please use root=... instead.",
77            category=DeprecationWarning,
78            stacklevel=2,
79        )
80        root = options["root_value"]
81    if context is None and "context_value" in options:
82        warnings.warn(
83            "context_value has been deprecated. Please use context=... instead.",
84            category=DeprecationWarning,
85            stacklevel=2,
86        )
87        context = options["context_value"]
88    if variables is None and "variable_values" in options:
89        warnings.warn(
90            "variable_values has been deprecated. Please use variables=... instead.",
91            category=DeprecationWarning,
92            stacklevel=2,
93        )
94        variables = options["variable_values"]
95    assert schema, "Must provide schema"
96    assert isinstance(schema, GraphQLSchema), (
97        "Schema must be an instance of GraphQLSchema. Also ensure that there are "
98        + "not multiple versions of GraphQL installed in your node_modules directory."
99    )
100
101    if middleware:
102        if not isinstance(middleware, MiddlewareManager):
103            middleware = MiddlewareManager(*middleware)
104
105        assert isinstance(middleware, MiddlewareManager), (
106            "middlewares have to be an instance"
107            ' of MiddlewareManager. Received "{}".'.format(middleware)
108        )
109
110    if executor is None:
111        executor = SyncExecutor()
112
113    exe_context = ExecutionContext(
114        schema,
115        document_ast,
116        root,
117        context,
118        variables or {},
119        operation_name,
120        executor,
121        middleware,
122        allow_subscriptions,
123    )
124
125    def promise_executor(v):
126        # type: (Optional[Any]) -> Union[Dict, Promise[Dict], Observable]
127        return execute_operation(exe_context, exe_context.operation, root)
128
129    def on_rejected(error):
130        # type: (Exception) -> None
131        exe_context.errors.append(error)
132        return None
133
134    def on_resolve(data):
135        # type: (Union[None, Dict, Observable]) -> Union[ExecutionResult, Observable]
136        if isinstance(data, Observable):
137            return data
138
139        if not exe_context.errors:
140            return ExecutionResult(data=data)
141
142        return ExecutionResult(data=data, errors=exe_context.errors)
143
144    promise = (
145        Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve)
146    )
147
148    if not return_promise:
149        exe_context.executor.wait_until_finished()
150        return promise.get()
151    else:
152        clean = getattr(exe_context.executor, "clean", None)
153        if clean:
154            clean()
155
156    return promise
157
158
159def execute_operation(
160    exe_context,  # type: ExecutionContext
161    operation,  # type: OperationDefinition
162    root_value,  # type: Any
163):
164    # type: (...) -> Union[Dict, Promise[Dict]]
165    type = get_operation_root_type(exe_context.schema, operation)
166    fields = collect_fields(
167        exe_context, type, operation.selection_set, DefaultOrderedDict(list), set()
168    )
169
170    if operation.operation == "mutation":
171        return execute_fields_serially(exe_context, type, root_value, [], fields)
172
173    if operation.operation == "subscription":
174        if not exe_context.allow_subscriptions:
175            raise Exception(
176                "Subscriptions are not allowed. "
177                "You will need to either use the subscribe function "
178                "or pass allow_subscriptions=True"
179            )
180        return subscribe_fields(exe_context, type, root_value, fields)
181
182    return execute_fields(exe_context, type, root_value, fields, [], None)
183
184
185def execute_fields_serially(
186    exe_context,  # type: ExecutionContext
187    parent_type,  # type: GraphQLObjectType
188    source_value,  # type: Any
189    path,  # type: List
190    fields,  # type: DefaultOrderedDict
191):
192    # type: (...) -> Promise
193    def execute_field_callback(results, response_name):
194        # type: (Dict, str) -> Union[Dict, Promise[Dict]]
195        field_asts = fields[response_name]
196        result = resolve_field(
197            exe_context,
198            parent_type,
199            source_value,
200            field_asts,
201            None,
202            path + [response_name],
203        )
204        if result is Undefined:
205            return results
206
207        if is_thenable(result):
208
209            def collect_result(resolved_result):
210                # type: (Dict) -> Dict
211                results[response_name] = resolved_result
212                return results
213
214            return result.then(collect_result, None)
215
216        results[response_name] = result
217        return results
218
219    def execute_field(prev_promise, response_name):
220        # type: (Promise, str) -> Promise
221        return prev_promise.then(
222            lambda results: execute_field_callback(results, response_name)
223        )
224
225    return functools.reduce(
226        execute_field, fields.keys(), Promise.resolve(collections.OrderedDict())
227    )
228
229
230def execute_fields(
231    exe_context,  # type: ExecutionContext
232    parent_type,  # type: GraphQLObjectType
233    source_value,  # type: Any
234    fields,  # type: DefaultOrderedDict
235    path,  # type: List[Union[int, str]]
236    info,  # type: Optional[ResolveInfo]
237):
238    # type: (...) -> Union[Dict, Promise[Dict]]
239    contains_promise = False
240
241    final_results = OrderedDict()
242
243    for response_name, field_asts in fields.items():
244        result = resolve_field(
245            exe_context,
246            parent_type,
247            source_value,
248            field_asts,
249            info,
250            path + [response_name],
251        )
252        if result is Undefined:
253            continue
254
255        final_results[response_name] = result
256        if is_thenable(result):
257            contains_promise = True
258
259    if not contains_promise:
260        return final_results
261
262    return promise_for_dict(final_results)
263
264
265def subscribe_fields(
266    exe_context,  # type: ExecutionContext
267    parent_type,  # type: GraphQLObjectType
268    source_value,  # type: Any
269    fields,  # type: DefaultOrderedDict
270):
271    # type: (...) -> Observable
272    subscriber_exe_context = SubscriberExecutionContext(exe_context)
273
274    def on_error(error):
275        subscriber_exe_context.report_error(error)
276
277    def map_result(data):
278        # type: (Dict[str, Any]) -> ExecutionResult
279        if subscriber_exe_context.errors:
280            result = ExecutionResult(data=data, errors=subscriber_exe_context.errors)
281        else:
282            result = ExecutionResult(data=data)
283        subscriber_exe_context.reset()
284        return result
285
286    observables = []  # type: List[Observable]
287
288    # assert len(fields) == 1, "Can only subscribe one element at a time."
289
290    for response_name, field_asts in fields.items():
291        result = subscribe_field(
292            subscriber_exe_context,
293            parent_type,
294            source_value,
295            field_asts,
296            [response_name],
297        )
298        if result is Undefined:
299            continue
300
301        def catch_error(error):
302            subscriber_exe_context.errors.append(error)
303            return Observable.just(None)
304
305        # Map observable results
306        observable = result.catch_exception(catch_error).map(
307            lambda data: map_result({response_name: data})
308        )
309        return observable
310        observables.append(observable)
311
312    return Observable.merge(observables)
313
314
315def resolve_field(
316    exe_context,  # type: ExecutionContext
317    parent_type,  # type: GraphQLObjectType
318    source,  # type: Any
319    field_asts,  # type: List[Field]
320    parent_info,  # type: Optional[ResolveInfo]
321    field_path,  # type: List[Union[int, str]]
322):
323    # type: (...) -> Any
324    field_ast = field_asts[0]
325    field_name = field_ast.name.value
326
327    field_def = get_field_def(exe_context.schema, parent_type, field_name)
328    if not field_def:
329        return Undefined
330
331    return_type = field_def.type
332    resolve_fn = field_def.resolver or default_resolve_fn
333
334    # We wrap the resolve_fn from the middleware
335    resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn)
336
337    # Build a dict of arguments from the field.arguments AST, using the variables scope to
338    # fulfill any variable references.
339    args = exe_context.get_argument_values(field_def, field_ast)
340
341    # The resolve function's optional third argument is a context value that
342    # is provided to every resolve function within an execution. It is commonly
343    # used to represent an authenticated user, or request-specific caches.
344    context = exe_context.context_value
345
346    # The resolve function's optional third argument is a collection of
347    # information about the current execution state.
348    info = ResolveInfo(
349        field_name,
350        field_asts,
351        return_type,
352        parent_type,
353        schema=exe_context.schema,
354        fragments=exe_context.fragments,
355        root_value=exe_context.root_value,
356        operation=exe_context.operation,
357        variable_values=exe_context.variable_values,
358        context=context,
359        path=field_path,
360    )
361
362    executor = exe_context.executor
363    result = resolve_or_error(resolve_fn_middleware, source, info, args, executor)
364
365    return complete_value_catching_error(
366        exe_context, return_type, field_asts, info, field_path, result
367    )
368
369
370def subscribe_field(
371    exe_context,  # type: SubscriberExecutionContext
372    parent_type,  # type: GraphQLObjectType
373    source,  # type: Any
374    field_asts,  # type: List[Field]
375    path,  # type: List[str]
376):
377    # type: (...) -> Observable
378    field_ast = field_asts[0]
379    field_name = field_ast.name.value
380
381    field_def = get_field_def(exe_context.schema, parent_type, field_name)
382    if not field_def:
383        return Undefined
384
385    return_type = field_def.type
386    resolve_fn = field_def.resolver or default_resolve_fn
387
388    # We wrap the resolve_fn from the middleware
389    resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn)
390
391    # Build a dict of arguments from the field.arguments AST, using the variables scope to
392    # fulfill any variable references.
393    args = exe_context.get_argument_values(field_def, field_ast)
394
395    # The resolve function's optional third argument is a context value that
396    # is provided to every resolve function within an execution. It is commonly
397    # used to represent an authenticated user, or request-specific caches.
398    context = exe_context.context_value
399
400    # The resolve function's optional third argument is a collection of
401    # information about the current execution state.
402    info = ResolveInfo(
403        field_name,
404        field_asts,
405        return_type,
406        parent_type,
407        schema=exe_context.schema,
408        fragments=exe_context.fragments,
409        root_value=exe_context.root_value,
410        operation=exe_context.operation,
411        variable_values=exe_context.variable_values,
412        context=context,
413        path=path,
414    )
415
416    executor = exe_context.executor
417    result = resolve_or_error(resolve_fn_middleware, source, info, args, executor)
418
419    if isinstance(result, Exception):
420        raise result
421
422    if not isinstance(result, Observable):
423        raise GraphQLError(
424            "Subscription must return Async Iterable or Observable. Received: {}".format(
425                repr(result)
426            )
427        )
428
429    return result.map(
430        functools.partial(
431            complete_value_catching_error,
432            exe_context,
433            return_type,
434            field_asts,
435            info,
436            path,
437        )
438    )
439
440
441def resolve_or_error(
442    resolve_fn,  # type: Callable
443    source,  # type: Any
444    info,  # type: ResolveInfo
445    args,  # type: Dict
446    executor,  # type: Any
447):
448    # type: (...) -> Any
449    try:
450        return executor.execute(resolve_fn, source, info, **args)
451    except Exception as e:
452        logger.exception(
453            "An error occurred while resolving field {}.{}".format(
454                info.parent_type.name, info.field_name
455            )
456        )
457        e.stack = sys.exc_info()[2]  # type: ignore
458        return e
459
460
461def complete_value_catching_error(
462    exe_context,  # type: ExecutionContext
463    return_type,  # type: Any
464    field_asts,  # type: List[Field]
465    info,  # type: ResolveInfo
466    path,  # type: List[Union[int, str]]
467    result,  # type: Any
468):
469    # type: (...) -> Any
470    # If the field type is non-nullable, then it is resolved without any
471    # protection from errors.
472    if isinstance(return_type, GraphQLNonNull):
473        return complete_value(exe_context, return_type, field_asts, info, path, result)
474
475    # Otherwise, error protection is applied, logging the error and
476    # resolving a null value for this field if one is encountered.
477    try:
478        completed = complete_value(
479            exe_context, return_type, field_asts, info, path, result
480        )
481        if is_thenable(completed):
482
483            def handle_error(error):
484                # type: (Union[GraphQLError, GraphQLLocatedError]) -> Optional[Any]
485                traceback = completed._traceback  # type: ignore
486                exe_context.report_error(error, traceback)
487                return None
488
489            return completed.catch(handle_error)
490
491        return completed
492    except Exception as e:
493        traceback = sys.exc_info()[2]
494        exe_context.report_error(e, traceback)
495        return None
496
497
498def complete_value(
499    exe_context,  # type: ExecutionContext
500    return_type,  # type: Any
501    field_asts,  # type: List[Field]
502    info,  # type: ResolveInfo
503    path,  # type: List[Union[int, str]]
504    result,  # type: Any
505):
506    # type: (...) -> Any
507    """
508    Implements the instructions for completeValue as defined in the
509    "Field entries" section of the spec.
510
511    If the field type is Non-Null, then this recursively completes the value for the inner type. It throws a field
512    error if that completion returns null, as per the "Nullability" section of the spec.
513
514    If the field type is a List, then this recursively completes the value for the inner type on each item in the
515    list.
516
517    If the field type is a Scalar or Enum, ensures the completed value is a legal value of the type by calling the
518    `serialize` method of GraphQL type definition.
519
520    If the field is an abstract type, determine the runtime type of the value and then complete based on that type.
521
522    Otherwise, the field type expects a sub-selection set, and will complete the value by evaluating all
523    sub-selections.
524    """
525    # If field type is NonNull, complete for inner type, and throw field error
526    # if result is null.
527    if is_thenable(result):
528        return Promise.resolve(result).then(
529            lambda resolved: complete_value(
530                exe_context, return_type, field_asts, info, path, resolved
531            ),
532            lambda error: Promise.rejected(
533                GraphQLLocatedError(field_asts, original_error=error, path=path)
534            ),
535        )
536
537    # print return_type, type(result)
538    if isinstance(result, Exception):
539        raise GraphQLLocatedError(field_asts, original_error=result, path=path)
540
541    if isinstance(return_type, GraphQLNonNull):
542        return complete_nonnull_value(
543            exe_context, return_type, field_asts, info, path, result
544        )
545
546    # If result is null-like, return null.
547    if result is None:
548        return None
549
550    # If field type is List, complete each item in the list with the inner type
551    if isinstance(return_type, GraphQLList):
552        return complete_list_value(
553            exe_context, return_type, field_asts, info, path, result
554        )
555
556    # If field type is Scalar or Enum, serialize to a valid value, returning
557    # null if coercion is not possible.
558    if isinstance(return_type, (GraphQLScalarType, GraphQLEnumType)):
559        return complete_leaf_value(return_type, path, result)
560
561    if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)):
562        return complete_abstract_value(
563            exe_context, return_type, field_asts, info, path, result
564        )
565
566    if isinstance(return_type, GraphQLObjectType):
567        return complete_object_value(
568            exe_context, return_type, field_asts, info, path, result
569        )
570
571    assert False, u'Cannot complete value of unexpected type "{}".'.format(return_type)
572
573
574def complete_list_value(
575    exe_context,  # type: ExecutionContext
576    return_type,  # type: GraphQLList
577    field_asts,  # type: List[Field]
578    info,  # type: ResolveInfo
579    path,  # type: List[Union[int, str]]
580    result,  # type: Any
581):
582    # type: (...) -> List[Any]
583    """
584    Complete a list value by completing each item in the list with the inner type
585    """
586    assert isinstance(result, Iterable), (
587        "User Error: expected iterable, but did not find one " + "for field {}.{}."
588    ).format(info.parent_type, info.field_name)
589
590    item_type = return_type.of_type
591    completed_results = []
592    contains_promise = False
593
594    index = 0
595    for item in result:
596        completed_item = complete_value_catching_error(
597            exe_context, item_type, field_asts, info, path + [index], item
598        )
599        if not contains_promise and is_thenable(completed_item):
600            contains_promise = True
601
602        completed_results.append(completed_item)
603        index += 1
604
605    return Promise.all(completed_results) if contains_promise else completed_results
606
607
608def complete_leaf_value(
609    return_type,  # type: Union[GraphQLEnumType, GraphQLScalarType]
610    path,  # type: List[Union[int, str]]
611    result,  # type: Any
612):
613    # type: (...) -> Union[int, str, float, bool]
614    """
615    Complete a Scalar or Enum by serializing to a valid value, returning null if serialization is not possible.
616    """
617    assert hasattr(return_type, "serialize"), "Missing serialize method on type"
618    serialized_result = return_type.serialize(result)
619
620    if serialized_result is None:
621        raise GraphQLError(
622            ('Expected a value of type "{}" but ' + "received: {}").format(
623                return_type, result
624            ),
625            path=path,
626        )
627    return serialized_result
628
629
630def complete_abstract_value(
631    exe_context,  # type: ExecutionContext
632    return_type,  # type: Union[GraphQLInterfaceType, GraphQLUnionType]
633    field_asts,  # type: List[Field]
634    info,  # type: ResolveInfo
635    path,  # type: List[Union[int, str]]
636    result,  # type: Any
637):
638    # type: (...) -> Dict[str, Any]
639    """
640    Complete an value of an abstract type by determining the runtime type of that value, then completing based
641    on that type.
642    """
643    runtime_type = None  # type: Union[str, GraphQLObjectType, None]
644
645    # Field type must be Object, Interface or Union and expect sub-selections.
646    if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)):
647        if return_type.resolve_type:
648            runtime_type = return_type.resolve_type(result, info)
649        else:
650            runtime_type = get_default_resolve_type_fn(result, info, return_type)
651
652    if isinstance(runtime_type, string_types):
653        runtime_type = info.schema.get_type(runtime_type)  # type: ignore
654
655    if not isinstance(runtime_type, GraphQLObjectType):
656        raise GraphQLError(
657            (
658                "Abstract type {} must resolve to an Object type at runtime "
659                + 'for field {}.{} with value "{}", received "{}".'
660            ).format(
661                return_type, info.parent_type, info.field_name, result, runtime_type
662            ),
663            field_asts,
664        )
665
666    if not exe_context.schema.is_possible_type(return_type, runtime_type):
667        raise GraphQLError(
668            u'Runtime Object type "{}" is not a possible type for "{}".'.format(
669                runtime_type, return_type
670            ),
671            field_asts,
672        )
673
674    return complete_object_value(
675        exe_context, runtime_type, field_asts, info, path, result
676    )
677
678
679def get_default_resolve_type_fn(
680    value,  # type: Any
681    info,  # type: ResolveInfo
682    abstract_type,  # type: Union[GraphQLInterfaceType, GraphQLUnionType]
683):
684    # type: (...) -> Optional[GraphQLObjectType]
685    possible_types = info.schema.get_possible_types(abstract_type)
686    for type in possible_types:
687        if callable(type.is_type_of) and type.is_type_of(value, info):
688            return type
689    return None
690
691
692def complete_object_value(
693    exe_context,  # type: ExecutionContext
694    return_type,  # type: GraphQLObjectType
695    field_asts,  # type: List[Field]
696    info,  # type: ResolveInfo
697    path,  # type: List[Union[int, str]]
698    result,  # type: Any
699):
700    # type: (...) -> Dict[str, Any]
701    """
702    Complete an Object value by evaluating all sub-selections.
703    """
704    if return_type.is_type_of and not return_type.is_type_of(result, info):
705        raise GraphQLError(
706            u'Expected value of type "{}" but got: {}.'.format(
707                return_type, type(result).__name__
708            ),
709            field_asts,
710        )
711
712    # Collect sub-fields to execute to complete this value.
713    subfield_asts = exe_context.get_sub_fields(return_type, field_asts)
714    return execute_fields(exe_context, return_type, result, subfield_asts, path, info)
715
716
717def complete_nonnull_value(
718    exe_context,  # type: ExecutionContext
719    return_type,  # type: GraphQLNonNull
720    field_asts,  # type: List[Field]
721    info,  # type: ResolveInfo
722    path,  # type: List[Union[int, str]]
723    result,  # type: Any
724):
725    # type: (...) -> Any
726    """
727    Complete a NonNull value by completing the inner type
728    """
729    completed = complete_value(
730        exe_context, return_type.of_type, field_asts, info, path, result
731    )
732    if completed is None:
733        raise GraphQLError(
734            "Cannot return null for non-nullable field {}.{}.".format(
735                info.parent_type, info.field_name
736            ),
737            field_asts,
738            path=path,
739        )
740
741    return completed
742