1# -*- coding: utf-8 -*-
2import logging
3from traceback import format_exception
4
5from ..error import GraphQLError
6from ..language import ast
7from ..pyutils.default_ordered_dict import DefaultOrderedDict
8from ..type.definition import GraphQLInterfaceType, GraphQLUnionType
9from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective
10from ..type.introspection import (
11    SchemaMetaFieldDef,
12    TypeMetaFieldDef,
13    TypeNameMetaFieldDef,
14)
15from ..utils.type_from_ast import type_from_ast
16from .values import get_argument_values, get_variable_values
17
18# Necessary for static type checking
19if False:  # flake8: noqa
20    from ..type.definition import GraphQLObjectType, GraphQLField
21    from ..type.schema import GraphQLSchema
22    from ..language.ast import (
23        Document,
24        OperationDefinition,
25        SelectionSet,
26        Directive,
27        FragmentDefinition,
28        InlineFragment,
29        Field,
30    )
31    from .base import ResolveInfo
32    from types import TracebackType
33    from typing import Any, List, Dict, Optional, Union, Callable, Set, Tuple
34
35logger = logging.getLogger(__name__)
36
37
38class ExecutionContext(object):
39    """Data that must be available at all points during query execution.
40
41    Namely, schema of the type system that is currently executing,
42    and the fragments defined in the query document"""
43
44    __slots__ = (
45        "schema",
46        "fragments",
47        "root_value",
48        "operation",
49        "variable_values",
50        "errors",
51        "context_value",
52        "argument_values_cache",
53        "executor",
54        "middleware",
55        "allow_subscriptions",
56        "_subfields_cache",
57    )
58
59    def __init__(
60        self,
61        schema,  # type: GraphQLSchema
62        document_ast,  # type: Document
63        root_value,  # type: Any
64        context_value,  # type: Any
65        variable_values,  # type: Optional[Dict[str, Any]]
66        operation_name,  # type: Optional[str]
67        executor,  # type: Any
68        middleware,  # type: Optional[Any]
69        allow_subscriptions,  # type: bool
70    ):
71        # type: (...) -> None
72        """Constructs a ExecutionContext object from the arguments passed
73        to execute, which we will pass throughout the other execution
74        methods."""
75        errors = []  # type: List[Exception]
76        operation = None
77        fragments = {}  # type: Dict[str, FragmentDefinition]
78
79        for definition in document_ast.definitions:
80            if isinstance(definition, ast.OperationDefinition):
81                if not operation_name and operation:
82                    raise GraphQLError(
83                        "Must provide operation name if query contains multiple operations."
84                    )
85
86                if (
87                    not operation_name
88                    or definition.name
89                    and definition.name.value == operation_name
90                ):
91                    operation = definition
92
93            elif isinstance(definition, ast.FragmentDefinition):
94                fragments[definition.name.value] = definition
95
96            else:
97                raise GraphQLError(
98                    u"GraphQL cannot execute a request containing a {}.".format(
99                        definition.__class__.__name__
100                    ),
101                    definition,
102                )
103
104        if not operation:
105            if operation_name:
106                raise GraphQLError(
107                    u'Unknown operation named "{}".'.format(operation_name)
108                )
109
110            else:
111                raise GraphQLError("Must provide an operation.")
112
113        variable_values = get_variable_values(
114            schema, operation.variable_definitions or [], variable_values
115        )
116
117        self.schema = schema
118        self.fragments = fragments
119        self.root_value = root_value
120        self.operation = operation
121        self.variable_values = variable_values
122        self.errors = errors
123        self.context_value = context_value
124        self.argument_values_cache = (
125            {}
126        )  # type: Dict[Tuple[GraphQLField, Field], Dict[str, Any]]
127        self.executor = executor
128        self.middleware = middleware
129        self.allow_subscriptions = allow_subscriptions
130        self._subfields_cache = (
131            {}
132        )  # type: Dict[Tuple[GraphQLObjectType, Tuple[Field, ...]], DefaultOrderedDict]
133
134    def get_field_resolver(self, field_resolver):
135        # type: (Callable) -> Callable
136        if not self.middleware:
137            return field_resolver
138        return self.middleware.get_field_resolver(field_resolver)
139
140    def get_argument_values(self, field_def, field_ast):
141        # type: (GraphQLField, Field) -> Dict[str, Any]
142        k = field_def, field_ast
143        if k not in self.argument_values_cache:
144            self.argument_values_cache[k] = get_argument_values(
145                field_def.args, field_ast.arguments, self.variable_values
146            )
147
148        return self.argument_values_cache[k]
149
150    def report_error(self, error, traceback=None):
151        # type: (Exception, Optional[TracebackType]) -> None
152        exception = format_exception(
153            type(error), error, getattr(error, "stack", None) or traceback
154        )
155        logger.error("".join(exception))
156        self.errors.append(error)
157
158    def get_sub_fields(self, return_type, field_asts):
159        # type: (GraphQLObjectType, List[Field]) -> DefaultOrderedDict
160        k = return_type, tuple(field_asts)
161        if k not in self._subfields_cache:
162            subfield_asts = DefaultOrderedDict(list)
163            visited_fragment_names = set()  # type: Set[str]
164            for field_ast in field_asts:
165                selection_set = field_ast.selection_set
166                if selection_set:
167                    subfield_asts = collect_fields(
168                        self,
169                        return_type,
170                        selection_set,
171                        subfield_asts,
172                        visited_fragment_names,
173                    )
174            self._subfields_cache[k] = subfield_asts
175        return self._subfields_cache[k]
176
177
178class SubscriberExecutionContext(object):
179    __slots__ = "exe_context", "errors"
180
181    def __init__(self, exe_context):
182        # type: (ExecutionContext) -> None
183        self.exe_context = exe_context
184        self.errors = []  # type: List[Exception]
185
186    def reset(self):
187        # type: () -> None
188        self.errors = []
189
190    def __getattr__(self, name):
191        # type: (str) -> Any
192        return getattr(self.exe_context, name)
193
194
195def get_operation_root_type(schema, operation):
196    # type: (GraphQLSchema, OperationDefinition) -> GraphQLObjectType
197    op = operation.operation
198    if op == "query":
199        return schema.get_query_type()
200
201    elif op == "mutation":
202        mutation_type = schema.get_mutation_type()
203
204        if not mutation_type:
205            raise GraphQLError("Schema is not configured for mutations", [operation])
206
207        return mutation_type
208
209    elif op == "subscription":
210        subscription_type = schema.get_subscription_type()
211
212        if not subscription_type:
213            raise GraphQLError(
214                "Schema is not configured for subscriptions", [operation]
215            )
216
217        return subscription_type
218
219    raise GraphQLError(
220        "Can only execute queries, mutations and subscriptions", [operation]
221    )
222
223
224def collect_fields(
225    ctx,  # type: ExecutionContext
226    runtime_type,  # type: GraphQLObjectType
227    selection_set,  # type: SelectionSet
228    fields,  # type: DefaultOrderedDict
229    prev_fragment_names,  # type: Set[str]
230):
231    # type: (...) -> DefaultOrderedDict
232    """
233    Given a selectionSet, adds all of the fields in that selection to
234    the passed in map of fields, and returns it at the end.
235
236    collect_fields requires the "runtime type" of an object. For a field which
237    returns and Interface or Union type, the "runtime type" will be the actual
238    Object type returned by that field.
239    """
240    for selection in selection_set.selections:
241        directives = selection.directives
242
243        if isinstance(selection, ast.Field):
244            if not should_include_node(ctx, directives):
245                continue
246
247            name = get_field_entry_key(selection)
248            fields[name].append(selection)
249
250        elif isinstance(selection, ast.InlineFragment):
251            if not should_include_node(
252                ctx, directives
253            ) or not does_fragment_condition_match(ctx, selection, runtime_type):
254                continue
255
256            collect_fields(
257                ctx, runtime_type, selection.selection_set, fields, prev_fragment_names
258            )
259
260        elif isinstance(selection, ast.FragmentSpread):
261            frag_name = selection.name.value
262
263            if frag_name in prev_fragment_names or not should_include_node(
264                ctx, directives
265            ):
266                continue
267
268            prev_fragment_names.add(frag_name)
269            fragment = ctx.fragments[frag_name]
270            frag_directives = fragment.directives
271            if (
272                not fragment
273                or not should_include_node(ctx, frag_directives)
274                or not does_fragment_condition_match(ctx, fragment, runtime_type)
275            ):
276                continue
277
278            collect_fields(
279                ctx, runtime_type, fragment.selection_set, fields, prev_fragment_names
280            )
281
282    return fields
283
284
285def should_include_node(ctx, directives):
286    # type: (ExecutionContext, Optional[List[Directive]]) -> bool
287    """Determines if a field should be included based on the @include and
288    @skip directives, where @skip has higher precidence than @include."""
289    # TODO: Refactor based on latest code
290    if directives:
291        skip_ast = None
292
293        for directive in directives:
294            if directive.name.value == GraphQLSkipDirective.name:
295                skip_ast = directive
296                break
297
298        if skip_ast:
299            args = get_argument_values(
300                GraphQLSkipDirective.args, skip_ast.arguments, ctx.variable_values
301            )
302            if args.get("if") is True:
303                return False
304
305        include_ast = None
306
307        for directive in directives:
308            if directive.name.value == GraphQLIncludeDirective.name:
309                include_ast = directive
310                break
311
312        if include_ast:
313            args = get_argument_values(
314                GraphQLIncludeDirective.args, include_ast.arguments, ctx.variable_values
315            )
316
317            if args.get("if") is False:
318                return False
319
320    return True
321
322
323def does_fragment_condition_match(
324    ctx,  # type: ExecutionContext
325    fragment,  # type: Union[FragmentDefinition, InlineFragment]
326    type_,  # type: GraphQLObjectType
327):
328    # type: (...) -> bool
329    type_condition_ast = fragment.type_condition
330    if not type_condition_ast:
331        return True
332
333    conditional_type = type_from_ast(ctx.schema, type_condition_ast)
334    if conditional_type.is_same_type(type_):
335        return True
336
337    if isinstance(conditional_type, (GraphQLInterfaceType, GraphQLUnionType)):
338        return ctx.schema.is_possible_type(conditional_type, type_)
339
340    return False
341
342
343def get_field_entry_key(node):
344    # type: (Field) -> str
345    """Implements the logic to compute the key of a given field's entry"""
346    if node.alias:
347        return node.alias.value
348    return node.name.value
349
350
351def default_resolve_fn(source, info, **args):
352    # type: (Any, ResolveInfo, **Any) -> Optional[Any]
353    """If a resolve function is not given, then a default resolve behavior is used which takes the property of the source object
354    of the same name as the field and returns it as the result, or if it's a function, returns the result of calling that function."""
355    name = info.field_name
356    if isinstance(source, dict):
357        property = source.get(name)
358    else:
359        property = getattr(source, name, None)
360    if callable(property):
361        return property()
362    return property
363
364
365def get_field_def(
366    schema,  # type: GraphQLSchema
367    parent_type,  # type: GraphQLObjectType
368    field_name,  # type: str
369):
370    # type: (...) -> Optional[GraphQLField]
371    """This method looks up the field on the given type defintion.
372    It has special casing for the two introspection fields, __schema
373    and __typename. __typename is special because it can always be
374    queried as a field, even in situations where no other fields
375    are allowed, like on a Union. __schema could get automatically
376    added to the query type, but that would require mutating type
377    definitions, which would cause issues."""
378    if field_name == "__schema" and schema.get_query_type() == parent_type:
379        return SchemaMetaFieldDef
380    elif field_name == "__type" and schema.get_query_type() == parent_type:
381        return TypeMetaFieldDef
382    elif field_name == "__typename":
383        return TypeNameMetaFieldDef
384    return parent_type.fields.get(field_name)
385