1cimport cython
2
3cdef object two = 2
4
5cdef int size_in_bits_ = sizeof(INT) * 8
6cdef bint is_signed_ = not ((<INT>-1) > 0)
7cdef INT max_value_ = <INT>(two ** (size_in_bits_ - is_signed_) - 1)
8cdef INT min_value_ = ~max_value_
9cdef INT half_ = max_value_ // <INT>2
10
11# Python visible.
12size_in_bits = size_in_bits_
13is_signed = is_signed_
14max_value = max_value_
15min_value = min_value_
16half = half_
17
18
19import operator
20from libc.math cimport sqrt
21
22cpdef check(func, op, a, b):
23    cdef INT res = 0, op_res = 0
24    cdef bint func_overflow = False
25    cdef bint assign_overflow = False
26    try:
27        res = func(a, b)
28    except OverflowError:
29        func_overflow = True
30    try:
31        op_res = op(a, b)
32    except OverflowError:
33        assign_overflow = True
34    assert func_overflow == assign_overflow, "Inconsistent overflow: %s(%s, %s)" % (func, a, b)
35    if not func_overflow:
36        assert res == op_res, "Inconsistent values: %s(%s, %s) == %s != %s" % (func, a, b, res, op_res)
37
38medium_values = (max_value_ / 2, max_value_ / 3, min_value_ / 2, <INT>sqrt(<long double>max_value_) - <INT>1, <INT>sqrt(<long double>max_value_) + 1)
39def run_test(func, op):
40    cdef INT offset, b
41    check(func, op, 300, 200)
42    check(func, op, max_value_, max_value_)
43    check(func, op, max_value_, min_value_)
44    if not is_signed_ or not func is test_sub:
45        check(func, op, min_value_, min_value_)
46
47    for offset in range(5):
48        check(func, op, max_value_ - <INT>1, offset)
49        check(func, op, min_value_ + <INT>1, offset)
50        if is_signed_:
51            check(func, op, max_value_ - 1, 2 - offset)
52            check(func, op, min_value_ + 1, 2 - offset)
53
54    for offset in range(9):
55        check(func, op, max_value_ / <INT>2, offset)
56        check(func, op, min_value_ / <INT>3, offset)
57        check(func, op, max_value_ / <INT>4, offset)
58        check(func, op, min_value_ / <INT>5, offset)
59        if is_signed_:
60            check(func, op, max_value_ / 2, 4 - offset)
61            check(func, op, min_value_ / 3, 4 - offset)
62            check(func, op, max_value_ / -4, 3 - offset)
63            check(func, op, min_value_ / -5, 3 - offset)
64
65    for offset in range(-3, 4):
66        for a in medium_values:
67            for b in medium_values:
68                check(func, op, a, b + offset)
69
70@cython.overflowcheck(True)
71def test_add(INT a, INT b):
72    """
73    >>> test_add(1, 2)
74    3
75    >>> test_add(max_value, max_value)   #doctest: +ELLIPSIS
76    Traceback (most recent call last):
77    ...
78    OverflowError: value too large
79    >>> run_test(test_add, operator.add)
80    """
81    return int(a + b)
82
83@cython.overflowcheck(True)
84def test_sub(INT a, INT b):
85    """
86    >>> test_sub(10, 1)
87    9
88    >>> test_sub(min_value, 1)   #doctest: +ELLIPSIS
89    Traceback (most recent call last):
90    ...
91    OverflowError: value too large
92    >>> run_test(test_sub, operator.sub)
93    """
94    return int(a - b)
95
96@cython.overflowcheck(True)
97def test_mul(INT a, INT b):
98    """
99    >>> test_mul(11, 13)
100    143
101    >>> test_mul(max_value / 2, max_value / 2)   #doctest: +ELLIPSIS
102    Traceback (most recent call last):
103    ...
104    OverflowError: value too large
105    >>> run_test(test_mul, operator.mul)
106    """
107    return int(a * b)
108
109@cython.overflowcheck(True)
110def test_nested_add(INT a, INT b, INT c):
111    """
112    >>> test_nested_add(1, 2, 3)
113    6
114    >>> expect_overflow(test_nested_add, half + 1, half + 1, half + 1)
115    >>> expect_overflow(test_nested_add, half - 1, half - 1, half - 1)
116    """
117    return int(a + b + c)
118
119def expect_overflow(func, *args):
120    try:
121        res = func(*args)
122    except OverflowError:
123        return
124    assert False, "Expected OverflowError, got %s" % res
125
126cpdef format(INT value):
127    """
128    >>> format(1)
129    '1'
130    >>> format(half - 1)
131    'half - 1'
132    >>> format(half)
133    'half'
134    >>> format(half + 2)
135    'half + 2'
136    >>> format(half + half - 3)
137    'half + half - 3'
138    >>> format(max_value)
139    'max_value'
140    """
141    if value == max_value_:
142        return "max_value"
143    elif value == half_:
144        return "half"
145    elif max_value_ - value <= max_value_ // <INT>4:
146        return "half + half - %s" % (half_ + half_ - value)
147    elif max_value_ - value <= half_:
148        return "half + %s" % (value - half_)
149    elif max_value_ - value <= half_ + max_value_ // <INT>4:
150        return "half - %s" % (half_ - value)
151    else:
152        return "%s" % value
153
154cdef INT called(INT value):
155    print("called(%s)" % format(value))
156    return value
157
158@cython.overflowcheck(True)
159def test_nested(INT a, INT b, INT c, INT d):
160    """
161    >>> test_nested_func(1, 2, 3)
162    called(5)
163    6
164    >>> expect_overflow(test_nested, half, half, 1, 1)
165    >>> expect_overflow(test_nested, half, 1, half, half)
166    >>> expect_overflow(test_nested, half, 2, half, 2)
167
168    >>> print(format(test_nested(half, 2, 0, 1)))
169    half + half - 0
170    >>> print(format(test_nested(1, 0, half, 2)))
171    half + half - 0
172    >>> print(format(test_nested(half, 1, 1, half)))
173    half + half - 0
174    """
175    return int(a * b + c * d)
176
177@cython.overflowcheck(True)
178def test_nested_func(INT a, INT b, INT c):
179    """
180    >>> test_nested_func(1, 2, 3)
181    called(5)
182    6
183    >>> expect_overflow(test_nested_func, half + 1, half + 1, half + 1)
184    >>> expect_overflow(test_nested_func, half - 1, half - 1, half - 1)
185    called(half + half - 2)
186    >>> print(format(test_nested_func(1, half - 1, half - 1)))
187    called(half + half - 2)
188    half + half - 1
189    """
190    return int(a + called(b + c))
191
192
193@cython.overflowcheck(True)
194def test_add_const(INT a):
195    """
196    >>> test_add_const(1)
197    101
198    >>> expect_overflow(test_add_const, max_value)
199    >>> expect_overflow(test_add_const , max_value - 99)
200    >>> test_add_const(max_value - 100) == max_value
201    True
202    """
203    return int(a + <INT>100)
204
205@cython.overflowcheck(True)
206def test_sub_const(INT a):
207    """
208    >>> test_sub_const(101)
209    1
210    >>> expect_overflow(test_sub_const, min_value)
211    >>> expect_overflow(test_sub_const, min_value + 99)
212    >>> test_sub_const(min_value + 100) == min_value
213    True
214    """
215    return int(a - <INT>100)
216
217@cython.overflowcheck(True)
218def test_mul_const(INT a):
219    """
220    >>> test_mul_const(2)
221    200
222    >>> expect_overflow(test_mul_const, max_value)
223    >>> expect_overflow(test_mul_const, max_value // 99)
224    >>> test_mul_const(max_value // 100) == max_value - max_value % 100
225    True
226    """
227    return int(a * <INT>100)
228
229@cython.overflowcheck(True)
230def test_lshift(INT a, int b):
231    """
232    >>> test_lshift(1, 10)
233    1024
234    >>> test_lshift(1, size_in_bits - 2) == 1 << (size_in_bits - 2)
235    True
236    >>> test_lshift(0, size_in_bits - 1)
237    0
238    >>> test_lshift(1, size_in_bits - 1) == 1 << (size_in_bits - 1) if not is_signed else True
239    True
240    >>> if is_signed: expect_overflow(test_lshift, 1, size_in_bits - 1)
241    >>> expect_overflow(test_lshift, 0, size_in_bits)
242    >>> expect_overflow(test_lshift, 1, size_in_bits)
243    >>> expect_overflow(test_lshift, 0, size_in_bits + 1)
244    >>> expect_overflow(test_lshift, 1, size_in_bits + 1)
245    >>> expect_overflow(test_lshift, 1, 100)
246    >>> expect_overflow(test_lshift, max_value, 1)
247    >>> test_lshift(max_value, 0) == max_value
248    True
249
250    >>> check(test_lshift, operator.lshift, 10, 15)
251    >>> check(test_lshift, operator.lshift, 10, 30)
252    >>> check(test_lshift, operator.lshift, 100, 60)
253    """
254    return int(a << b)
255