1import abc 2import dataclasses 3import functools 4import inspect 5from dataclasses import Field, fields 6from typing import Any, Callable, Dict, Optional, Tuple 7from enum import Enum 8 9from marshmallow import ValidationError 10 11from dataclasses_json.utils import CatchAllVar 12 13KnownParameters = Dict[str, Any] 14UnknownParameters = Dict[str, Any] 15 16 17class _UndefinedParameterAction(abc.ABC): 18 @staticmethod 19 @abc.abstractmethod 20 def handle_from_dict(cls, kvs: Dict[Any, Any]) -> Dict[str, Any]: 21 """ 22 Return the parameters to initialize the class with. 23 """ 24 pass 25 26 @staticmethod 27 def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: 28 """ 29 Return the parameters that will be written to the output dict 30 """ 31 return kvs 32 33 @staticmethod 34 def handle_dump(obj) -> Dict[Any, Any]: 35 """ 36 Return the parameters that will be added to the schema dump. 37 """ 38 return {} 39 40 @staticmethod 41 def create_init(obj) -> Callable: 42 return obj.__init__ 43 44 @staticmethod 45 def _separate_defined_undefined_kvs(cls, kvs: Dict) -> \ 46 Tuple[KnownParameters, UnknownParameters]: 47 """ 48 Returns a 2 dictionaries: defined and undefined parameters 49 """ 50 class_fields = fields(cls) 51 field_names = [field.name for field in class_fields] 52 unknown_given_parameters = {k: v for k, v in kvs.items() if 53 k not in field_names} 54 known_given_parameters = {k: v for k, v in kvs.items() if 55 k in field_names} 56 return known_given_parameters, unknown_given_parameters 57 58 59class _RaiseUndefinedParameters(_UndefinedParameterAction): 60 """ 61 This action raises UndefinedParameterError if it encounters an undefined 62 parameter during initialization. 63 """ 64 65 @staticmethod 66 def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: 67 known, unknown = \ 68 _UndefinedParameterAction._separate_defined_undefined_kvs( 69 cls=cls, kvs=kvs) 70 if len(unknown) > 0: 71 raise UndefinedParameterError( 72 f"Received undefined initialization arguments {unknown}") 73 return known 74 75 76CatchAll = Optional[CatchAllVar] 77 78 79class _IgnoreUndefinedParameters(_UndefinedParameterAction): 80 """ 81 This action does nothing when it encounters undefined parameters. 82 The undefined parameters can not be retrieved after the class has been 83 created. 84 """ 85 86 @staticmethod 87 def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: 88 known_given_parameters, _ = \ 89 _UndefinedParameterAction._separate_defined_undefined_kvs( 90 cls=cls, kvs=kvs) 91 return known_given_parameters 92 93 @staticmethod 94 def create_init(obj) -> Callable: 95 original_init = obj.__init__ 96 init_signature = inspect.signature(original_init) 97 98 @functools.wraps(obj.__init__) 99 def _ignore_init(self, *args, **kwargs): 100 known_kwargs, _ = \ 101 _CatchAllUndefinedParameters._separate_defined_undefined_kvs( 102 obj, kwargs) 103 num_params_takeable = len( 104 init_signature.parameters) - 1 # don't count self 105 num_args_takeable = num_params_takeable - len(known_kwargs) 106 107 args = args[:num_args_takeable] 108 bound_parameters = init_signature.bind_partial(self, *args, 109 **known_kwargs) 110 bound_parameters.apply_defaults() 111 112 arguments = bound_parameters.arguments 113 arguments.pop("self", None) 114 final_parameters = \ 115 _IgnoreUndefinedParameters.handle_from_dict(obj, arguments) 116 original_init(self, **final_parameters) 117 118 return _ignore_init 119 120 121class _CatchAllUndefinedParameters(_UndefinedParameterAction): 122 """ 123 This class allows to add a field of type utils.CatchAll which acts as a 124 dictionary into which all 125 undefined parameters will be written. 126 These parameters are not affected by LetterCase. 127 If no undefined parameters are given, this dictionary will be empty. 128 """ 129 130 class _SentinelNoDefault: 131 pass 132 133 @staticmethod 134 def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]: 135 known, unknown = _UndefinedParameterAction \ 136 ._separate_defined_undefined_kvs(cls=cls, kvs=kvs) 137 catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( 138 cls=cls) 139 140 if catch_all_field.name in known: 141 142 already_parsed = isinstance(known[catch_all_field.name], dict) 143 default_value = _CatchAllUndefinedParameters._get_default( 144 catch_all_field=catch_all_field) 145 received_default = default_value == known[catch_all_field.name] 146 147 value_to_write: Any 148 if received_default and len(unknown) == 0: 149 value_to_write = default_value 150 elif received_default and len(unknown) > 0: 151 value_to_write = unknown 152 elif already_parsed: 153 # Did not receive default 154 value_to_write = known[catch_all_field.name] 155 if len(unknown) > 0: 156 value_to_write.update(unknown) 157 else: 158 error_message = f"Received input field with " \ 159 f"same name as catch-all field: " \ 160 f"'{catch_all_field.name}': " \ 161 f"'{known[catch_all_field.name]}'" 162 raise UndefinedParameterError(error_message) 163 else: 164 value_to_write = unknown 165 166 known[catch_all_field.name] = value_to_write 167 return known 168 169 @staticmethod 170 def _get_default(catch_all_field: Field) -> Any: 171 # access to the default factory currently causes 172 # a false-positive mypy error (16. Dec 2019): 173 # https://github.com/python/mypy/issues/6910 174 175 # noinspection PyProtectedMember 176 has_default = not isinstance(catch_all_field.default, 177 dataclasses._MISSING_TYPE) 178 # noinspection PyProtectedMember 179 has_default_factory = not isinstance(catch_all_field.default_factory, 180 # type: ignore 181 dataclasses._MISSING_TYPE) 182 default_value = _CatchAllUndefinedParameters._SentinelNoDefault 183 if has_default: 184 default_value = catch_all_field.default 185 elif has_default_factory: 186 # This might be unwanted if the default factory constructs 187 # something expensive, 188 # because we have to construct it again just for this test 189 default_value = catch_all_field.default_factory() # type: ignore 190 191 return default_value 192 193 @staticmethod 194 def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]: 195 catch_all_field = \ 196 _CatchAllUndefinedParameters._get_catch_all_field(obj) 197 undefined_parameters = kvs.pop(catch_all_field.name) 198 if isinstance(undefined_parameters, dict): 199 kvs.update( 200 undefined_parameters) # If desired handle letter case here 201 return kvs 202 203 @staticmethod 204 def handle_dump(obj) -> Dict[Any, Any]: 205 catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field( 206 cls=obj) 207 return getattr(obj, catch_all_field.name) 208 209 @staticmethod 210 def create_init(obj) -> Callable: 211 original_init = obj.__init__ 212 init_signature = inspect.signature(original_init) 213 214 @functools.wraps(obj.__init__) 215 def _catch_all_init(self, *args, **kwargs): 216 known_kwargs, unknown_kwargs = \ 217 _CatchAllUndefinedParameters._separate_defined_undefined_kvs( 218 obj, kwargs) 219 num_params_takeable = len( 220 init_signature.parameters) - 1 # don't count self 221 if _CatchAllUndefinedParameters._get_catch_all_field( 222 obj).name not in known_kwargs: 223 num_params_takeable -= 1 224 num_args_takeable = num_params_takeable - len(known_kwargs) 225 226 args, unknown_args = args[:num_args_takeable], args[ 227 num_args_takeable:] 228 bound_parameters = init_signature.bind_partial(self, *args, 229 **known_kwargs) 230 231 unknown_args = {f"_UNKNOWN{i}": v for i, v in 232 enumerate(unknown_args)} 233 arguments = bound_parameters.arguments 234 arguments.update(unknown_args) 235 arguments.update(unknown_kwargs) 236 arguments.pop("self", None) 237 final_parameters = _CatchAllUndefinedParameters.handle_from_dict( 238 obj, arguments) 239 original_init(self, **final_parameters) 240 241 return _catch_all_init 242 243 @staticmethod 244 def _get_catch_all_field(cls) -> Field: 245 catch_all_fields = list( 246 filter(lambda f: f.type == Optional[CatchAllVar], fields(cls))) 247 number_of_catch_all_fields = len(catch_all_fields) 248 if number_of_catch_all_fields == 0: 249 raise UndefinedParameterError( 250 "No field of type dataclasses_json.CatchAll defined") 251 elif number_of_catch_all_fields > 1: 252 raise UndefinedParameterError( 253 f"Multiple catch-all fields supplied: " 254 f"{number_of_catch_all_fields}.") 255 else: 256 return catch_all_fields[0] 257 258 259class Undefined(Enum): 260 """ 261 Choose the behavior what happens when an undefined parameter is encountered 262 during class initialization. 263 """ 264 INCLUDE = _CatchAllUndefinedParameters 265 RAISE = _RaiseUndefinedParameters 266 EXCLUDE = _IgnoreUndefinedParameters 267 268 269class UndefinedParameterError(ValidationError): 270 """ 271 Raised when something has gone wrong handling undefined parameters. 272 """ 273 pass 274