1cimport cython
2
3cdef object two = 2
4cdef int size_in_bits = sizeof(INT) * 8
5
6cdef bint is_signed_ = not ((<INT>-1) > 0)
7cdef INT max_value_ = <INT>(two ** (size_in_bits - is_signed_) - 1)
8cdef INT min_value_ = ~max_value_
9cdef INT half_ = max_value_ // <INT>2
10
11# Python visible.
12is_signed = is_signed_
13max_value = max_value_
14min_value = min_value_
15half = half_
16
17
18import operator
19from libc.math cimport sqrt
20
21cpdef check(func, op, a, b):
22    cdef INT res = 0, op_res = 0
23    cdef bint func_overflow = False
24    cdef bint assign_overflow = False
25    try:
26        res = func(a, b)
27    except OverflowError:
28        func_overflow = True
29    try:
30        op_res = op(a, b)
31    except OverflowError:
32        assign_overflow = True
33    assert func_overflow == assign_overflow, "Inconsistent overflow: %s(%s, %s)" % (func, a, b)
34    if not func_overflow:
35        assert res == op_res, "Inconsistent values: %s(%s, %s) == %s != %s" % (func, a, b, res, op_res)
36
37medium_values = (max_value_ / 2, max_value_ / 3, min_value_ / 2, <INT>sqrt(<long double>max_value_) - <INT>1, <INT>sqrt(<long double>max_value_) + 1)
38def run_test(func, op):
39    cdef INT offset, b
40    check(func, op, 300, 200)
41    check(func, op, max_value_, max_value_)
42    check(func, op, max_value_, min_value_)
43    if not is_signed_ or not func is test_sub:
44        check(func, op, min_value_, min_value_)
45
46    for offset in range(5):
47        check(func, op, max_value_ - <INT>1, offset)
48        check(func, op, min_value_ + <INT>1, offset)
49        if is_signed_:
50            check(func, op, max_value_ - 1, 2 - offset)
51            check(func, op, min_value_ + 1, 2 - offset)
52
53    for offset in range(9):
54        check(func, op, max_value_ / <INT>2, offset)
55        check(func, op, min_value_ / <INT>3, offset)
56        check(func, op, max_value_ / <INT>4, offset)
57        check(func, op, min_value_ / <INT>5, offset)
58        if is_signed_:
59            check(func, op, max_value_ / 2, 4 - offset)
60            check(func, op, min_value_ / 3, 4 - offset)
61            check(func, op, max_value_ / -4, 3 - offset)
62            check(func, op, min_value_ / -5, 3 - offset)
63
64    for offset in range(-3, 4):
65        for a in medium_values:
66            for b in medium_values:
67                check(func, op, a, b + offset)
68
69@cython.overflowcheck(True)
70def test_add(INT a, INT b):
71    """
72    >>> test_add(1, 2)
73    3
74    >>> test_add(max_value, max_value)   #doctest: +ELLIPSIS
75    Traceback (most recent call last):
76    ...
77    OverflowError: value too large
78    >>> run_test(test_add, operator.add)
79    """
80    return int(a + b)
81
82@cython.overflowcheck(True)
83def test_sub(INT a, INT b):
84    """
85    >>> test_sub(10, 1)
86    9
87    >>> test_sub(min_value, 1)   #doctest: +ELLIPSIS
88    Traceback (most recent call last):
89    ...
90    OverflowError: value too large
91    >>> run_test(test_sub, operator.sub)
92    """
93    return int(a - b)
94
95@cython.overflowcheck(True)
96def test_mul(INT a, INT b):
97    """
98    >>> test_mul(11, 13)
99    143
100    >>> test_mul(max_value / 2, max_value / 2)   #doctest: +ELLIPSIS
101    Traceback (most recent call last):
102    ...
103    OverflowError: value too large
104    >>> run_test(test_mul, operator.mul)
105    """
106    return int(a * b)
107
108@cython.overflowcheck(True)
109def test_nested_add(INT a, INT b, INT c):
110    """
111    >>> test_nested_add(1, 2, 3)
112    6
113    >>> expect_overflow(test_nested_add, half + 1, half + 1, half + 1)
114    >>> expect_overflow(test_nested_add, half - 1, half - 1, half - 1)
115    """
116    return int(a + b + c)
117
118def expect_overflow(func, *args):
119    try:
120        res = func(*args)
121    except OverflowError:
122        return
123    assert False, "Expected OverflowError, got %s" % res
124
125cpdef format(INT value):
126    """
127    >>> format(1)
128    '1'
129    >>> format(half - 1)
130    'half - 1'
131    >>> format(half)
132    'half'
133    >>> format(half + 2)
134    'half + 2'
135    >>> format(half + half - 3)
136    'half + half - 3'
137    >>> format(max_value)
138    'max_value'
139    """
140    if value == max_value_:
141        return "max_value"
142    elif value == half_:
143        return "half"
144    elif max_value_ - value <= max_value_ // <INT>4:
145        return "half + half - %s" % (half_ + half_ - value)
146    elif max_value_ - value <= half_:
147        return "half + %s" % (value - half_)
148    elif max_value_ - value <= half_ + max_value_ // <INT>4:
149        return "half - %s" % (half_ - value)
150    else:
151        return "%s" % value
152
153cdef INT called(INT value):
154    print("called(%s)" % format(value))
155    return value
156
157@cython.overflowcheck(True)
158def test_nested(INT a, INT b, INT c, INT d):
159    """
160    >>> test_nested_func(1, 2, 3)
161    called(5)
162    6
163    >>> expect_overflow(test_nested, half, half, 1, 1)
164    >>> expect_overflow(test_nested, half, 1, half, half)
165    >>> expect_overflow(test_nested, half, 2, half, 2)
166
167    >>> print(format(test_nested(half, 2, 0, 1)))
168    half + half - 0
169    >>> print(format(test_nested(1, 0, half, 2)))
170    half + half - 0
171    >>> print(format(test_nested(half, 1, 1, half)))
172    half + half - 0
173    """
174    return int(a * b + c * d)
175
176@cython.overflowcheck(True)
177def test_nested_func(INT a, INT b, INT c):
178    """
179    >>> test_nested_func(1, 2, 3)
180    called(5)
181    6
182    >>> expect_overflow(test_nested_func, half + 1, half + 1, half + 1)
183    >>> expect_overflow(test_nested_func, half - 1, half - 1, half - 1)
184    called(half + half - 2)
185    >>> print(format(test_nested_func(1, half - 1, half - 1)))
186    called(half + half - 2)
187    half + half - 1
188    """
189    return int(a + called(b + c))
190
191
192@cython.overflowcheck(True)
193def test_add_const(INT a):
194    """
195    >>> test_add_const(1)
196    101
197    >>> expect_overflow(test_add_const, max_value)
198    >>> expect_overflow(test_add_const , max_value - 99)
199    >>> test_add_const(max_value - 100) == max_value
200    True
201    """
202    return int(a + <INT>100)
203
204@cython.overflowcheck(True)
205def test_sub_const(INT a):
206    """
207    >>> test_sub_const(101)
208    1
209    >>> expect_overflow(test_sub_const, min_value)
210    >>> expect_overflow(test_sub_const, min_value + 99)
211    >>> test_sub_const(min_value + 100) == min_value
212    True
213    """
214    return int(a - <INT>100)
215
216@cython.overflowcheck(True)
217def test_mul_const(INT a):
218    """
219    >>> test_mul_const(2)
220    200
221    >>> expect_overflow(test_mul_const, max_value)
222    >>> expect_overflow(test_mul_const, max_value // 99)
223    >>> test_mul_const(max_value // 100) == max_value - max_value % 100
224    True
225    """
226    return int(a * <INT>100)
227
228@cython.overflowcheck(True)
229def test_lshift(INT a, int b):
230    """
231    >>> test_lshift(1, 10)
232    1024
233    >>> expect_overflow(test_lshift, 1, 100)
234    >>> expect_overflow(test_lshift, max_value, 1)
235    >>> test_lshift(max_value, 0) == max_value
236    True
237
238    >>> check(test_lshift, operator.lshift, 10, 15)
239    >>> check(test_lshift, operator.lshift, 10, 30)
240    >>> check(test_lshift, operator.lshift, 100, 60)
241    """
242    return int(a << b)
243