1import gzip
2import io
3import warnings
4from collections import OrderedDict
5from functools import partial
6from importlib import import_module
7from sys import version_info, version
8
9
10class JsonTricksDeprecation(UserWarning):
11	""" Special deprecation warning because the built-in one is ignored by default """
12	def __init__(self, msg):
13		super(JsonTricksDeprecation, self).__init__(msg)
14
15
16class hashodict(OrderedDict):
17	"""
18	This dictionary is hashable. It should NOT be mutated, or all kinds of weird
19	bugs may appear. This is not enforced though, it's only used for encoding.
20	"""
21	def __hash__(self):
22		return hash(frozenset(self.items()))
23
24
25try:
26	from inspect import signature
27except ImportError:
28	try:
29		from inspect import getfullargspec
30	except ImportError:
31		from inspect import getargspec, isfunction
32		def get_arg_names(callable):
33			if type(callable) == partial and version_info[0] == 2:
34				if not hasattr(get_arg_names, '__warned_partial_argspec'):
35					get_arg_names.__warned_partial_argspec = True
36					warnings.warn("'functools.partial' and 'inspect.getargspec' are not compatible in this Python version; "
37						"ignoring the 'partial' wrapper when inspecting arguments of {}, which can lead to problems".format(callable))
38				return set(getargspec(callable.func).args)
39			if isfunction(callable):
40				argspec = getargspec(callable)
41			else:
42				argspec = getargspec(callable.__call__)
43			return set(argspec.args)
44	else:
45		#todo: this is not covered in test case (py 3+ uses `signature`, py2 `getfullargspec`); consider removing it
46		def get_arg_names(callable):
47			argspec = getfullargspec(callable)
48			return set(argspec.args) | set(argspec.kwonlyargs)
49else:
50	def get_arg_names(callable):
51		sig = signature(callable)
52		return set(sig.parameters.keys())
53
54
55def filtered_wrapper(encoder):
56	"""
57	Filter kwargs passed to encoder.
58	"""
59	if hasattr(encoder, "default"):
60		encoder = encoder.default
61	elif not hasattr(encoder, '__call__'):
62		raise TypeError('`obj_encoder` {0:} does not have `default` method and is not callable'.format(enc))
63	names = get_arg_names(encoder)
64
65	def wrapper(*args, **kwargs):
66		return encoder(*args, **{k: v for k, v in kwargs.items() if k in names})
67	return wrapper
68
69
70class NoNumpyException(Exception):
71	""" Trying to use numpy features, but numpy cannot be found. """
72
73
74class NoPandasException(Exception):
75	""" Trying to use pandas features, but pandas cannot be found. """
76
77
78class NoEnumException(Exception):
79	""" Trying to use enum features, but enum cannot be found. """
80
81
82class NoPathlibException(Exception):
83	""" Trying to use pathlib features, but pathlib cannot be found. """
84
85
86class ClassInstanceHookBase(object):
87	def get_cls_from_instance_type(self, mod, name, cls_lookup_map):
88		Cls = ValueError()
89		if mod is None:
90			try:
91				Cls = getattr((__import__('__main__')), name)
92			except (ImportError, AttributeError):
93				if name not in cls_lookup_map:
94					raise ImportError(('class {0:s} seems to have been exported from the main file, which means '
95						'it has no module/import path set; you need to provide loads argument'
96						'`cls_lookup_map={{"{0}": Class}}` to locate the class').format(name))
97				Cls = cls_lookup_map[name]
98		else:
99			imp_err = None
100			try:
101				module = import_module('{0:}'.format(mod, name))
102			except ImportError as err:
103				imp_err = ('encountered import error "{0:}" while importing "{1:}" to decode a json file; perhaps '
104					'it was encoded in a different environment where {1:}.{2:} was available').format(err, mod, name)
105			else:
106				if hasattr(module, name):
107					Cls = getattr(module, name)
108				else:
109					imp_err = 'imported "{0:}" but could find "{1:}" inside while decoding a json file (found {2:})'.format(
110						module, name, ', '.join(attr for attr in dir(module) if not attr.startswith('_')))
111			if imp_err:
112				Cls = cls_lookup_map.get(name, None)
113				if Cls is None:
114					raise ImportError('{}; add the class to `cls_lookup_map={{"{}": Class}}` argument'.format(imp_err, name))
115		return Cls
116
117
118def get_scalar_repr(npscalar):
119	return hashodict((
120		('__ndarray__', npscalar.item()),
121		('dtype', str(npscalar.dtype)),
122		('shape', ()),
123	))
124
125
126def encode_scalars_inplace(obj):
127	"""
128	Searches a data structure of lists, tuples and dicts for numpy scalars
129	and replaces them by their dictionary representation, which can be loaded
130	by json-tricks. This happens in-place (the object is changed, use a copy).
131	"""
132	from numpy import generic, complex64, complex128
133	if isinstance(obj, (generic, complex64, complex128)):
134		return get_scalar_repr(obj)
135	if isinstance(obj, dict):
136		for key, val in tuple(obj.items()):
137			obj[key] = encode_scalars_inplace(val)
138		return obj
139	if isinstance(obj, list):
140		for k, val in enumerate(obj):
141			obj[k] = encode_scalars_inplace(val)
142		return obj
143	if isinstance(obj, (tuple, set)):
144		return type(obj)(encode_scalars_inplace(val) for val in obj)
145	return obj
146
147
148def encode_intenums_inplace(obj):
149	"""
150	Searches a data structure of lists, tuples and dicts for IntEnum
151	and replaces them by their dictionary representation, which can be loaded
152	by json-tricks. This happens in-place (the object is changed, use a copy).
153	"""
154	from enum import IntEnum
155	from json_tricks import encoders
156	if isinstance(obj, IntEnum):
157		return encoders.enum_instance_encode(obj)
158	if isinstance(obj, dict):
159		for key, val in obj.items():
160			obj[key] = encode_intenums_inplace(val)
161		return obj
162	if isinstance(obj, list):
163		for index, val in enumerate(obj):
164			obj[index] = encode_intenums_inplace(val)
165		return obj
166	if isinstance(obj, (tuple, set)):
167		return type(obj)(encode_intenums_inplace(val) for val in obj)
168	return obj
169
170
171def get_module_name_from_object(obj):
172	mod = obj.__class__.__module__
173	if mod == '__main__':
174		mod = None
175		warnings.warn(('class {0:} seems to have been defined in the main file; unfortunately this means'
176			' that it\'s module/import path is unknown, so you might have to provide cls_lookup_map when '
177			'decoding').format(obj.__class__))
178	return mod
179
180
181def nested_index(collection, indices):
182	for i in indices:
183		collection = collection[i]
184	return collection
185
186
187def dict_default(dictionary, key, default_value):
188	if key not in dictionary:
189		dictionary[key] = default_value
190
191
192def gzip_compress(data, compresslevel):
193	"""
194	Do gzip compression, without the timestamp. Similar to gzip.compress, but without timestamp, and also before py3.2.
195	"""
196	buf = io.BytesIO()
197	with gzip.GzipFile(fileobj=buf, mode='wb', compresslevel=compresslevel, mtime=0) as fh:
198		fh.write(data)
199	return buf.getvalue()
200
201
202def gzip_decompress(data):
203	"""
204	Do gzip decompression, without the timestamp. Just like gzip.decompress, but that's py3.2+.
205	"""
206	with gzip.GzipFile(fileobj=io.BytesIO(data)) as f:
207		return f.read()
208
209
210is_py3 = (version[:2] == '3.')
211str_type = str if is_py3 else (basestring, unicode,)
212
213