1from copy import copy, deepcopy
2from enum import Enum
3from typing import Any, Dict, List, Optional, Union
4
5from .source import Source
6from .token_kind import TokenKind
7from ..pyutils import camel_to_snake, FrozenList
8
9__all__ = [
10    "Location",
11    "Token",
12    "Node",
13    "NameNode",
14    "DocumentNode",
15    "DefinitionNode",
16    "ExecutableDefinitionNode",
17    "OperationDefinitionNode",
18    "VariableDefinitionNode",
19    "SelectionSetNode",
20    "SelectionNode",
21    "FieldNode",
22    "ArgumentNode",
23    "FragmentSpreadNode",
24    "InlineFragmentNode",
25    "FragmentDefinitionNode",
26    "ValueNode",
27    "VariableNode",
28    "IntValueNode",
29    "FloatValueNode",
30    "StringValueNode",
31    "BooleanValueNode",
32    "NullValueNode",
33    "EnumValueNode",
34    "ListValueNode",
35    "ObjectValueNode",
36    "ObjectFieldNode",
37    "DirectiveNode",
38    "TypeNode",
39    "NamedTypeNode",
40    "ListTypeNode",
41    "NonNullTypeNode",
42    "TypeSystemDefinitionNode",
43    "SchemaDefinitionNode",
44    "OperationType",
45    "OperationTypeDefinitionNode",
46    "TypeDefinitionNode",
47    "ScalarTypeDefinitionNode",
48    "ObjectTypeDefinitionNode",
49    "FieldDefinitionNode",
50    "InputValueDefinitionNode",
51    "InterfaceTypeDefinitionNode",
52    "UnionTypeDefinitionNode",
53    "EnumTypeDefinitionNode",
54    "EnumValueDefinitionNode",
55    "InputObjectTypeDefinitionNode",
56    "DirectiveDefinitionNode",
57    "SchemaExtensionNode",
58    "TypeExtensionNode",
59    "TypeSystemExtensionNode",
60    "ScalarTypeExtensionNode",
61    "ObjectTypeExtensionNode",
62    "InterfaceTypeExtensionNode",
63    "UnionTypeExtensionNode",
64    "EnumTypeExtensionNode",
65    "InputObjectTypeExtensionNode",
66]
67
68
69class Token:
70    """AST Token
71
72    Represents a range of characters represented by a lexical token within a Source.
73    """
74
75    __slots__ = "kind", "start", "end", "line", "column", "prev", "next", "value"
76
77    kind: TokenKind  # the kind of token
78    start: int  # the character offset at which this Node begins
79    end: int  # the character offset at which this Node ends
80    line: int  # the 1-indexed line number on which this Token appears
81    column: int  # the 1-indexed column number at which this Token begins
82    # for non-punctuation tokens, represents the interpreted value of the token:
83    value: Optional[str]
84    # Tokens exist as nodes in a double-linked-list amongst all tokens including
85    # ignored tokens. <SOF> is always the first node and <EOF> the last.
86    prev: Optional["Token"]
87    next: Optional["Token"]
88
89    def __init__(
90        self,
91        kind: TokenKind,
92        start: int,
93        end: int,
94        line: int,
95        column: int,
96        prev: Optional["Token"] = None,
97        value: Optional[str] = None,
98    ) -> None:
99        self.kind = kind
100        self.start, self.end = start, end
101        self.line, self.column = line, column
102        self.value = value
103        self.prev = prev
104        self.next = None
105
106    def __str__(self) -> str:
107        return self.desc
108
109    def __repr__(self) -> str:
110        """Print a simplified form when appearing in repr() or inspect()."""
111        return f"<Token {self.desc} {self.line}:{self.column}>"
112
113    def __inspect__(self) -> str:
114        return repr(self)
115
116    def __eq__(self, other: Any) -> bool:
117        if isinstance(other, Token):
118            return (
119                self.kind == other.kind
120                and self.start == other.start
121                and self.end == other.end
122                and self.line == other.line
123                and self.column == other.column
124                and self.value == other.value
125            )
126        elif isinstance(other, str):
127            return other == self.desc
128        return False
129
130    def __hash__(self) -> int:
131        return hash(
132            (self.kind, self.start, self.end, self.line, self.column, self.value)
133        )
134
135    def __copy__(self) -> "Token":
136        """Create a shallow copy of the token"""
137        return self.__class__(
138            self.kind,
139            self.start,
140            self.end,
141            self.line,
142            self.column,
143            self.prev,
144            self.value,
145        )
146
147    def __deepcopy__(self, memo: Dict) -> "Token":
148        """Allow only shallow copies to avoid recursion."""
149        return copy(self)
150
151    @property
152    def desc(self) -> str:
153        """A helper property to describe a token as a string for debugging"""
154        kind, value = self.kind.value, self.value
155        return f"{kind} {value!r}" if value else kind
156
157
158class Location:
159    """AST Location
160
161    Contains a range of UTF-8 character offsets and token references that identify the
162    region of the source from which the AST derived.
163    """
164
165    __slots__ = (
166        "start",
167        "end",
168        "start_token",
169        "end_token",
170        "source",
171    )
172
173    start: int  # character offset at which this Node begins
174    end: int  # character offset at which this Node ends
175    start_token: Token  # Token at which this Node begins
176    end_token: Token  # Token at which this Node ends.
177    source: Source  # Source document the AST represents
178
179    def __init__(self, start_token: Token, end_token: Token, source: Source) -> None:
180        self.start = start_token.start
181        self.end = end_token.end
182        self.start_token = start_token
183        self.end_token = end_token
184        self.source = source
185
186    def __str__(self) -> str:
187        return f"{self.start}:{self.end}"
188
189    def __repr__(self) -> str:
190        """Print a simplified form when appearing in repr() or inspect()."""
191        return f"<Location {self.start}:{self.end}>"
192
193    def __inspect__(self) -> str:
194        return repr(self)
195
196    def __eq__(self, other: Any) -> bool:
197        if isinstance(other, Location):
198            return self.start == other.start and self.end == other.end
199        elif isinstance(other, (list, tuple)) and len(other) == 2:
200            return self.start == other[0] and self.end == other[1]
201        return False
202
203    def __ne__(self, other: Any) -> bool:
204        return not self == other
205
206    def __hash__(self) -> int:
207        return hash((self.start, self.end))
208
209
210class OperationType(Enum):
211
212    QUERY = "query"
213    MUTATION = "mutation"
214    SUBSCRIPTION = "subscription"
215
216
217# Base AST Node
218
219
220class Node:
221    """AST nodes"""
222
223    # allow custom attributes and weak references (not used internally)
224    __slots__ = "__dict__", "__weakref__", "loc"
225
226    loc: Optional[Location]
227
228    kind: str = "ast"  # the kind of the node as a snake_case string
229    keys = ["loc"]  # the names of the attributes of this node
230
231    def __init__(self, **kwargs: Any) -> None:
232        """Initialize the node with the given keyword arguments."""
233        for key in self.keys:
234            value = kwargs.get(key)
235            if isinstance(value, list) and not isinstance(value, FrozenList):
236                value = FrozenList(value)
237            setattr(self, key, value)
238
239    def __repr__(self) -> str:
240        """Get a simple representation of the node."""
241        name, loc = self.__class__.__name__, getattr(self, "loc", None)
242        return f"{name} at {loc}" if loc else name
243
244    def __eq__(self, other: Any) -> bool:
245        """Test whether two nodes are equal (recursively)."""
246        return (
247            isinstance(other, Node)
248            and self.__class__ == other.__class__
249            and all(getattr(self, key) == getattr(other, key) for key in self.keys)
250        )
251
252    def __hash__(self) -> int:
253        return hash(tuple(getattr(self, key) for key in self.keys))
254
255    def __copy__(self) -> "Node":
256        """Create a shallow copy of the node."""
257        return self.__class__(**{key: getattr(self, key) for key in self.keys})
258
259    def __deepcopy__(self, memo: Dict) -> "Node":
260        """Create a deep copy of the node"""
261        # noinspection PyArgumentList
262        return self.__class__(
263            **{key: deepcopy(getattr(self, key), memo) for key in self.keys}
264        )
265
266    def __init_subclass__(cls) -> None:
267        super().__init_subclass__()
268        name = cls.__name__
269        if name.endswith("Node"):
270            name = name[:-4]
271        cls.kind = camel_to_snake(name)
272        keys: List[str] = []
273        for base in cls.__bases__:
274            # noinspection PyUnresolvedReferences
275            keys.extend(base.keys)  # type: ignore
276        keys.extend(cls.__slots__)
277        cls.keys = keys
278
279
280# Name
281
282
283class NameNode(Node):
284    __slots__ = ("value",)
285
286    value: str
287
288
289# Document
290
291
292class DocumentNode(Node):
293    __slots__ = ("definitions",)
294
295    definitions: FrozenList["DefinitionNode"]
296
297
298class DefinitionNode(Node):
299    __slots__ = ()
300
301
302class ExecutableDefinitionNode(DefinitionNode):
303    __slots__ = "name", "directives", "variable_definitions", "selection_set"
304
305    name: Optional[NameNode]
306    directives: FrozenList["DirectiveNode"]
307    variable_definitions: FrozenList["VariableDefinitionNode"]
308    selection_set: "SelectionSetNode"
309
310
311class OperationDefinitionNode(ExecutableDefinitionNode):
312    __slots__ = ("operation",)
313
314    operation: OperationType
315
316
317class VariableDefinitionNode(Node):
318    __slots__ = "variable", "type", "default_value", "directives"
319
320    variable: "VariableNode"
321    type: "TypeNode"
322    default_value: Optional["ValueNode"]
323    directives: FrozenList["DirectiveNode"]
324
325
326class SelectionSetNode(Node):
327    __slots__ = ("selections",)
328
329    selections: FrozenList["SelectionNode"]
330
331
332class SelectionNode(Node):
333    __slots__ = ("directives",)
334
335    directives: FrozenList["DirectiveNode"]
336
337
338class FieldNode(SelectionNode):
339    __slots__ = "alias", "name", "arguments", "selection_set"
340
341    alias: Optional[NameNode]
342    name: NameNode
343    arguments: FrozenList["ArgumentNode"]
344    selection_set: Optional[SelectionSetNode]
345
346
347class ArgumentNode(Node):
348    __slots__ = "name", "value"
349
350    name: NameNode
351    value: "ValueNode"
352
353
354# Fragments
355
356
357class FragmentSpreadNode(SelectionNode):
358    __slots__ = ("name",)
359
360    name: NameNode
361
362
363class InlineFragmentNode(SelectionNode):
364    __slots__ = "type_condition", "selection_set"
365
366    type_condition: "NamedTypeNode"
367    selection_set: SelectionSetNode
368
369
370class FragmentDefinitionNode(ExecutableDefinitionNode):
371    __slots__ = ("type_condition",)
372
373    name: NameNode
374    type_condition: "NamedTypeNode"
375
376
377# Values
378
379
380class ValueNode(Node):
381    __slots__ = ()
382
383
384class VariableNode(ValueNode):
385    __slots__ = ("name",)
386
387    name: NameNode
388
389
390class IntValueNode(ValueNode):
391    __slots__ = ("value",)
392
393    value: str
394
395
396class FloatValueNode(ValueNode):
397    __slots__ = ("value",)
398
399    value: str
400
401
402class StringValueNode(ValueNode):
403    __slots__ = "value", "block"
404
405    value: str
406    block: Optional[bool]
407
408
409class BooleanValueNode(ValueNode):
410    __slots__ = ("value",)
411
412    value: bool
413
414
415class NullValueNode(ValueNode):
416    __slots__ = ()
417
418
419class EnumValueNode(ValueNode):
420    __slots__ = ("value",)
421
422    value: str
423
424
425class ListValueNode(ValueNode):
426    __slots__ = ("values",)
427
428    values: FrozenList[ValueNode]
429
430
431class ObjectValueNode(ValueNode):
432    __slots__ = ("fields",)
433
434    fields: FrozenList["ObjectFieldNode"]
435
436
437class ObjectFieldNode(Node):
438    __slots__ = "name", "value"
439
440    name: NameNode
441    value: ValueNode
442
443
444# Directives
445
446
447class DirectiveNode(Node):
448    __slots__ = "name", "arguments"
449
450    name: NameNode
451    arguments: FrozenList[ArgumentNode]
452
453
454# Type Reference
455
456
457class TypeNode(Node):
458    __slots__ = ()
459
460
461class NamedTypeNode(TypeNode):
462    __slots__ = ("name",)
463
464    name: NameNode
465
466
467class ListTypeNode(TypeNode):
468    __slots__ = ("type",)
469
470    type: TypeNode
471
472
473class NonNullTypeNode(TypeNode):
474    __slots__ = ("type",)
475
476    type: Union[NamedTypeNode, ListTypeNode]
477
478
479# Type System Definition
480
481
482class TypeSystemDefinitionNode(DefinitionNode):
483    __slots__ = ()
484
485
486class SchemaDefinitionNode(TypeSystemDefinitionNode):
487    __slots__ = "description", "directives", "operation_types"
488
489    description: Optional[StringValueNode]
490    directives: FrozenList[DirectiveNode]
491    operation_types: FrozenList["OperationTypeDefinitionNode"]
492
493
494class OperationTypeDefinitionNode(Node):
495    __slots__ = "operation", "type"
496
497    operation: OperationType
498    type: NamedTypeNode
499
500
501# Type Definition
502
503
504class TypeDefinitionNode(TypeSystemDefinitionNode):
505    __slots__ = "description", "name", "directives"
506
507    description: Optional[StringValueNode]
508    name: NameNode
509    directives: FrozenList[DirectiveNode]
510
511
512class ScalarTypeDefinitionNode(TypeDefinitionNode):
513    __slots__ = ()
514
515
516class ObjectTypeDefinitionNode(TypeDefinitionNode):
517    __slots__ = "interfaces", "fields"
518
519    interfaces: FrozenList[NamedTypeNode]
520    fields: FrozenList["FieldDefinitionNode"]
521
522
523class FieldDefinitionNode(DefinitionNode):
524    __slots__ = "description", "name", "directives", "arguments", "type"
525
526    description: Optional[StringValueNode]
527    name: NameNode
528    directives: FrozenList[DirectiveNode]
529    arguments: FrozenList["InputValueDefinitionNode"]
530    type: TypeNode
531
532
533class InputValueDefinitionNode(DefinitionNode):
534    __slots__ = "description", "name", "directives", "type", "default_value"
535
536    description: Optional[StringValueNode]
537    name: NameNode
538    directives: FrozenList[DirectiveNode]
539    type: TypeNode
540    default_value: Optional[ValueNode]
541
542
543class InterfaceTypeDefinitionNode(TypeDefinitionNode):
544    __slots__ = "fields", "interfaces"
545
546    fields: FrozenList["FieldDefinitionNode"]
547    interfaces: FrozenList[NamedTypeNode]
548
549
550class UnionTypeDefinitionNode(TypeDefinitionNode):
551    __slots__ = ("types",)
552
553    types: FrozenList[NamedTypeNode]
554
555
556class EnumTypeDefinitionNode(TypeDefinitionNode):
557    __slots__ = ("values",)
558
559    values: FrozenList["EnumValueDefinitionNode"]
560
561
562class EnumValueDefinitionNode(DefinitionNode):
563    __slots__ = "description", "name", "directives"
564
565    description: Optional[StringValueNode]
566    name: NameNode
567    directives: FrozenList[DirectiveNode]
568
569
570class InputObjectTypeDefinitionNode(TypeDefinitionNode):
571    __slots__ = ("fields",)
572
573    fields: FrozenList[InputValueDefinitionNode]
574
575
576# Directive Definitions
577
578
579class DirectiveDefinitionNode(TypeSystemDefinitionNode):
580    __slots__ = "description", "name", "arguments", "repeatable", "locations"
581
582    description: Optional[StringValueNode]
583    name: NameNode
584    arguments: FrozenList[InputValueDefinitionNode]
585    repeatable: bool
586    locations: FrozenList[NameNode]
587
588
589# Type System Extensions
590
591
592class SchemaExtensionNode(Node):
593    __slots__ = "directives", "operation_types"
594
595    directives: FrozenList[DirectiveNode]
596    operation_types: FrozenList[OperationTypeDefinitionNode]
597
598
599# Type Extensions
600
601
602class TypeExtensionNode(TypeSystemDefinitionNode):
603    __slots__ = "name", "directives"
604
605    name: NameNode
606    directives: FrozenList[DirectiveNode]
607
608
609TypeSystemExtensionNode = Union[SchemaExtensionNode, TypeExtensionNode]
610
611
612class ScalarTypeExtensionNode(TypeExtensionNode):
613    __slots__ = ()
614
615
616class ObjectTypeExtensionNode(TypeExtensionNode):
617    __slots__ = "interfaces", "fields"
618
619    interfaces: FrozenList[NamedTypeNode]
620    fields: FrozenList[FieldDefinitionNode]
621
622
623class InterfaceTypeExtensionNode(TypeExtensionNode):
624    __slots__ = "interfaces", "fields"
625
626    interfaces: FrozenList[NamedTypeNode]
627    fields: FrozenList[FieldDefinitionNode]
628
629
630class UnionTypeExtensionNode(TypeExtensionNode):
631    __slots__ = ("types",)
632
633    types: FrozenList[NamedTypeNode]
634
635
636class EnumTypeExtensionNode(TypeExtensionNode):
637    __slots__ = ("values",)
638
639    values: FrozenList[EnumValueDefinitionNode]
640
641
642class InputObjectTypeExtensionNode(TypeExtensionNode):
643    __slots__ = ("fields",)
644
645    fields: FrozenList[InputValueDefinitionNode]
646