1import json 2from typing import Any, Union 3 4import numpy as np 5from pydantic.json import pydantic_encoder 6 7from .importing import which_import 8 9try: 10 import msgpack 11except ModuleNotFoundError: 12 pass 13 14_msgpack_which_msg = "Please install via `conda install msgpack-python`." 15 16## MSGPackExt 17 18 19def msgpackext_encode(obj: Any) -> Any: 20 r""" 21 Encodes an object using pydantic and NumPy array serialization techniques suitable for msgpack. 22 23 Parameters 24 ---------- 25 obj : Any 26 Any object that can be serialized with pydantic and NumPy encoding techniques. 27 28 Returns 29 ------- 30 Any 31 A msgpack compatible form of the object. 32 """ 33 34 # First try pydantic base objects 35 try: 36 return pydantic_encoder(obj) 37 except TypeError: 38 pass 39 40 if isinstance(obj, np.ndarray): 41 if obj.shape: 42 data = {b"_nd_": True, b"dtype": obj.dtype.str, b"data": np.ascontiguousarray(obj).tobytes()} 43 if len(obj.shape) > 1: 44 data[b"shape"] = obj.shape 45 return data 46 47 else: 48 # Converts np.array(5) -> 5 49 return obj.tolist() 50 51 return obj 52 53 54def msgpackext_decode(obj: Any) -> Any: 55 r""" 56 Decodes a msgpack objects from a dictionary representation. 57 58 Parameters 59 ---------- 60 obj : Any 61 An encoded object, likely a dictionary. 62 63 Returns 64 ------- 65 Any 66 The decoded form of the object. 67 """ 68 69 if b"_nd_" in obj: 70 arr = np.frombuffer(obj[b"data"], dtype=obj[b"dtype"]) 71 if b"shape" in obj: 72 arr.shape = obj[b"shape"] 73 74 return arr 75 76 return obj 77 78 79def msgpackext_dumps(data: Any) -> bytes: 80 r"""Safe serialization of a Python object to msgpack binary representation using all known encoders. 81 For NumPy, encodes a specialized object format to encode all shape and type data. 82 83 Parameters 84 ---------- 85 data : Any 86 A encodable python object. 87 88 Returns 89 ------- 90 bytes 91 A msgpack representation of the data in bytes. 92 """ 93 which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg) 94 95 return msgpack.dumps(data, default=msgpackext_encode, use_bin_type=True) 96 97 98def msgpackext_loads(data: bytes) -> Any: 99 r"""Deserializes a msgpack byte representation of known objects into those objects. 100 101 Parameters 102 ---------- 103 data : bytes 104 The serialized msgpack byte array. 105 106 Returns 107 ------- 108 Any 109 The deserialized Python objects. 110 """ 111 which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg) 112 113 return msgpack.loads(data, object_hook=msgpackext_decode, raw=False) 114 115 116## JSON Ext 117 118 119class JSONExtArrayEncoder(json.JSONEncoder): 120 def default(self, obj: Any) -> Any: 121 try: 122 return pydantic_encoder(obj) 123 except TypeError: 124 pass 125 126 if isinstance(obj, np.ndarray): 127 if obj.shape: 128 data = {"_nd_": True, "dtype": obj.dtype.str, "data": np.ascontiguousarray(obj).tobytes().hex()} 129 if len(obj.shape) > 1: 130 data["shape"] = obj.shape 131 return data 132 133 else: 134 # Converts np.array(5) -> 5 135 return obj.tolist() 136 137 return json.JSONEncoder.default(self, obj) 138 139 140def jsonext_decode(obj: Any) -> Any: 141 142 if "_nd_" in obj: 143 arr = np.frombuffer(bytes.fromhex(obj["data"]), dtype=obj["dtype"]) 144 if "shape" in obj: 145 arr.shape = obj["shape"] 146 147 return arr 148 149 return obj 150 151 152def jsonext_dumps(data: Any) -> str: 153 r"""Safe serialization of Python objects to JSON string representation using all known encoders. 154 The JSON serializer uses a custom array syntax rather than flat JSON lists. 155 156 Parameters 157 ---------- 158 data : Any 159 A encodable python object. 160 161 Returns 162 ------- 163 str 164 A JSON representation of the data. 165 """ 166 167 return json.dumps(data, cls=JSONExtArrayEncoder) 168 169 170def jsonext_loads(data: Union[str, bytes]) -> Any: 171 r"""Deserializes a json representation of known objects into those objects. 172 173 Parameters 174 ---------- 175 data : str or bytes 176 The byte-serialized JSON blob. 177 178 Returns 179 ------- 180 Any 181 The deserialized Python objects. 182 """ 183 184 return json.loads(data, object_hook=jsonext_decode) 185 186 187## JSON 188 189 190class JSONArrayEncoder(json.JSONEncoder): 191 def default(self, obj: Any) -> Any: 192 try: 193 return pydantic_encoder(obj) 194 except TypeError: 195 pass 196 197 if isinstance(obj, np.ndarray): 198 if obj.shape: 199 return obj.ravel().tolist() 200 else: 201 return obj.tolist() 202 203 return json.JSONEncoder.default(self, obj) 204 205 206def json_dumps(data: Any) -> str: 207 r"""Safe serialization of a Python dictionary to JSON string representation using all known encoders. 208 209 Parameters 210 ---------- 211 data : Any 212 A encodable python object. 213 214 Returns 215 ------- 216 str 217 A JSON representation of the data. 218 """ 219 220 return json.dumps(data, cls=JSONArrayEncoder) 221 222 223def json_loads(data: str) -> Any: 224 r"""Deserializes a json representation of known objects into those objects. 225 226 Parameters 227 ---------- 228 data : str 229 The serialized JSON blob. 230 231 Returns 232 ------- 233 Any 234 The deserialized Python objects. 235 """ 236 237 # Doesn't hurt anything to try to load JSONext as well 238 return json.loads(data, object_hook=jsonext_decode) 239 240 241## MSGPack 242 243 244def msgpack_encode(obj: Any) -> Any: 245 r""" 246 Encodes an object using pydantic. Converts numpy arrays to plain python lists 247 248 Parameters 249 ---------- 250 obj : Any 251 Any object that can be serialized with pydantic and NumPy encoding techniques. 252 253 Returns 254 ------- 255 Any 256 A msgpack compatible form of the object. 257 """ 258 259 try: 260 return pydantic_encoder(obj) 261 except TypeError: 262 pass 263 264 if isinstance(obj, np.ndarray): 265 if obj.shape: 266 return obj.ravel().tolist() 267 else: 268 return obj.tolist() 269 270 return obj 271 272 273def msgpack_dumps(data: Any) -> str: 274 r"""Safe serialization of a Python object to msgpack binary representation using all known encoders. 275 For NumPy, converts to lists. 276 277 Parameters 278 ---------- 279 data : Any 280 A encodable python object. 281 282 Returns 283 ------- 284 str 285 A msgpack representation of the data in bytes. 286 """ 287 288 which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg) 289 290 return msgpack.dumps(data, default=msgpack_encode, use_bin_type=True) 291 292 293def msgpack_loads(data: str) -> Any: 294 r"""Deserializes a msgpack byte representation of known objects into those objects. 295 296 Parameters 297 ---------- 298 data : bytes 299 The serialized msgpack byte array. 300 301 Returns 302 ------- 303 Any 304 The deserialized Python objects. 305 """ 306 307 which_import("msgpack", raise_error=True, raise_msg=_msgpack_which_msg) 308 309 # Doesn't hurt anything to try to load msgpack-ext as well 310 return msgpack.loads(data, object_hook=msgpackext_decode, raw=False) 311 312 313## Helper functions 314 315 316def serialize(data: Any, encoding: str) -> Union[str, bytes]: 317 r"""Encoding Python objects using the provided encoder. 318 319 Parameters 320 ---------- 321 data : Any 322 A encodable python object. 323 encoding : str 324 The type of encoding to perform: {'json', 'json-ext', 'msgpack-ext'} 325 326 Returns 327 ------- 328 Union[str, bytes] 329 A serialized representation of the data. 330 331 """ 332 if encoding.lower() == "json": 333 return json_dumps(data) 334 elif encoding.lower() == "json-ext": 335 return jsonext_dumps(data) 336 elif encoding.lower() == "msgpack": 337 return msgpack_dumps(data) 338 elif encoding.lower() == "msgpack-ext": 339 return msgpackext_dumps(data) 340 else: 341 raise KeyError(f"Encoding '{encoding}' not understood, valid options: 'json', 'json-ext', 'msgpack-ext'") 342 343 344def deserialize(blob: Union[str, bytes], encoding: str) -> Any: 345 r"""Encoding Python objects using . 346 347 Parameters 348 ---------- 349 blob : Union[str, bytes] 350 The serialized data. 351 encoding : str 352 The type of encoding of the blob: {'json', 'json-ext', 'msgpack'} 353 354 Returns 355 ------- 356 Any 357 The deserialized Python objects. 358 """ 359 if encoding.lower() == "json": 360 assert isinstance(blob, str) 361 return json_loads(blob) 362 elif encoding.lower() == "json-ext": 363 assert isinstance(blob, (str, bytes)) 364 return jsonext_loads(blob) 365 elif encoding.lower() in ["msgpack"]: 366 assert isinstance(blob, bytes) 367 return msgpack_loads(blob) 368 elif encoding.lower() in ["msgpack-ext"]: 369 assert isinstance(blob, bytes) 370 return msgpackext_loads(blob) 371 else: 372 raise KeyError( 373 f"Encoding '{encoding}' not understood, valid options: 'json', 'json-ext', 'msgpack', 'msgpack-ext'" 374 ) 375