1import json
2from typing import Any, Union
3
4import numpy as np
5from pydantic.json import pydantic_encoder
6
7from .importing import which_import
8
9try:
10    import msgpack
11except ModuleNotFoundError:
12    pass
13
14_msgpack_which_msg = "Please install via `conda install msgpack-python`."
15
16## MSGPackExt
17
18
19def msgpackext_encode(obj: Any) -> Any:
20    r"""
21    Encodes an object using pydantic and NumPy array serialization techniques suitable for msgpack.
22
23    Parameters
24    ----------
25    obj : Any
26        Any object that can be serialized with pydantic and NumPy encoding techniques.
27
28    Returns
29    -------
30    Any
31        A msgpack compatible form of the object.
32    """
33
34    # First try pydantic base objects
35    try:
36        return pydantic_encoder(obj)
37    except TypeError:
38        pass
39
40    if isinstance(obj, np.ndarray):
41        if obj.shape:
42            data = {b"_nd_": True, b"dtype": obj.dtype.str, b"data": np.ascontiguousarray(obj).tobytes()}
43            if len(obj.shape) > 1:
44                data[b"shape"] = obj.shape
45            return data
46
47        else:
48            # Converts np.array(5) -> 5
49            return obj.tolist()
50
51    return obj
52
53
54def msgpackext_decode(obj: Any) -> Any:
55    r"""
56    Decodes a msgpack objects from a dictionary representation.
57
58    Parameters
59    ----------
60    obj : Any
61        An encoded object, likely a dictionary.
62
63    Returns
64    -------
65    Any
66        The decoded form of the object.
67    """
68
69    if b"_nd_" in obj:
70        arr = np.frombuffer(obj[b"data"], dtype=obj[b"dtype"])
71        if b"shape" in obj:
72            arr.shape = obj[b"shape"]
73
74        return arr
75
76    return obj
77
78
79def msgpackext_dumps(data: Any) -> bytes:
80    r"""Safe serialization of a Python object to msgpack binary representation using all known encoders.
81    For NumPy, encodes a specialized object format to encode all shape and type data.
82
83    Parameters
84    ----------
85    data : Any
86        A encodable python object.
87
88    Returns
89    -------
90    bytes
91        A msgpack representation of the data in bytes.
92    """
93    which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg)
94
95    return msgpack.dumps(data, default=msgpackext_encode, use_bin_type=True)
96
97
98def msgpackext_loads(data: bytes) -> Any:
99    r"""Deserializes a msgpack byte representation of known objects into those objects.
100
101    Parameters
102    ----------
103    data : bytes
104        The serialized msgpack byte array.
105
106    Returns
107    -------
108    Any
109        The deserialized Python objects.
110    """
111    which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg)
112
113    return msgpack.loads(data, object_hook=msgpackext_decode, raw=False)
114
115
116## JSON Ext
117
118
119class JSONExtArrayEncoder(json.JSONEncoder):
120    def default(self, obj: Any) -> Any:
121        try:
122            return pydantic_encoder(obj)
123        except TypeError:
124            pass
125
126        if isinstance(obj, np.ndarray):
127            if obj.shape:
128                data = {"_nd_": True, "dtype": obj.dtype.str, "data": np.ascontiguousarray(obj).tobytes().hex()}
129                if len(obj.shape) > 1:
130                    data["shape"] = obj.shape
131                return data
132
133            else:
134                # Converts np.array(5) -> 5
135                return obj.tolist()
136
137        return json.JSONEncoder.default(self, obj)
138
139
140def jsonext_decode(obj: Any) -> Any:
141
142    if "_nd_" in obj:
143        arr = np.frombuffer(bytes.fromhex(obj["data"]), dtype=obj["dtype"])
144        if "shape" in obj:
145            arr.shape = obj["shape"]
146
147        return arr
148
149    return obj
150
151
152def jsonext_dumps(data: Any) -> str:
153    r"""Safe serialization of Python objects to JSON string representation using all known encoders.
154    The JSON serializer uses a custom array syntax rather than flat JSON lists.
155
156    Parameters
157    ----------
158    data : Any
159        A encodable python object.
160
161    Returns
162    -------
163    str
164        A JSON representation of the data.
165    """
166
167    return json.dumps(data, cls=JSONExtArrayEncoder)
168
169
170def jsonext_loads(data: Union[str, bytes]) -> Any:
171    r"""Deserializes a json representation of known objects into those objects.
172
173    Parameters
174    ----------
175    data : str or bytes
176        The byte-serialized JSON blob.
177
178    Returns
179    -------
180    Any
181        The deserialized Python objects.
182    """
183
184    return json.loads(data, object_hook=jsonext_decode)
185
186
187## JSON
188
189
190class JSONArrayEncoder(json.JSONEncoder):
191    def default(self, obj: Any) -> Any:
192        try:
193            return pydantic_encoder(obj)
194        except TypeError:
195            pass
196
197        if isinstance(obj, np.ndarray):
198            if obj.shape:
199                return obj.ravel().tolist()
200            else:
201                return obj.tolist()
202
203        return json.JSONEncoder.default(self, obj)
204
205
206def json_dumps(data: Any) -> str:
207    r"""Safe serialization of a Python dictionary to JSON string representation using all known encoders.
208
209    Parameters
210    ----------
211    data : Any
212        A encodable python object.
213
214    Returns
215    -------
216    str
217        A JSON representation of the data.
218    """
219
220    return json.dumps(data, cls=JSONArrayEncoder)
221
222
223def json_loads(data: str) -> Any:
224    r"""Deserializes a json representation of known objects into those objects.
225
226    Parameters
227    ----------
228    data : str
229        The serialized JSON blob.
230
231    Returns
232    -------
233    Any
234        The deserialized Python objects.
235    """
236
237    # Doesn't hurt anything to try to load JSONext as well
238    return json.loads(data, object_hook=jsonext_decode)
239
240
241## MSGPack
242
243
244def msgpack_encode(obj: Any) -> Any:
245    r"""
246    Encodes an object using pydantic. Converts numpy arrays to plain python lists
247
248    Parameters
249    ----------
250    obj : Any
251        Any object that can be serialized with pydantic and NumPy encoding techniques.
252
253    Returns
254    -------
255    Any
256        A msgpack compatible form of the object.
257    """
258
259    try:
260        return pydantic_encoder(obj)
261    except TypeError:
262        pass
263
264    if isinstance(obj, np.ndarray):
265        if obj.shape:
266            return obj.ravel().tolist()
267        else:
268            return obj.tolist()
269
270    return obj
271
272
273def msgpack_dumps(data: Any) -> str:
274    r"""Safe serialization of a Python object to msgpack binary representation using all known encoders.
275    For NumPy, converts to lists.
276
277    Parameters
278    ----------
279    data : Any
280        A encodable python object.
281
282    Returns
283    -------
284    str
285        A msgpack representation of the data in bytes.
286    """
287
288    which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg)
289
290    return msgpack.dumps(data, default=msgpack_encode, use_bin_type=True)
291
292
293def msgpack_loads(data: str) -> Any:
294    r"""Deserializes a msgpack byte representation of known objects into those objects.
295
296    Parameters
297    ----------
298    data : bytes
299        The serialized msgpack byte array.
300
301    Returns
302    -------
303    Any
304        The deserialized Python objects.
305    """
306
307    which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg)
308
309    # Doesn't hurt anything to try to load msgpack-ext as well
310    return msgpack.loads(data, object_hook=msgpackext_decode, raw=False)
311
312
313## Helper functions
314
315
316def serialize(data: Any, encoding: str) -> Union[str, bytes]:
317    r"""Encoding Python objects using the provided encoder.
318
319    Parameters
320    ----------
321    data : Any
322        A encodable python object.
323    encoding : str
324        The type of encoding to perform: {'json', 'json-ext', 'msgpack-ext'}
325
326    Returns
327    -------
328    Union[str, bytes]
329        A serialized representation of the data.
330
331    """
332    if encoding.lower() == "json":
333        return json_dumps(data)
334    elif encoding.lower() == "json-ext":
335        return jsonext_dumps(data)
336    elif encoding.lower() == "msgpack":
337        return msgpack_dumps(data)
338    elif encoding.lower() == "msgpack-ext":
339        return msgpackext_dumps(data)
340    else:
341        raise KeyError(f"Encoding '{encoding}' not understood, valid options: 'json', 'json-ext', 'msgpack-ext'")
342
343
344def deserialize(blob: Union[str, bytes], encoding: str) -> Any:
345    r"""Encoding Python objects using .
346
347    Parameters
348    ----------
349    blob : Union[str, bytes]
350        The serialized data.
351    encoding : str
352        The type of encoding of the blob: {'json', 'json-ext', 'msgpack'}
353
354    Returns
355    -------
356    Any
357        The deserialized Python objects.
358    """
359    if encoding.lower() == "json":
360        assert isinstance(blob, str)
361        return json_loads(blob)
362    elif encoding.lower() == "json-ext":
363        assert isinstance(blob, (str, bytes))
364        return jsonext_loads(blob)
365    elif encoding.lower() in ["msgpack"]:
366        assert isinstance(blob, bytes)
367        return msgpack_loads(blob)
368    elif encoding.lower() in ["msgpack-ext"]:
369        assert isinstance(blob, bytes)
370        return msgpackext_loads(blob)
371    else:
372        raise KeyError(
373            f"Encoding '{encoding}' not understood, valid options: 'json', 'json-ext', 'msgpack', 'msgpack-ext'"
374        )
375