1import re
2from functools import partial
3from typing import Any, Optional
4
5from ...error import GraphQLError
6from ...language import TypeDefinitionNode, TypeExtensionNode
7from ...pyutils import did_you_mean, inspect, suggestion_list
8from ...type import (
9    is_enum_type,
10    is_input_object_type,
11    is_interface_type,
12    is_object_type,
13    is_scalar_type,
14    is_union_type,
15)
16from . import SDLValidationContext, SDLValidationRule
17
18__all__ = ["PossibleTypeExtensionsRule"]
19
20
21class PossibleTypeExtensionsRule(SDLValidationRule):
22    """Possible type extension
23
24    A type extension is only valid if the type is defined and has the same kind.
25    """
26
27    def __init__(self, context: SDLValidationContext):
28        super().__init__(context)
29        self.schema = context.schema
30        self.defined_types = {
31            def_.name.value: def_
32            for def_ in context.document.definitions
33            if isinstance(def_, TypeDefinitionNode)
34        }
35
36    def check_extension(self, node: TypeExtensionNode, *_args: Any) -> None:
37        schema = self.schema
38        type_name = node.name.value
39        def_node = self.defined_types.get(type_name)
40        existing_type = schema.get_type(type_name) if schema else None
41
42        expected_kind: Optional[str]
43        if def_node:
44            expected_kind = def_kind_to_ext_kind(def_node.kind)
45        elif existing_type:
46            expected_kind = type_to_ext_kind(existing_type)
47        else:
48            expected_kind = None
49
50        if expected_kind:
51            if expected_kind != node.kind:
52                kind_str = extension_kind_to_type_name(node.kind)
53                self.report_error(
54                    GraphQLError(
55                        f"Cannot extend non-{kind_str} type '{type_name}'.",
56                        [def_node, node] if def_node else node,
57                    )
58                )
59        else:
60            all_type_names = list(self.defined_types)
61            if self.schema:
62                all_type_names.extend(self.schema.type_map)
63            suggested_types = suggestion_list(type_name, all_type_names)
64            self.report_error(
65                GraphQLError(
66                    f"Cannot extend type '{type_name}' because it is not defined."
67                    + did_you_mean(suggested_types),
68                    node.name,
69                )
70            )
71
72    enter_scalar_type_extension = enter_object_type_extension = check_extension
73    enter_interface_type_extension = enter_union_type_extension = check_extension
74    enter_enum_type_extension = enter_input_object_type_extension = check_extension
75
76
77def_kind_to_ext_kind = partial(re.compile("(?<=_type_)definition$").sub, "extension")
78
79
80def type_to_ext_kind(type_: Any) -> str:
81    if is_scalar_type(type_):
82        return "scalar_type_extension"
83    if is_object_type(type_):
84        return "object_type_extension"
85    if is_interface_type(type_):
86        return "interface_type_extension"
87    if is_union_type(type_):
88        return "union_type_extension"
89    if is_enum_type(type_):
90        return "enum_type_extension"
91    if is_input_object_type(type_):
92        return "input_object_type_extension"
93
94    # Not reachable. All possible types have been considered.
95    raise TypeError(f"Unexpected type: {inspect(type_)}.")
96
97
98_type_names_for_extension_kinds = {
99    "scalar_type_extension": "scalar",
100    "object_type_extension": "object",
101    "interface_type_extension": "interface",
102    "union_type_extension": "union",
103    "enum_type_extension": "enum",
104    "input_object_type_extension": "input object",
105}
106
107
108def extension_kind_to_type_name(kind: str) -> str:
109    return _type_names_for_extension_kinds.get(kind, "unknown type")
110