1from numba import jit 2import unittest 3import numpy as np 4import copy 5from numba.tests.support import MemoryLeakMixin 6 7 8try: 9 xrange 10except NameError: 11 xrange = range 12 13 14@jit 15def inc(a): 16 for i in xrange(len(a)): 17 a[i] += 1 18 return a 19 20@jit 21def inc1(a): 22 a[0] += 1 23 return a[0] 24 25@jit 26def inc2(a): 27 a[0] += 1 28 return a[0], a[0] + 1 29 30 31def chain1(a): 32 x = y = z = inc(a) 33 return x + y + z 34 35 36def chain2(v): 37 a = np.zeros(2) 38 a[0] = x = a[1] = v 39 return a[0] + a[1] + (x / 2) 40 41 42def unpack1(x, y): 43 a, b = x, y 44 return a + b / 2 45 46 47def unpack2(x, y): 48 a, b = c, d = inc1(x), inc1(y) 49 return a + c / 2, b + d / 2 50 51 52def chain3(x, y): 53 a = (b, c) = (inc1(x), inc1(y)) 54 (d, e) = f = (inc1(x), inc1(y)) 55 return (a[0] + b / 2 + d + f[0]), (a[1] + c + e / 2 + f[1]) 56 57 58def unpack3(x): 59 a, b = inc2(x) 60 return a + b / 2 61 62 63def unpack4(x): 64 a, b = c, d = inc2(x) 65 return a + c / 2, b + d / 2 66 67 68def unpack5(x): 69 a = b, c = inc2(x) 70 d, e = f = inc2(x) 71 return (a[0] + b / 2 + d + f[0]), (a[1] + c + e / 2 + f[1]) 72 73 74def unpack6(x, y): 75 (a, b), (c, d) = (x, y), (y + 1, x + 1) 76 return a + c / 2, b / 2 + d 77 78 79class TestChainedAssign(MemoryLeakMixin, unittest.TestCase): 80 def test_chain1(self): 81 args = [ 82 [np.arange(2)], 83 [np.arange(4, dtype=np.double)], 84 ] 85 self._test_template(chain1, args) 86 87 def test_chain2(self): 88 args = [ 89 [3], 90 [3.0], 91 ] 92 self._test_template(chain2, args) 93 94 def test_unpack1(self): 95 args = [ 96 [1, 3.0], 97 [1.0, 3], 98 ] 99 self._test_template(unpack1, args) 100 101 def test_unpack2(self): 102 args = [ 103 [np.array([2]), np.array([4.0])], 104 [np.array([2.0]), np.array([4])], 105 ] 106 self._test_template(unpack2, args) 107 108 def test_chain3(self): 109 args = [ 110 [np.array([0]), np.array([1.5])], 111 [np.array([0.5]), np.array([1])], 112 ] 113 self._test_template(chain3, args) 114 115 def test_unpack3(self): 116 args = [ 117 [np.array([1])], 118 [np.array([1.0])], 119 ] 120 self._test_template(unpack3, args) 121 122 def test_unpack4(self): 123 args = [ 124 [np.array([1])], 125 [np.array([1.0])], 126 ] 127 self._test_template(unpack4, args) 128 129 def test_unpack5(self): 130 args = [ 131 [np.array([2])], 132 [np.array([2.0])], 133 ] 134 self._test_template(unpack5, args) 135 136 def test_unpack6(self): 137 args1 = 3.0, 2 138 args2 = 3.0, 2.0 139 self._test_template(unpack6, [args1, args2]) 140 141 def _test_template(self, pyfunc, argcases): 142 cfunc = jit(pyfunc) 143 for args in argcases: 144 a1 = copy.deepcopy(args) 145 a2 = copy.deepcopy(args) 146 np.testing.assert_allclose(pyfunc(*a1), cfunc(*a2)) 147 148 149if __name__ == '__main__': 150 unittest.main() 151 152