1from math import sin, cos
2import pytest
3
4from funcy.calc import *
5
6
7def test_memoize():
8    @memoize
9    def inc(x):
10        calls.append(x)
11        return x + 1
12
13    calls = []
14    assert inc(0) == 1
15    assert inc(1) == 2
16    assert inc(0) == 1
17    assert calls == [0, 1]
18
19    # using kwargs
20    assert inc(x=0) == 1
21    assert inc(x=1) == 2
22    assert inc(x=0) == 1
23    assert calls == [0, 1, 0, 1]
24
25
26def test_memoize_args_kwargs():
27    @memoize
28    def mul(x, by=1):
29        calls.append((x, by))
30        return x * by
31
32    calls = []
33    assert mul(0) == 0
34    assert mul(1) == 1
35    assert mul(0) == 0
36    assert calls == [(0, 1), (1, 1)]
37
38    # more with kwargs
39    assert mul(0, 1) == 0
40    assert mul(1, 1) == 1
41    assert mul(0, 1) == 0
42    assert calls == [(0, 1), (1, 1), (0, 1), (1, 1)]
43
44
45def test_memoize_memory():
46    @memoize
47    def inc(x):
48        calls.append(x)
49        return x + 1
50
51    calls = []
52    inc(0)
53    inc.memory.clear()
54    inc(0)
55    assert calls == [0, 0]
56
57
58def test_memoize_key_func():
59    @memoize(key_func=len)
60    def inc(s):
61        calls.append(s)
62        return s * 2
63
64    calls = []
65    assert inc('a') == 'aa'
66    assert inc('b') == 'aa'
67    inc('ab')
68    assert calls == ['a', 'ab']
69
70
71def test_make_lookuper():
72    @make_lookuper
73    def letter_index():
74        return ((c, i) for i, c in enumerate('abcdefghij'))
75
76    assert letter_index('c') == 2
77    with pytest.raises(LookupError): letter_index('_')
78
79
80def test_make_lookuper_nested():
81    tables_built = [0]
82
83    @make_lookuper
84    def function_table(f):
85        tables_built[0] += 1
86        return ((x, f(x)) for x in range(10))
87
88    assert function_table(sin)(5) == sin(5)
89    assert function_table(cos)(3) == cos(3)
90    assert function_table(sin)(3) == sin(3)
91    assert tables_built[0] == 2
92
93    with pytest.raises(LookupError): function_table(cos)(-1)
94
95
96def test_silent_lookuper():
97    @silent_lookuper
98    def letter_index():
99        return ((c, i) for i, c in enumerate('abcdefghij'))
100
101    assert letter_index('c') == 2
102    assert letter_index('_') is None
103
104
105def test_silnent_lookuper_nested():
106    @silent_lookuper
107    def function_table(f):
108        return ((x, f(x)) for x in range(10))
109
110    assert function_table(sin)(5) == sin(5)
111    assert function_table(cos)(-1) is None
112
113
114def test_cache():
115    calls = []
116
117    @cache(timeout=60)
118    def inc(x):
119        calls.append(x)
120        return x + 1
121
122    assert inc(0) == 1
123    assert inc(1) == 2
124    assert inc(0) == 1
125    assert calls == [0, 1]
126
127
128def test_cache_mixed_args():
129    @cache(timeout=60)
130    def add(x, y):
131        return x + y
132
133    assert add(1, y=2) == 3
134
135
136def test_cache_timedout():
137    calls = []
138
139    @cache(timeout=0)
140    def inc(x):
141        calls.append(x)
142        return x + 1
143
144    assert inc(0) == 1
145    assert inc(0) == 1
146    assert calls == [0, 0]
147
148
149def test_cache_invalidate():
150    calls = []
151
152    @cache(timeout=60)
153    def inc(x):
154        calls.append(x)
155        return x + 1
156
157    assert inc(0) == 1
158    assert inc(1) == 2
159    assert inc(0) == 1
160    assert calls == [0, 1]
161
162    inc.invalidate_all()
163    assert inc(0) == 1
164    assert inc(1) == 2
165    assert inc(0) == 1
166    assert calls == [0, 1, 0, 1]
167
168    inc.invalidate(1)
169    assert inc(0) == 1
170    assert inc(1) == 2
171    assert inc(0) == 1
172    assert calls == [0, 1, 0, 1, 1]
173
174    # ensure invalidate() is idempotent (doesn't raise KeyError on the 2nd call)
175    inc.invalidate(0)
176    inc.invalidate(0)
177