1from typing import Any, Collection, Dict, List, Optional, cast
2
3from ..language import ast, DirectiveLocation
4from ..pyutils import inspect, is_description, FrozenList
5from .definition import GraphQLArgument, GraphQLInputType, GraphQLNonNull, is_input_type
6from .scalars import GraphQLBoolean, GraphQLString
7
8__all__ = [
9    "is_directive",
10    "assert_directive",
11    "is_specified_directive",
12    "specified_directives",
13    "GraphQLDirective",
14    "GraphQLIncludeDirective",
15    "GraphQLSkipDirective",
16    "GraphQLDeprecatedDirective",
17    "GraphQLSpecifiedByDirective",
18    "DirectiveLocation",
19    "DEFAULT_DEPRECATION_REASON",
20]
21
22
23class GraphQLDirective:
24    """GraphQL Directive
25
26    Directives are used by the GraphQL runtime as a way of modifying execution behavior.
27    Type system creators will usually not create these directly.
28    """
29
30    name: str
31    locations: List[DirectiveLocation]
32    is_repeatable: bool
33    args: Dict[str, GraphQLArgument]
34    description: Optional[str]
35    extensions: Optional[Dict[str, Any]]
36    ast_node: Optional[ast.DirectiveDefinitionNode]
37
38    def __init__(
39        self,
40        name: str,
41        locations: Collection[DirectiveLocation],
42        args: Optional[Dict[str, GraphQLArgument]] = None,
43        is_repeatable: bool = False,
44        description: Optional[str] = None,
45        extensions: Optional[Dict[str, Any]] = None,
46        ast_node: Optional[ast.DirectiveDefinitionNode] = None,
47    ) -> None:
48        if not name:
49            raise TypeError("Directive must be named.")
50        elif not isinstance(name, str):
51            raise TypeError("The directive name must be a string.")
52        try:
53            locations = [
54                value
55                if isinstance(value, DirectiveLocation)
56                else DirectiveLocation[cast(str, value)]
57                for value in locations
58            ]
59        except (KeyError, TypeError):
60            raise TypeError(
61                f"{name} locations must be specified"
62                " as a collection of DirectiveLocation enum values."
63            )
64        if args is None:
65            args = {}
66        elif not isinstance(args, dict) or not all(
67            isinstance(key, str) for key in args
68        ):
69            raise TypeError(f"{name} args must be a dict with argument names as keys.")
70        elif not all(
71            isinstance(value, GraphQLArgument) or is_input_type(value)
72            for value in args.values()
73        ):
74            raise TypeError(
75                f"{name} args must be GraphQLArgument or input type objects."
76            )
77        else:
78            args = {
79                name: value
80                if isinstance(value, GraphQLArgument)
81                else GraphQLArgument(cast(GraphQLInputType, value))
82                for name, value in args.items()
83            }
84        if not isinstance(is_repeatable, bool):
85            raise TypeError(f"{name} is_repeatable flag must be True or False.")
86        if ast_node and not isinstance(ast_node, ast.DirectiveDefinitionNode):
87            raise TypeError(f"{name} AST node must be a DirectiveDefinitionNode.")
88        if description is not None and not is_description(description):
89            raise TypeError(f"{name} description must be a string.")
90        if extensions is not None and (
91            not isinstance(extensions, dict)
92            or not all(isinstance(key, str) for key in extensions)
93        ):
94            raise TypeError(f"{name} extensions must be a dictionary with string keys.")
95        self.name = name
96        self.locations = locations
97        self.args = args
98        self.is_repeatable = is_repeatable
99        self.description = description
100        self.extensions = extensions
101        self.ast_node = ast_node
102
103    def __str__(self) -> str:
104        return f"@{self.name}"
105
106    def __repr__(self) -> str:
107        return f"<{self.__class__.__name__}({self})>"
108
109    def __eq__(self, other: Any) -> bool:
110        return self is other or (
111            isinstance(other, GraphQLDirective)
112            and self.name == other.name
113            and self.locations == other.locations
114            and self.args == other.args
115            and self.is_repeatable == other.is_repeatable
116            and self.description == other.description
117            and self.extensions == other.extensions
118        )
119
120    def to_kwargs(self) -> Dict[str, Any]:
121        return dict(
122            name=self.name,
123            locations=self.locations,
124            args=self.args,
125            is_repeatable=self.is_repeatable,
126            description=self.description,
127            extensions=self.extensions,
128            ast_node=self.ast_node,
129        )
130
131    def __copy__(self) -> "GraphQLDirective":  # pragma: no cover
132        return self.__class__(**self.to_kwargs())
133
134
135def is_directive(directive: Any) -> bool:
136    """Test if the given value is a GraphQL directive."""
137    return isinstance(directive, GraphQLDirective)
138
139
140def assert_directive(directive: Any) -> GraphQLDirective:
141    if not is_directive(directive):
142        raise TypeError(f"Expected {inspect(directive)} to be a GraphQL directive.")
143    return cast(GraphQLDirective, directive)
144
145
146# Used to conditionally include fields or fragments.
147GraphQLIncludeDirective = GraphQLDirective(
148    name="include",
149    locations=[
150        DirectiveLocation.FIELD,
151        DirectiveLocation.FRAGMENT_SPREAD,
152        DirectiveLocation.INLINE_FRAGMENT,
153    ],
154    args={
155        "if": GraphQLArgument(
156            GraphQLNonNull(GraphQLBoolean), description="Included when true."
157        )
158    },
159    description="Directs the executor to include this field or fragment"
160    " only when the `if` argument is true.",
161)
162
163
164# Used to conditionally skip (exclude) fields or fragments:
165GraphQLSkipDirective = GraphQLDirective(
166    name="skip",
167    locations=[
168        DirectiveLocation.FIELD,
169        DirectiveLocation.FRAGMENT_SPREAD,
170        DirectiveLocation.INLINE_FRAGMENT,
171    ],
172    args={
173        "if": GraphQLArgument(
174            GraphQLNonNull(GraphQLBoolean), description="Skipped when true."
175        )
176    },
177    description="Directs the executor to skip this field or fragment"
178    " when the `if` argument is true.",
179)
180
181
182# Constant string used for default reason for a deprecation:
183DEFAULT_DEPRECATION_REASON = "No longer supported"
184
185# Used to declare element of a GraphQL schema as deprecated:
186GraphQLDeprecatedDirective = GraphQLDirective(
187    name="deprecated",
188    locations=[
189        DirectiveLocation.FIELD_DEFINITION,
190        DirectiveLocation.ARGUMENT_DEFINITION,
191        DirectiveLocation.INPUT_FIELD_DEFINITION,
192        DirectiveLocation.ENUM_VALUE,
193    ],
194    args={
195        "reason": GraphQLArgument(
196            GraphQLString,
197            description="Explains why this element was deprecated,"
198            " usually also including a suggestion for how to access"
199            " supported similar data."
200            " Formatted using the Markdown syntax, as specified by"
201            " [CommonMark](https://commonmark.org/).",
202            default_value=DEFAULT_DEPRECATION_REASON,
203        )
204    },
205    description="Marks an element of a GraphQL schema as no longer supported.",
206)
207
208# Used to provide a URL for specifying the behaviour of custom scalar definitions:
209GraphQLSpecifiedByDirective = GraphQLDirective(
210    name="specifiedBy",
211    locations=[DirectiveLocation.SCALAR],
212    args={
213        "url": GraphQLArgument(
214            GraphQLNonNull(GraphQLString),
215            description="The URL that specifies the behaviour of this scalar.",
216        )
217    },
218    description="Exposes a URL that specifies the behaviour of this scalar.",
219)
220
221
222specified_directives: FrozenList[GraphQLDirective] = FrozenList(
223    [
224        GraphQLIncludeDirective,
225        GraphQLSkipDirective,
226        GraphQLDeprecatedDirective,
227        GraphQLSpecifiedByDirective,
228    ]
229)
230specified_directives.__doc__ = """The full list of specified directives."""
231
232
233def is_specified_directive(directive: GraphQLDirective) -> bool:
234    """Check whether the given directive is one of the specified directives."""
235    return any(
236        specified_directive.name == directive.name
237        for specified_directive in specified_directives
238    )
239