1from __future__ import absolute_import, print_function
2from pony.py23compat import PY2, imap, basestring, unicode, pickle, iteritems
3
4import io, re, os.path, sys, inspect, types, warnings
5
6from datetime import datetime
7from itertools import count as _count
8from inspect import isfunction
9from time import strptime
10from collections import defaultdict
11from functools import update_wrapper, wraps
12from xml.etree import cElementTree
13from copy import deepcopy
14
15import pony
16from pony import options
17
18from pony.thirdparty.compiler import ast
19from pony.thirdparty.decorator import decorator as _decorator
20
21if pony.MODE.startswith('GAE-'): localbase = object
22else: from threading import local as localbase
23
24
25class PonyDeprecationWarning(DeprecationWarning):
26    pass
27
28def deprecated(stacklevel, message):
29    warnings.warn(message, PonyDeprecationWarning, stacklevel)
30
31warnings.simplefilter('once', PonyDeprecationWarning)
32
33def _improved_decorator(caller, func):
34    if isfunction(func):
35        return _decorator(caller, func)
36    def pony_wrapper(*args, **kwargs):
37        return caller(func, *args, **kwargs)
38    return pony_wrapper
39
40def decorator(caller, func=None):
41    if func is not None:
42        return _improved_decorator(caller, func)
43    def new_decorator(func):
44        return _improved_decorator(caller, func)
45    if isfunction(caller):
46        update_wrapper(new_decorator, caller)
47    return new_decorator
48
49def decorator_with_params(dec):
50    def parameterized_decorator(*args, **kwargs):
51        if len(args) == 1 and isfunction(args[0]) and not kwargs:
52            return decorator(dec(), args[0])
53        return decorator(dec(*args, **kwargs))
54    return parameterized_decorator
55
56@decorator
57def cut_traceback(func, *args, **kwargs):
58    if not options.CUT_TRACEBACK:
59        return func(*args, **kwargs)
60
61    try: return func(*args, **kwargs)
62    except AssertionError: raise
63    except Exception:
64        exc_type, exc, tb = sys.exc_info()
65        full_tb = tb
66        last_pony_tb = None
67        try:
68            while tb.tb_next:
69                module_name = tb.tb_frame.f_globals['__name__']
70                if module_name == 'pony' or (module_name is not None  # may be None during import
71                                             and module_name.startswith('pony.')):
72                    last_pony_tb = tb
73                tb = tb.tb_next
74            if last_pony_tb is None: raise
75            module_name = tb.tb_frame.f_globals.get('__name__') or ''
76            if module_name.startswith('pony.utils') and tb.tb_frame.f_code.co_name == 'throw':
77                reraise(exc_type, exc, last_pony_tb)
78            reraise(exc_type, exc, full_tb)
79        finally:
80            del exc, full_tb, tb, last_pony_tb
81
82cut_traceback_depth = 2
83
84if pony.MODE != 'INTERACTIVE':
85    cut_traceback_depth = 0
86    def cut_traceback(func):
87        return func
88
89if PY2:
90    exec('''def reraise(exc_type, exc, tb):
91    try: raise exc_type, exc, tb
92    finally: del tb''')
93else:
94    def reraise(exc_type, exc, tb):
95        try: raise exc.with_traceback(tb)
96        finally: del exc, tb
97
98def throw(exc_type, *args, **kwargs):
99    if isinstance(exc_type, Exception):
100        assert not args and not kwargs
101        exc = exc_type
102    else: exc = exc_type(*args, **kwargs)
103    exc.__cause__ = None
104    try:
105        if not (pony.MODE == 'INTERACTIVE' and options.CUT_TRACEBACK):
106            raise exc
107        else:
108            raise exc  # Set "pony.options.CUT_TRACEBACK = False" to see full traceback
109    finally: del exc
110
111def truncate_repr(s, max_len=100):
112    s = repr(s)
113    return s if len(s) <= max_len else s[:max_len-3] + '...'
114
115codeobjects = {}
116
117def get_codeobject_id(codeobject):
118    codeobject_id = id(codeobject)
119    if codeobject_id not in codeobjects:
120        codeobjects[codeobject_id] = codeobject
121    return codeobject_id
122
123lambda_args_cache = {}
124
125def get_lambda_args(func):
126    if type(func) is types.FunctionType:
127        codeobject = func.func_code if PY2 else func.__code__
128        cache_key = get_codeobject_id(codeobject)
129    elif isinstance(func, ast.Lambda):
130        cache_key = func
131    else: assert False  # pragma: no cover
132
133    names = lambda_args_cache.get(cache_key)
134    if names is not None: return names
135
136    if type(func) is types.FunctionType:
137        if hasattr(inspect, 'signature'):
138            names, argsname, kwname, defaults = [], None, None, []
139            for p in inspect.signature(func).parameters.values():
140                if p.default is not p.empty:
141                    defaults.append(p.default)
142
143                if p.kind == p.POSITIONAL_OR_KEYWORD:
144                    names.append(p.name)
145                elif p.kind == p.VAR_POSITIONAL:
146                    argsname = p.name
147                elif p.kind == p.VAR_KEYWORD:
148                    kwname = p.name
149                elif p.kind == p.POSITIONAL_ONLY:
150                    throw(TypeError, 'Positional-only arguments like %s are not supported' % p.name)
151                elif p.kind == p.KEYWORD_ONLY:
152                    throw(TypeError, 'Keyword-only arguments like %s are not supported' % p.name)
153                else: assert False
154        else:
155            names, argsname, kwname, defaults = inspect.getargspec(func)
156    elif isinstance(func, ast.Lambda):
157        names = func.argnames
158        if func.kwargs: names, kwname = names[:-1], names[-1]
159        else: kwname = None
160        if func.varargs: names, argsname = names[:-1], names[-1]
161        else: argsname = None
162        defaults = func.defaults
163    else: assert False  # pragma: no cover
164    if argsname: throw(TypeError, '*%s is not supported' % argsname)
165    if kwname: throw(TypeError, '**%s is not supported' % kwname)
166    if defaults: throw(TypeError, 'Defaults are not supported')
167
168    lambda_args_cache[cache_key] = names
169    return names
170
171def error_method(*args, **kwargs):
172    raise TypeError()
173
174_ident_re = re.compile(r'^[A-Za-z_]\w*\Z')
175
176# is_ident = ident_re.match
177def is_ident(string):
178    'is_ident(string) -> bool'
179    return bool(_ident_re.match(string))
180
181_name_parts_re = re.compile(r'''
182            [A-Z][A-Z0-9]+(?![a-z]) # ACRONYM
183        |   [A-Z][a-z]*             # Capitalized or single capital
184        |   [a-z]+                  # all-lowercase
185        |   [0-9]+                  # numbers
186        |   _+                      # underscores
187        ''', re.VERBOSE)
188
189def split_name(name):
190    "split_name('Some_FUNNYName') -> ['Some', 'FUNNY', 'Name']"
191    if not _ident_re.match(name):
192        raise ValueError('Name is not correct Python identifier')
193    list = _name_parts_re.findall(name)
194    if not (list[0].strip('_') and list[-1].strip('_')):
195        raise ValueError('Name must not starting or ending with underscores')
196    return [ s for s in list if s.strip('_') ]
197
198def uppercase_name(name):
199    "uppercase_name('Some_FUNNYName') -> 'SOME_FUNNY_NAME'"
200    return '_'.join(s.upper() for s in split_name(name))
201
202def lowercase_name(name):
203    "uppercase_name('Some_FUNNYName') -> 'some_funny_name'"
204    return '_'.join(s.lower() for s in split_name(name))
205
206def camelcase_name(name):
207    "uppercase_name('Some_FUNNYName') -> 'SomeFunnyName'"
208    return ''.join(s.capitalize() for s in split_name(name))
209
210def mixedcase_name(name):
211    "mixedcase_name('Some_FUNNYName') -> 'someFunnyName'"
212    list = split_name(name)
213    return list[0].lower() + ''.join(s.capitalize() for s in list[1:])
214
215def import_module(name):
216    "import_module('a.b.c') -> <module a.b.c>"
217    mod = sys.modules.get(name)
218    if mod is not None: return mod
219    mod = __import__(name)
220    components = name.split('.')
221    for comp in components[1:]: mod = getattr(mod, comp)
222    return mod
223
224if sys.platform == 'win32':
225      _absolute_re = re.compile(r'^(?:[A-Za-z]:)?[\\/]')
226else: _absolute_re = re.compile(r'^/')
227
228def is_absolute_path(filename):
229    return bool(_absolute_re.match(filename))
230
231def absolutize_path(filename, frame_depth):
232    if is_absolute_path(filename): return filename
233    code_filename = sys._getframe(frame_depth+1).f_code.co_filename
234    if not is_absolute_path(code_filename):
235        if code_filename.startswith('<') and code_filename.endswith('>'):
236            if pony.MODE == 'INTERACTIVE': raise ValueError(
237                'When in interactive mode, please provide absolute file path. Got: %r' % filename)
238            raise EnvironmentError('Unexpected module filename, which is not absolute file path: %r' % code_filename)
239    code_path = os.path.dirname(code_filename)
240    return os.path.join(code_path, filename)
241
242def current_timestamp():
243    return datetime2timestamp(datetime.now())
244
245def datetime2timestamp(d):
246    result = d.isoformat(' ')
247    if len(result) == 19: return result + '.000000'
248    return result
249
250def timestamp2datetime(t):
251    time_tuple = strptime(t[:19], '%Y-%m-%d %H:%M:%S')
252    microseconds = int((t[20:26] + '000000')[:6])
253    return datetime(*(time_tuple[:6] + (microseconds,)))
254
255expr1_re = re.compile(r'''
256        ([A-Za-z_]\w*)  # identifier (group 1)
257    |   ([(])           # open parenthesis (group 2)
258    ''', re.VERBOSE)
259
260expr2_re = re.compile(r'''
261     \s*(?:
262            (;)                 # semicolon (group 1)
263        |   (\.\s*[A-Za-z_]\w*) # dot + identifier (group 2)
264        |   ([([])              # open parenthesis or braces (group 3)
265        )
266    ''', re.VERBOSE)
267
268expr3_re = re.compile(r"""
269        [()[\]]                   # parenthesis or braces (group 1)
270    |   '''(?:[^\\]|\\.)*?'''     # '''triple-quoted string'''
271    |   \"""(?:[^\\]|\\.)*?\"""   # \"""triple-quoted string\"""
272    |   '(?:[^'\\]|\\.)*?'        # 'string'
273    |   "(?:[^"\\]|\\.)*?"        # "string"
274    """, re.VERBOSE)
275
276def parse_expr(s, pos=0):
277    z = 0
278    match = expr1_re.match(s, pos)
279    if match is None: raise ValueError()
280    start = pos
281    i = match.lastindex
282    if i == 1: pos = match.end()  # identifier
283    elif i == 2: z = 2  # "("
284    else: assert False  # pragma: no cover
285    while True:
286        match = expr2_re.match(s, pos)
287        if match is None: return s[start:pos], z==1
288        pos = match.end()
289        i = match.lastindex
290        if i == 1: return s[start:pos], False  # ";" - explicit end of expression
291        elif i == 2: z = 2  # .identifier
292        elif i == 3:  # "(" or "["
293            pos = match.end()
294            counter = 1
295            open = match.group(i)
296            if open == '(': close = ')'
297            elif open == '[': close = ']'; z = 2
298            else: assert False  # pragma: no cover
299            while True:
300                match = expr3_re.search(s, pos)
301                if match is None: raise ValueError()
302                pos = match.end()
303                x = match.group()
304                if x == open: counter += 1
305                elif x == close:
306                    counter -= 1
307                    if not counter: z += 1; break
308        else: assert False  # pragma: no cover
309
310def tostring(x):
311    if isinstance(x, basestring): return x
312    if hasattr(x, '__unicode__'):
313        try: return unicode(x)
314        except: pass
315    if hasattr(x, 'makeelement'): return cElementTree.tostring(x)
316    try: return str(x)
317    except: pass
318    try: return repr(x)
319    except: pass
320    if type(x) == types.InstanceType: return '<%s instance at 0x%X>' % (x.__class__.__name__)
321    return '<%s object at 0x%X>' % (x.__class__.__name__)
322
323def strjoin(sep, strings, source_encoding='ascii', dest_encoding=None):
324    "Can join mix of unicode and byte strings in different encodings"
325    strings = list(strings)
326    try: return sep.join(strings)
327    except UnicodeDecodeError: pass
328    for i, s in enumerate(strings):
329        if isinstance(s, str):
330            strings[i] = s.decode(source_encoding, 'replace').replace(u'\ufffd', '?')
331    result = sep.join(strings)
332    if dest_encoding is None: return result
333    return result.encode(dest_encoding, 'replace')
334
335def count(*args, **kwargs):
336    if kwargs: return _count(*args, **kwargs)
337    if len(args) != 1: return _count(*args)
338    arg = args[0]
339    if hasattr(arg, 'count'): return arg.count()
340    try: it = iter(arg)
341    except TypeError: return _count(arg)
342    return len(set(it))
343
344def avg(iter):
345    count = 0
346    sum = 0.0
347    for elem in iter:
348        if elem is None: continue
349        sum += elem
350        count += 1
351    if not count: return None
352    return sum / count
353
354def group_concat(items, sep=','):
355    if items is None:
356        return None
357    return str(sep).join(str(item) for item in items)
358
359def coalesce(*args):
360    for arg in args:
361        if arg is not None:
362            return arg
363    return None
364
365def distinct(iter):
366    d = defaultdict(int)
367    for item in iter:
368        d[item] = d[item] + 1
369    return d
370
371def concat(*args):
372    return ''.join(tostring(arg) for arg in args)
373
374def between(x, a, b):
375    return a <= x <= b
376
377def is_utf8(encoding):
378    return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8')
379
380def _persistent_id(obj):
381    if obj is Ellipsis:
382        return "Ellipsis"
383
384def _persistent_load(persid):
385    if persid == "Ellipsis":
386        return Ellipsis
387    raise pickle.UnpicklingError("unsupported persistent object")
388
389def pickle_ast(val):
390    pickled = io.BytesIO()
391    pickler = pickle.Pickler(pickled)
392    pickler.persistent_id = _persistent_id
393    pickler.dump(val)
394    return pickled
395
396def unpickle_ast(pickled):
397    pickled.seek(0)
398    unpickler = pickle.Unpickler(pickled)
399    unpickler.persistent_load = _persistent_load
400    return unpickler.load()
401
402def copy_ast(tree):
403    return unpickle_ast(pickle_ast(tree))
404
405def _hashable_wrap(func):
406    @wraps(func, assigned=('__name__', '__doc__'))
407    def new_func(self, *args, **kwargs):
408        if getattr(self, '_hash', None) is not None:
409            assert False, 'Cannot mutate HashableDict instance after the hash value is calculated'
410        return func(self, *args, **kwargs)
411    return new_func
412
413class HashableDict(dict):
414    def __hash__(self):
415        result = getattr(self, '_hash', None)
416        if result is None:
417            result = 0
418            for key, value in self.items():
419                result ^= hash(key)
420                result ^= hash(value)
421            self._hash = result
422        return result
423    def __deepcopy__(self, memo):
424        if getattr(self, '_hash', None) is not None:
425            return self
426        return HashableDict({deepcopy(key, memo): deepcopy(value, memo)
427                            for key, value in iteritems(self)})
428    __setitem__ = _hashable_wrap(dict.__setitem__)
429    __delitem__ = _hashable_wrap(dict.__delitem__)
430    clear = _hashable_wrap(dict.clear)
431    pop = _hashable_wrap(dict.pop)
432    popitem = _hashable_wrap(dict.popitem)
433    setdefault = _hashable_wrap(dict.setdefault)
434    update = _hashable_wrap(dict.update)
435
436def deref_proxy(value):
437    t = type(value)
438    if t.__name__ == 'LocalProxy' and '_get_current_object' in t.__dict__:
439        # Flask local proxy
440        value = value._get_current_object()
441    elif t.__name__ == 'EntityProxy':
442        # Pony proxy
443        value = value._get_object()
444
445    return value
446
447def deduplicate(value, deduplication_cache):
448    t = type(value)
449    try:
450        return deduplication_cache[t].setdefault(value, value)
451    except:
452        return value
453