1from . import ast
2
3QUERY_DOCUMENT_KEYS = {
4    ast.Name: (),
5    ast.Document: ("definitions",),
6    ast.OperationDefinition: (
7        "name",
8        "variable_definitions",
9        "directives",
10        "selection_set",
11    ),
12    ast.VariableDefinition: ("variable", "type", "default_value"),
13    ast.Variable: ("name",),
14    ast.SelectionSet: ("selections",),
15    ast.Field: ("alias", "name", "arguments", "directives", "selection_set"),
16    ast.Argument: ("name", "value"),
17    ast.FragmentSpread: ("name", "directives"),
18    ast.InlineFragment: ("type_condition", "directives", "selection_set"),
19    ast.FragmentDefinition: ("name", "type_condition", "directives", "selection_set"),
20    ast.IntValue: (),
21    ast.FloatValue: (),
22    ast.StringValue: (),
23    ast.BooleanValue: (),
24    ast.EnumValue: (),
25    ast.ListValue: ("values",),
26    ast.ObjectValue: ("fields",),
27    ast.ObjectField: ("name", "value"),
28    ast.Directive: ("name", "arguments"),
29    ast.NamedType: ("name",),
30    ast.ListType: ("type",),
31    ast.NonNullType: ("type",),
32    ast.SchemaDefinition: ("directives", "operation_types"),
33    ast.OperationTypeDefinition: ("type",),
34    ast.ScalarTypeDefinition: ("name", "directives"),
35    ast.ObjectTypeDefinition: ("name", "interfaces", "directives", "fields"),
36    ast.FieldDefinition: ("name", "arguments", "directives", "type"),
37    ast.InputValueDefinition: ("name", "type", "directives", "default_value"),
38    ast.InterfaceTypeDefinition: ("name", "directives", "fields"),
39    ast.UnionTypeDefinition: ("name", "directives", "types"),
40    ast.EnumTypeDefinition: ("name", "directives", "values"),
41    ast.EnumValueDefinition: ("name", "directives"),
42    ast.InputObjectTypeDefinition: ("name", "directives", "fields"),
43    ast.TypeExtensionDefinition: ("definition",),
44    ast.DirectiveDefinition: ("name", "arguments", "locations"),
45}
46
47AST_KIND_TO_TYPE = {c.__name__: c for c in QUERY_DOCUMENT_KEYS.keys()}
48
49
50class VisitorMeta(type):
51    def __new__(cls, name, bases, attrs):
52        enter_handlers = {}
53        leave_handlers = {}
54
55        for base in bases:
56            if hasattr(base, "_enter_handlers"):
57                enter_handlers.update(base._enter_handlers)
58
59            if hasattr(base, "_leave_handlers"):
60                leave_handlers.update(base._leave_handlers)
61
62        for attr, val in attrs.items():
63            if attr.startswith("enter_"):
64                ast_kind = attr[6:]
65                ast_type = AST_KIND_TO_TYPE.get(ast_kind)
66                enter_handlers[ast_type] = val
67
68            elif attr.startswith("leave_"):
69                ast_kind = attr[6:]
70                ast_type = AST_KIND_TO_TYPE.get(ast_kind)
71                leave_handlers[ast_type] = val
72
73        attrs["_enter_handlers"] = enter_handlers
74        attrs["_leave_handlers"] = leave_handlers
75        attrs["_get_enter_handler"] = enter_handlers.get
76        attrs["_get_leave_handler"] = leave_handlers.get
77        return super(VisitorMeta, cls).__new__(cls, name, bases, attrs)
78