1from numba import jit
2import unittest
3import numpy as np
4import copy
5from numba.tests.support import MemoryLeakMixin
6
7
8try:
9    xrange
10except NameError:
11    xrange = range
12
13
14@jit
15def inc(a):
16    for i in xrange(len(a)):
17        a[i] += 1
18    return a
19
20@jit
21def inc1(a):
22    a[0] += 1
23    return a[0]
24
25@jit
26def inc2(a):
27    a[0] += 1
28    return a[0], a[0] + 1
29
30
31def chain1(a):
32    x = y = z = inc(a)
33    return x + y + z
34
35
36def chain2(v):
37    a = np.zeros(2)
38    a[0] = x = a[1] = v
39    return a[0] + a[1] + (x / 2)
40
41
42def unpack1(x, y):
43    a, b = x, y
44    return a + b / 2
45
46
47def unpack2(x, y):
48    a, b = c, d = inc1(x), inc1(y)
49    return a + c / 2, b + d / 2
50
51
52def chain3(x, y):
53    a = (b, c) = (inc1(x), inc1(y))
54    (d, e) = f = (inc1(x), inc1(y))
55    return (a[0] + b / 2 + d + f[0]), (a[1] + c + e / 2 + f[1])
56
57
58def unpack3(x):
59    a, b = inc2(x)
60    return a + b / 2
61
62
63def unpack4(x):
64    a, b = c, d = inc2(x)
65    return a + c / 2, b + d / 2
66
67
68def unpack5(x):
69    a = b, c = inc2(x)
70    d, e = f = inc2(x)
71    return (a[0] + b / 2 + d + f[0]), (a[1] + c + e / 2 + f[1])
72
73
74def unpack6(x, y):
75    (a, b), (c, d) = (x, y), (y + 1, x + 1)
76    return a + c / 2, b / 2 + d
77
78
79class TestChainedAssign(MemoryLeakMixin, unittest.TestCase):
80    def test_chain1(self):
81        args = [
82            [np.arange(2)],
83            [np.arange(4, dtype=np.double)],
84        ]
85        self._test_template(chain1, args)
86
87    def test_chain2(self):
88        args = [
89            [3],
90            [3.0],
91        ]
92        self._test_template(chain2, args)
93
94    def test_unpack1(self):
95        args = [
96            [1, 3.0],
97            [1.0, 3],
98        ]
99        self._test_template(unpack1, args)
100
101    def test_unpack2(self):
102        args = [
103            [np.array([2]), np.array([4.0])],
104            [np.array([2.0]), np.array([4])],
105        ]
106        self._test_template(unpack2, args)
107
108    def test_chain3(self):
109        args = [
110            [np.array([0]), np.array([1.5])],
111            [np.array([0.5]), np.array([1])],
112        ]
113        self._test_template(chain3, args)
114
115    def test_unpack3(self):
116        args = [
117            [np.array([1])],
118            [np.array([1.0])],
119        ]
120        self._test_template(unpack3, args)
121
122    def test_unpack4(self):
123        args = [
124            [np.array([1])],
125            [np.array([1.0])],
126        ]
127        self._test_template(unpack4, args)
128
129    def test_unpack5(self):
130        args = [
131            [np.array([2])],
132            [np.array([2.0])],
133        ]
134        self._test_template(unpack5, args)
135
136    def test_unpack6(self):
137        args1 = 3.0, 2
138        args2 = 3.0, 2.0
139        self._test_template(unpack6, [args1, args2])
140
141    def _test_template(self, pyfunc, argcases):
142        cfunc = jit(pyfunc)
143        for args in argcases:
144            a1 = copy.deepcopy(args)
145            a2 = copy.deepcopy(args)
146            np.testing.assert_allclose(pyfunc(*a1), cfunc(*a2))
147
148
149if __name__ == '__main__':
150    unittest.main()
151
152