1from collections import OrderedDict
2
3from hypothesis import given
4from hypothesis.strategies import floats, integers
5import numpy
6import pytest
7
8import snuggs
9
10
11@pytest.fixture
12def ones():
13    return numpy.ones((2, 2))
14
15
16@pytest.fixture
17def truetrue():
18    return numpy.array([True, True])
19
20
21@pytest.fixture
22def truefalse():
23    return numpy.array([True, False])
24
25
26@given(integers())
27def test_integer_operand(num):
28    assert list(snuggs.operand.parseString(repr(num))) == [num]
29
30
31@given(floats(allow_infinity=False, allow_nan=False))
32def test_real_operand(num):
33    assert list(snuggs.operand.parseString(repr(num))) == [num]
34
35
36def test_int_expr():
37    assert snuggs.eval('(+ 1 2)') == 3
38
39
40def test_int_mult_expr():
41    assert snuggs.eval('(+ 1 2 3)') == 6
42
43
44def test_real_expr():
45    assert round(snuggs.eval('(* 0.1 0.2)'), 3) == 0.02
46
47
48def test_int_real_expr():
49    assert snuggs.eval('(+ 2 1.1)') == 3.1
50
51
52def test_real_int_expr():
53    assert snuggs.eval('(+ 1.1 2)') == 3.1
54
55
56def test_arr_var(ones):
57    r = snuggs.eval('(+ foo 0)', foo=ones)
58    assert list(r.flatten()) == [1, 1, 1, 1]
59
60
61def test_arr_lookup(ones):
62    kwargs = OrderedDict((('foo', ones),
63                          ('bar', 2.0 * ones),
64                          ('a', 3.0 * ones)))
65    r = snuggs.eval('(read 1)', kwargs)
66    assert list(r.flatten()) == [1, 1, 1, 1]
67
68def test_arr_var_long(ones):
69    r = snuggs.eval('(+ FOO_BAR_42 0)', FOO_BAR_42=ones)
70    assert list(r.flatten()) == [1, 1, 1, 1]
71
72
73@pytest.mark.xfail(reason="Keyword argument order can't be relied on")
74def test_arr_lookup_kwarg_order(ones):
75    kwargs = OrderedDict((('foo', ones),
76                          ('bar', 2.0 * ones),
77                          ('a', 3.0 * ones)))
78    r = snuggs.eval('(read 1)', **kwargs)
79    assert list(r.flatten()) == [1, 1, 1, 1]
80
81
82def test_arr_lookup_2(ones):
83    r = snuggs.eval('(read 1 1)', foo=ones)
84    assert list(r.flatten()) == [1, 1]
85
86
87def test_arr_take(ones):
88    r = snuggs.eval('(take foo 1)', foo=ones)
89    assert list(r.flatten()) == [1, 1]
90    r = snuggs.eval('(take foo 2)', foo=ones)
91    assert list(r.flatten()) == [1, 1]
92
93
94def test_int_arr_expr(ones):
95    result = snuggs.eval('(+ foo 1)', foo=ones)
96    assert list(result.flatten()) == [2, 2, 2, 2]
97
98
99def test_int_arr_expr_by_name(ones):
100    result = snuggs.eval('(+ (read 1) 1.5)', foo=ones)
101    assert list(result.flatten()) == [2.5, 2.5, 2.5, 2.5]
102
103
104def test_int_arr_read(ones):
105    result = snuggs.eval('(+ (read 1 1) 1.5)', foo=ones)
106    assert list(result.flatten()) == [2.5, 2.5]
107
108
109def test_list(ones):
110    result = snuggs.eval(
111        '(asarray (take foo 1) (take foo 1) (take bar 1) (take bar 1))',
112        foo=ones, bar=ones)
113    assert list(result.flatten()) == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
114
115
116def test_eq(ones):
117    ones[0][0] = 2
118    result = snuggs.eval('(== foo 1)', foo=ones)
119    assert list(result.flatten()) == [False, True, True, True]
120
121
122def test_or(truetrue, truefalse):
123    result = snuggs.eval(
124        '(| foo bar)', foo=truetrue, bar=truefalse)
125    assert list(result.flatten()) == [True, True]
126
127
128def test_and(truetrue, truefalse):
129    result = snuggs.eval(
130        '(& foo bar)', foo=truetrue, bar=truefalse)
131    assert list(result.flatten()) == [True, False]
132
133
134def test_ones_like(truefalse):
135    result = snuggs.eval("(ones_like foo 'uint8')", foo=truefalse)
136    assert list(result.flatten()) == [1.0, 1.0]
137
138
139def test_full_like(truefalse):
140    result = snuggs.eval("(full_like foo 3.14 'float64')", foo=truefalse)
141    assert list(result.flatten()) == [3.14, 3.14]
142    result = snuggs.eval('(full_like foo 3.14 "float64")', foo=truefalse)
143    assert list(result.flatten()) == [3.14, 3.14]
144
145
146def test_ufunc(truetrue, truefalse):
147    result = snuggs.eval(
148        '(where (& foo bar) 1 0)', foo=truetrue, bar=truefalse)
149    assert list(result.flatten()) == [1.0, 0.0]
150
151
152def test_partial():
153    result = snuggs.eval('((partial * 2) 2)')
154    assert result == 4
155
156
157def test_map_func():
158    result = snuggs.eval('(map sqrt (asarray 1 4 9))')
159    assert list(result) == [1, 2, 3]
160
161
162def test_map_partial():
163    result = snuggs.eval('(map (partial * 2) (asarray 1 2 3))')
164    assert list(result) == [2, 4, 6]
165
166
167def test_map_asarray():
168    result = snuggs.eval('(asarray (map (partial * 2) (asarray 1 2 3)))')
169    assert list(result) == [2, 4, 6]
170
171
172def test_multi_operator_array(ones):
173    result = snuggs.eval(
174        '(+ ones (/ ones 1 0.5) (* ones 1 3))', ones=ones)
175    assert list(result.flatten()) == [6.0] * 4
176
177
178def test_nil():
179    assert snuggs.eval('(== nil nil)')
180    assert not snuggs.eval('(== 1 nil)')
181    assert not snuggs.eval('(== nil 1)')
182    assert snuggs.eval('(!= 1 nil)')
183    assert snuggs.eval('(!= nil 1)')
184
185
186def test_masked_arr():
187    foo = numpy.ma.masked_equal(numpy.array([0, 0, 0, 1], dtype='uint8'), 0)
188    r = snuggs.eval('(+ foo 1)', foo=foo)
189    assert list(r.data.flatten()) == [0, 0, 0, 2]
190    assert list(r.flatten()) == [numpy.ma.masked, numpy.ma.masked, numpy.ma.masked, 2]
191
192
193# Parse and syntax error testing.
194def test_missing_closing_paren():
195    with pytest.raises(SyntaxError) as excinfo:
196        snuggs.eval("(+ 1 2")
197    assert excinfo.value.lineno == 1
198    assert excinfo.value.offset == 7
199
200
201def test_missing_func():
202    with pytest.raises(SyntaxError) as excinfo:
203        snuggs.eval("(0 1 2)")
204    assert excinfo.value.lineno == 1
205    assert excinfo.value.offset == 2
206    assert str(excinfo.value) == "'0' is not a function or operator"
207
208
209def test_missing_func2():
210    with pytest.raises(SyntaxError) as excinfo:
211        snuggs.eval("(# 1 2)")
212    assert excinfo.value.lineno == 1
213    assert excinfo.value.offset == 2
214
215
216def test_undefined_var():
217    with pytest.raises(SyntaxError) as excinfo:
218        snuggs.eval("(+ 1 bogus)")
219    assert excinfo.value.lineno == 1
220    assert excinfo.value.offset == 6
221    assert str(excinfo.value) == "name 'bogus' is not defined"
222
223
224def test_bogus_higher_order_func():
225    with pytest.raises(SyntaxError) as excinfo:
226        snuggs.eval("((bogus * 2) 2)")
227    assert excinfo.value.lineno == 1
228    assert excinfo.value.offset == 3
229
230
231def test_type_error():
232    with pytest.raises(TypeError):
233        snuggs.eval("(+ 1 'bogus')")
234
235
236def test_negative_decimal():
237    """Negative decimals parse correctly"""
238    assert snuggs.eval("(< -0.9 0)")
239