1import abc
2import dataclasses
3import functools
4import inspect
5from dataclasses import Field, fields
6from typing import Any, Callable, Dict, Optional, Tuple
7from enum import Enum
8
9from marshmallow import ValidationError
10
11from dataclasses_json.utils import CatchAllVar
12
13KnownParameters = Dict[str, Any]
14UnknownParameters = Dict[str, Any]
15
16
17class _UndefinedParameterAction(abc.ABC):
18    @staticmethod
19    @abc.abstractmethod
20    def handle_from_dict(cls, kvs: Dict[Any, Any]) -> Dict[str, Any]:
21        """
22        Return the parameters to initialize the class with.
23        """
24        pass
25
26    @staticmethod
27    def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]:
28        """
29        Return the parameters that will be written to the output dict
30        """
31        return kvs
32
33    @staticmethod
34    def handle_dump(obj) -> Dict[Any, Any]:
35        """
36        Return the parameters that will be added to the schema dump.
37        """
38        return {}
39
40    @staticmethod
41    def create_init(obj) -> Callable:
42        return obj.__init__
43
44    @staticmethod
45    def _separate_defined_undefined_kvs(cls, kvs: Dict) -> \
46            Tuple[KnownParameters, UnknownParameters]:
47        """
48        Returns a 2 dictionaries: defined and undefined parameters
49        """
50        class_fields = fields(cls)
51        field_names = [field.name for field in class_fields]
52        unknown_given_parameters = {k: v for k, v in kvs.items() if
53                                    k not in field_names}
54        known_given_parameters = {k: v for k, v in kvs.items() if
55                                  k in field_names}
56        return known_given_parameters, unknown_given_parameters
57
58
59class _RaiseUndefinedParameters(_UndefinedParameterAction):
60    """
61    This action raises UndefinedParameterError if it encounters an undefined
62    parameter during initialization.
63    """
64
65    @staticmethod
66    def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
67        known, unknown = \
68            _UndefinedParameterAction._separate_defined_undefined_kvs(
69                cls=cls, kvs=kvs)
70        if len(unknown) > 0:
71            raise UndefinedParameterError(
72                f"Received undefined initialization arguments {unknown}")
73        return known
74
75
76CatchAll = Optional[CatchAllVar]
77
78
79class _IgnoreUndefinedParameters(_UndefinedParameterAction):
80    """
81    This action does nothing when it encounters undefined parameters.
82    The undefined parameters can not be retrieved after the class has been
83    created.
84    """
85
86    @staticmethod
87    def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
88        known_given_parameters, _ = \
89            _UndefinedParameterAction._separate_defined_undefined_kvs(
90                cls=cls, kvs=kvs)
91        return known_given_parameters
92
93    @staticmethod
94    def create_init(obj) -> Callable:
95        original_init = obj.__init__
96        init_signature = inspect.signature(original_init)
97
98        @functools.wraps(obj.__init__)
99        def _ignore_init(self, *args, **kwargs):
100            known_kwargs, _ = \
101                _CatchAllUndefinedParameters._separate_defined_undefined_kvs(
102                    obj, kwargs)
103            num_params_takeable = len(
104                init_signature.parameters) - 1  # don't count self
105            num_args_takeable = num_params_takeable - len(known_kwargs)
106
107            args = args[:num_args_takeable]
108            bound_parameters = init_signature.bind_partial(self, *args,
109                                                           **known_kwargs)
110            bound_parameters.apply_defaults()
111
112            arguments = bound_parameters.arguments
113            arguments.pop("self", None)
114            final_parameters = \
115                _IgnoreUndefinedParameters.handle_from_dict(obj, arguments)
116            original_init(self, **final_parameters)
117
118        return _ignore_init
119
120
121class _CatchAllUndefinedParameters(_UndefinedParameterAction):
122    """
123    This class allows to add a field of type utils.CatchAll which acts as a
124    dictionary into which all
125    undefined parameters will be written.
126    These parameters are not affected by LetterCase.
127    If no undefined parameters are given, this dictionary will be empty.
128    """
129
130    class _SentinelNoDefault:
131        pass
132
133    @staticmethod
134    def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
135        known, unknown = _UndefinedParameterAction \
136            ._separate_defined_undefined_kvs(cls=cls, kvs=kvs)
137        catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field(
138            cls=cls)
139
140        if catch_all_field.name in known:
141
142            already_parsed = isinstance(known[catch_all_field.name], dict)
143            default_value = _CatchAllUndefinedParameters._get_default(
144                catch_all_field=catch_all_field)
145            received_default = default_value == known[catch_all_field.name]
146
147            value_to_write: Any
148            if received_default and len(unknown) == 0:
149                value_to_write = default_value
150            elif received_default and len(unknown) > 0:
151                value_to_write = unknown
152            elif already_parsed:
153                # Did not receive default
154                value_to_write = known[catch_all_field.name]
155                if len(unknown) > 0:
156                    value_to_write.update(unknown)
157            else:
158                error_message = f"Received input field with " \
159                                f"same name as catch-all field: " \
160                                f"'{catch_all_field.name}': " \
161                                f"'{known[catch_all_field.name]}'"
162                raise UndefinedParameterError(error_message)
163        else:
164            value_to_write = unknown
165
166        known[catch_all_field.name] = value_to_write
167        return known
168
169    @staticmethod
170    def _get_default(catch_all_field: Field) -> Any:
171        # access to the default factory currently causes
172        # a false-positive mypy error (16. Dec 2019):
173        # https://github.com/python/mypy/issues/6910
174
175        # noinspection PyProtectedMember
176        has_default = not isinstance(catch_all_field.default,
177                                     dataclasses._MISSING_TYPE)
178        # noinspection PyProtectedMember
179        has_default_factory = not isinstance(catch_all_field.default_factory,
180                                             # type: ignore
181                                             dataclasses._MISSING_TYPE)
182        default_value = _CatchAllUndefinedParameters._SentinelNoDefault
183        if has_default:
184            default_value = catch_all_field.default
185        elif has_default_factory:
186            # This might be unwanted if the default factory constructs
187            # something expensive,
188            # because we have to construct it again just for this test
189            default_value = catch_all_field.default_factory()  # type: ignore
190
191        return default_value
192
193    @staticmethod
194    def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]:
195        catch_all_field = \
196            _CatchAllUndefinedParameters._get_catch_all_field(obj)
197        undefined_parameters = kvs.pop(catch_all_field.name)
198        if isinstance(undefined_parameters, dict):
199            kvs.update(
200                undefined_parameters)  # If desired handle letter case here
201        return kvs
202
203    @staticmethod
204    def handle_dump(obj) -> Dict[Any, Any]:
205        catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field(
206            cls=obj)
207        return getattr(obj, catch_all_field.name)
208
209    @staticmethod
210    def create_init(obj) -> Callable:
211        original_init = obj.__init__
212        init_signature = inspect.signature(original_init)
213
214        @functools.wraps(obj.__init__)
215        def _catch_all_init(self, *args, **kwargs):
216            known_kwargs, unknown_kwargs = \
217                _CatchAllUndefinedParameters._separate_defined_undefined_kvs(
218                    obj, kwargs)
219            num_params_takeable = len(
220                init_signature.parameters) - 1  # don't count self
221            if _CatchAllUndefinedParameters._get_catch_all_field(
222                    obj).name not in known_kwargs:
223                num_params_takeable -= 1
224            num_args_takeable = num_params_takeable - len(known_kwargs)
225
226            args, unknown_args = args[:num_args_takeable], args[
227                                                           num_args_takeable:]
228            bound_parameters = init_signature.bind_partial(self, *args,
229                                                           **known_kwargs)
230
231            unknown_args = {f"_UNKNOWN{i}": v for i, v in
232                            enumerate(unknown_args)}
233            arguments = bound_parameters.arguments
234            arguments.update(unknown_args)
235            arguments.update(unknown_kwargs)
236            arguments.pop("self", None)
237            final_parameters = _CatchAllUndefinedParameters.handle_from_dict(
238                obj, arguments)
239            original_init(self, **final_parameters)
240
241        return _catch_all_init
242
243    @staticmethod
244    def _get_catch_all_field(cls) -> Field:
245        catch_all_fields = list(
246            filter(lambda f: f.type == Optional[CatchAllVar], fields(cls)))
247        number_of_catch_all_fields = len(catch_all_fields)
248        if number_of_catch_all_fields == 0:
249            raise UndefinedParameterError(
250                "No field of type dataclasses_json.CatchAll defined")
251        elif number_of_catch_all_fields > 1:
252            raise UndefinedParameterError(
253                f"Multiple catch-all fields supplied: "
254                f"{number_of_catch_all_fields}.")
255        else:
256            return catch_all_fields[0]
257
258
259class Undefined(Enum):
260    """
261    Choose the behavior what happens when an undefined parameter is encountered
262    during class initialization.
263    """
264    INCLUDE = _CatchAllUndefinedParameters
265    RAISE = _RaiseUndefinedParameters
266    EXCLUDE = _IgnoreUndefinedParameters
267
268
269class UndefinedParameterError(ValidationError):
270    """
271    Raised when something has gone wrong handling undefined parameters.
272    """
273    pass
274