1import gzip 2import io 3import warnings 4from collections import OrderedDict 5from functools import partial 6from importlib import import_module 7from sys import version_info, version 8 9 10class JsonTricksDeprecation(UserWarning): 11 """ Special deprecation warning because the built-in one is ignored by default """ 12 def __init__(self, msg): 13 super(JsonTricksDeprecation, self).__init__(msg) 14 15 16class hashodict(OrderedDict): 17 """ 18 This dictionary is hashable. It should NOT be mutated, or all kinds of weird 19 bugs may appear. This is not enforced though, it's only used for encoding. 20 """ 21 def __hash__(self): 22 return hash(frozenset(self.items())) 23 24 25try: 26 from inspect import signature 27except ImportError: 28 try: 29 from inspect import getfullargspec 30 except ImportError: 31 from inspect import getargspec, isfunction 32 def get_arg_names(callable): 33 if type(callable) == partial and version_info[0] == 2: 34 if not hasattr(get_arg_names, '__warned_partial_argspec'): 35 get_arg_names.__warned_partial_argspec = True 36 warnings.warn("'functools.partial' and 'inspect.getargspec' are not compatible in this Python version; " 37 "ignoring the 'partial' wrapper when inspecting arguments of {}, which can lead to problems".format(callable)) 38 return set(getargspec(callable.func).args) 39 if isfunction(callable): 40 argspec = getargspec(callable) 41 else: 42 argspec = getargspec(callable.__call__) 43 return set(argspec.args) 44 else: 45 #todo: this is not covered in test case (py 3+ uses `signature`, py2 `getfullargspec`); consider removing it 46 def get_arg_names(callable): 47 argspec = getfullargspec(callable) 48 return set(argspec.args) | set(argspec.kwonlyargs) 49else: 50 def get_arg_names(callable): 51 sig = signature(callable) 52 return set(sig.parameters.keys()) 53 54 55def filtered_wrapper(encoder): 56 """ 57 Filter kwargs passed to encoder. 58 """ 59 if hasattr(encoder, "default"): 60 encoder = encoder.default 61 elif not hasattr(encoder, '__call__'): 62 raise TypeError('`obj_encoder` {0:} does not have `default` method and is not callable'.format(enc)) 63 names = get_arg_names(encoder) 64 65 def wrapper(*args, **kwargs): 66 return encoder(*args, **{k: v for k, v in kwargs.items() if k in names}) 67 return wrapper 68 69 70class NoNumpyException(Exception): 71 """ Trying to use numpy features, but numpy cannot be found. """ 72 73 74class NoPandasException(Exception): 75 """ Trying to use pandas features, but pandas cannot be found. """ 76 77 78class NoEnumException(Exception): 79 """ Trying to use enum features, but enum cannot be found. """ 80 81 82class NoPathlibException(Exception): 83 """ Trying to use pathlib features, but pathlib cannot be found. """ 84 85 86class ClassInstanceHookBase(object): 87 def get_cls_from_instance_type(self, mod, name, cls_lookup_map): 88 Cls = ValueError() 89 if mod is None: 90 try: 91 Cls = getattr((__import__('__main__')), name) 92 except (ImportError, AttributeError): 93 if name not in cls_lookup_map: 94 raise ImportError(('class {0:s} seems to have been exported from the main file, which means ' 95 'it has no module/import path set; you need to provide loads argument' 96 '`cls_lookup_map={{"{0}": Class}}` to locate the class').format(name)) 97 Cls = cls_lookup_map[name] 98 else: 99 imp_err = None 100 try: 101 module = import_module('{0:}'.format(mod, name)) 102 except ImportError as err: 103 imp_err = ('encountered import error "{0:}" while importing "{1:}" to decode a json file; perhaps ' 104 'it was encoded in a different environment where {1:}.{2:} was available').format(err, mod, name) 105 else: 106 if hasattr(module, name): 107 Cls = getattr(module, name) 108 else: 109 imp_err = 'imported "{0:}" but could find "{1:}" inside while decoding a json file (found {2:})'.format( 110 module, name, ', '.join(attr for attr in dir(module) if not attr.startswith('_'))) 111 if imp_err: 112 Cls = cls_lookup_map.get(name, None) 113 if Cls is None: 114 raise ImportError('{}; add the class to `cls_lookup_map={{"{}": Class}}` argument'.format(imp_err, name)) 115 return Cls 116 117 118def get_scalar_repr(npscalar): 119 return hashodict(( 120 ('__ndarray__', npscalar.item()), 121 ('dtype', str(npscalar.dtype)), 122 ('shape', ()), 123 )) 124 125 126def encode_scalars_inplace(obj): 127 """ 128 Searches a data structure of lists, tuples and dicts for numpy scalars 129 and replaces them by their dictionary representation, which can be loaded 130 by json-tricks. This happens in-place (the object is changed, use a copy). 131 """ 132 from numpy import generic, complex64, complex128 133 if isinstance(obj, (generic, complex64, complex128)): 134 return get_scalar_repr(obj) 135 if isinstance(obj, dict): 136 for key, val in tuple(obj.items()): 137 obj[key] = encode_scalars_inplace(val) 138 return obj 139 if isinstance(obj, list): 140 for k, val in enumerate(obj): 141 obj[k] = encode_scalars_inplace(val) 142 return obj 143 if isinstance(obj, (tuple, set)): 144 return type(obj)(encode_scalars_inplace(val) for val in obj) 145 return obj 146 147 148def encode_intenums_inplace(obj): 149 """ 150 Searches a data structure of lists, tuples and dicts for IntEnum 151 and replaces them by their dictionary representation, which can be loaded 152 by json-tricks. This happens in-place (the object is changed, use a copy). 153 """ 154 from enum import IntEnum 155 from json_tricks import encoders 156 if isinstance(obj, IntEnum): 157 return encoders.enum_instance_encode(obj) 158 if isinstance(obj, dict): 159 for key, val in obj.items(): 160 obj[key] = encode_intenums_inplace(val) 161 return obj 162 if isinstance(obj, list): 163 for index, val in enumerate(obj): 164 obj[index] = encode_intenums_inplace(val) 165 return obj 166 if isinstance(obj, (tuple, set)): 167 return type(obj)(encode_intenums_inplace(val) for val in obj) 168 return obj 169 170 171def get_module_name_from_object(obj): 172 mod = obj.__class__.__module__ 173 if mod == '__main__': 174 mod = None 175 warnings.warn(('class {0:} seems to have been defined in the main file; unfortunately this means' 176 ' that it\'s module/import path is unknown, so you might have to provide cls_lookup_map when ' 177 'decoding').format(obj.__class__)) 178 return mod 179 180 181def nested_index(collection, indices): 182 for i in indices: 183 collection = collection[i] 184 return collection 185 186 187def dict_default(dictionary, key, default_value): 188 if key not in dictionary: 189 dictionary[key] = default_value 190 191 192def gzip_compress(data, compresslevel): 193 """ 194 Do gzip compression, without the timestamp. Similar to gzip.compress, but without timestamp, and also before py3.2. 195 """ 196 buf = io.BytesIO() 197 with gzip.GzipFile(fileobj=buf, mode='wb', compresslevel=compresslevel, mtime=0) as fh: 198 fh.write(data) 199 return buf.getvalue() 200 201 202def gzip_decompress(data): 203 """ 204 Do gzip decompression, without the timestamp. Just like gzip.decompress, but that's py3.2+. 205 """ 206 with gzip.GzipFile(fileobj=io.BytesIO(data)) as f: 207 return f.read() 208 209 210is_py3 = (version[:2] == '3.') 211str_type = str if is_py3 else (basestring, unicode,) 212 213