1from contextlib import contextmanager
2import os
3import shutil
4import tempfile
5import struct
6
7
8def raises(exc, lamda):
9    try:
10        lamda()
11        return False
12    except exc:
13        return True
14
15
16@contextmanager
17def tmpfile(extension=''):
18    extension = '.' + extension.lstrip('.')
19    handle, filename = tempfile.mkstemp(extension)
20    os.close(handle)
21    os.remove(filename)
22
23    try:
24        yield filename
25    finally:
26        if os.path.exists(filename):
27            if os.path.isdir(filename):
28                shutil.rmtree(filename)
29            else:
30                os.remove(filename)
31
32
33def frame(bytes):
34    """ Pack the length of the bytes in front of the bytes
35
36    TODO: This does a full copy.  This should maybe be inlined somehow
37    wherever this gets used instead.  My laptop shows a data bandwidth of
38    2GB/s
39    """
40    return struct.pack('Q', len(bytes)) + bytes
41
42
43def framesplit(bytes):
44    """ Split buffer into frames of concatenated chunks
45
46    >>> data = frame(b'Hello') + frame(b'World')
47    >>> list(framesplit(data))  # doctest: +SKIP
48    [b'Hello', b'World']
49    """
50    i = 0; n = len(bytes)
51    chunks = list()
52    while i < n:
53        nbytes = struct.unpack('Q', bytes[i:i+8])[0]
54        i += 8
55        yield bytes[i: i + nbytes]
56        i += nbytes
57
58
59def partition_all(n, bytes):
60    """ Partition bytes into evenly sized blocks
61
62    The final block holds the remainder and so may not be of equal size
63
64    >>> list(partition_all(2, b'Hello'))
65    [b'He', b'll', b'o']
66
67    See Also:
68        toolz.partition_all
69    """
70    if len(bytes) < n:  # zero copy fast common case
71        yield bytes
72    else:
73        for i in range(0, len(bytes), n):
74            yield bytes[i: i+n]
75
76
77@contextmanager
78def ignoring(*exc):
79    try:
80        yield
81    except exc:
82        pass
83
84
85@contextmanager
86def do_nothing(*args, **kwargs):
87    yield
88
89
90def nested_get(ind, coll, lazy=False):
91    """ Get nested index from collection
92
93    Examples
94    --------
95
96    >>> nested_get(1, 'abc')
97    'b'
98    >>> nested_get([1, 0], 'abc')
99    ['b', 'a']
100    >>> nested_get([[1, 0], [0, 1]], 'abc')
101    [['b', 'a'], ['a', 'b']]
102    """
103    if isinstance(ind, list):
104        if lazy:
105            return (nested_get(i, coll, lazy=lazy) for i in ind)
106        else:
107            return [nested_get(i, coll, lazy=lazy) for i in ind]
108    else:
109        return coll[ind]
110
111
112def flatten(seq):
113    """
114
115    >>> list(flatten([1]))
116    [1]
117
118    >>> list(flatten([[1, 2], [1, 2]]))
119    [1, 2, 1, 2]
120
121    >>> list(flatten([[[1], [2]], [[1], [2]]]))
122    [1, 2, 1, 2]
123
124    >>> list(flatten(((1, 2), (1, 2)))) # Don't flatten tuples
125    [(1, 2), (1, 2)]
126
127    >>> list(flatten((1, 2, [3, 4]))) # support heterogeneous
128    [1, 2, 3, 4]
129    """
130    for item in seq:
131        if isinstance(item, list):
132            for item2 in flatten(item):
133                yield item2
134        else:
135            yield item
136
137
138def suffix(key, term):
139    """ suffix a key with a suffix
140
141    Works if they key is a string or a tuple
142
143    >>> suffix('x', '.dtype')
144    'x.dtype'
145    >>> suffix(('a', 'b', 'c'), '.dtype')
146    ('a', 'b', 'c.dtype')
147    """
148    if isinstance(key, str):
149        return key + term
150    elif isinstance(key, tuple):
151        return key[:-1] + (suffix(key[-1], term),)
152    else:
153        return suffix(str(key), term)
154
155
156def extend(key, term):
157    """ extend a key with a another element in a tuple
158
159    Works if they key is a string or a tuple
160
161    >>> extend('x', '.dtype')
162    ('x', '.dtype')
163    >>> extend(('a', 'b', 'c'), '.dtype')
164    ('a', 'b', 'c', '.dtype')
165    """
166    if isinstance(term, tuple):
167        pass
168    elif isinstance(term, str):
169        term = (term,)
170    else:
171        term = (str(term),)
172
173    if not isinstance(key, tuple):
174        key = (key,)
175
176    return key + term
177