1import re 2from functools import partial 3from typing import Any, Optional 4 5from ...error import GraphQLError 6from ...language import TypeDefinitionNode, TypeExtensionNode 7from ...pyutils import did_you_mean, inspect, suggestion_list 8from ...type import ( 9 is_enum_type, 10 is_input_object_type, 11 is_interface_type, 12 is_object_type, 13 is_scalar_type, 14 is_union_type, 15) 16from . import SDLValidationContext, SDLValidationRule 17 18__all__ = ["PossibleTypeExtensionsRule"] 19 20 21class PossibleTypeExtensionsRule(SDLValidationRule): 22 """Possible type extension 23 24 A type extension is only valid if the type is defined and has the same kind. 25 """ 26 27 def __init__(self, context: SDLValidationContext): 28 super().__init__(context) 29 self.schema = context.schema 30 self.defined_types = { 31 def_.name.value: def_ 32 for def_ in context.document.definitions 33 if isinstance(def_, TypeDefinitionNode) 34 } 35 36 def check_extension(self, node: TypeExtensionNode, *_args: Any) -> None: 37 schema = self.schema 38 type_name = node.name.value 39 def_node = self.defined_types.get(type_name) 40 existing_type = schema.get_type(type_name) if schema else None 41 42 expected_kind: Optional[str] 43 if def_node: 44 expected_kind = def_kind_to_ext_kind(def_node.kind) 45 elif existing_type: 46 expected_kind = type_to_ext_kind(existing_type) 47 else: 48 expected_kind = None 49 50 if expected_kind: 51 if expected_kind != node.kind: 52 kind_str = extension_kind_to_type_name(node.kind) 53 self.report_error( 54 GraphQLError( 55 f"Cannot extend non-{kind_str} type '{type_name}'.", 56 [def_node, node] if def_node else node, 57 ) 58 ) 59 else: 60 all_type_names = list(self.defined_types) 61 if self.schema: 62 all_type_names.extend(self.schema.type_map) 63 suggested_types = suggestion_list(type_name, all_type_names) 64 self.report_error( 65 GraphQLError( 66 f"Cannot extend type '{type_name}' because it is not defined." 67 + did_you_mean(suggested_types), 68 node.name, 69 ) 70 ) 71 72 enter_scalar_type_extension = enter_object_type_extension = check_extension 73 enter_interface_type_extension = enter_union_type_extension = check_extension 74 enter_enum_type_extension = enter_input_object_type_extension = check_extension 75 76 77def_kind_to_ext_kind = partial(re.compile("(?<=_type_)definition$").sub, "extension") 78 79 80def type_to_ext_kind(type_: Any) -> str: 81 if is_scalar_type(type_): 82 return "scalar_type_extension" 83 if is_object_type(type_): 84 return "object_type_extension" 85 if is_interface_type(type_): 86 return "interface_type_extension" 87 if is_union_type(type_): 88 return "union_type_extension" 89 if is_enum_type(type_): 90 return "enum_type_extension" 91 if is_input_object_type(type_): 92 return "input_object_type_extension" 93 94 # Not reachable. All possible types have been considered. 95 raise TypeError(f"Unexpected type: {inspect(type_)}.") 96 97 98_type_names_for_extension_kinds = { 99 "scalar_type_extension": "scalar", 100 "object_type_extension": "object", 101 "interface_type_extension": "interface", 102 "union_type_extension": "union", 103 "enum_type_extension": "enum", 104 "input_object_type_extension": "input object", 105} 106 107 108def extension_kind_to_type_name(kind: str) -> str: 109 return _type_names_for_extension_kinds.get(kind, "unknown type") 110