1###################################################################
2#  Numexpr - Fast numerical array expression evaluator for NumPy.
3#
4#      License: MIT
5#      Author:  See AUTHORS.txt
6#
7#  See LICENSE.txt and LICENSES/*.txt for details about copyright and
8#  rights to use.
9####################################################################
10
11from __future__ import print_function
12import timeit, numpy
13
14array_size = 1e6
15iterations = 2
16
17# Choose the type you want to benchmark
18#dtype = 'int8'
19#dtype = 'int16'
20#dtype = 'int32'
21#dtype = 'int64'
22dtype = 'float32'
23#dtype = 'float64'
24
25def compare_times(setup, expr):
26    print("Expression:", expr)
27    namespace = {}
28    exec(setup, namespace)
29
30    numpy_timer = timeit.Timer(expr, setup)
31    numpy_time = numpy_timer.timeit(number=iterations)
32    print('numpy:', numpy_time / iterations)
33
34    try:
35        weave_timer = timeit.Timer('blitz("result=%s")' % expr, setup)
36        weave_time = weave_timer.timeit(number=iterations)
37        print("Weave:", weave_time/iterations)
38
39        print("Speed-up of weave over numpy:", round(numpy_time/weave_time, 2))
40    except:
41        print("Skipping weave timing")
42
43    numexpr_timer = timeit.Timer('evaluate("%s", optimization="aggressive")' % expr, setup)
44    numexpr_time = numexpr_timer.timeit(number=iterations)
45    print("numexpr:", numexpr_time/iterations)
46
47    tratio = numpy_time/numexpr_time
48    print("Speed-up of numexpr over numpy:", round(tratio, 2))
49    return tratio
50
51setup1 = """\
52from numpy import arange
53try: from scipy.weave import blitz
54except: pass
55from numexpr import evaluate
56result = arange(%f, dtype='%s')
57b = arange(%f, dtype='%s')
58c = arange(%f, dtype='%s')
59d = arange(%f, dtype='%s')
60e = arange(%f, dtype='%s')
61""" % ((array_size, dtype)*5)
62expr1 = 'b*c+d*e'
63
64setup2 = """\
65from numpy import arange
66try: from scipy.weave import blitz
67except: pass
68from numexpr import evaluate
69a = arange(%f, dtype='%s')
70b = arange(%f, dtype='%s')
71result = arange(%f, dtype='%s')
72""" % ((array_size, dtype)*3)
73expr2 = '2*a+3*b'
74
75
76setup3 = """\
77from numpy import arange, sin, cos, sinh
78try: from scipy.weave import blitz
79except: pass
80from numexpr import evaluate
81a = arange(2*%f, dtype='%s')[::2]
82b = arange(%f, dtype='%s')
83result = arange(%f, dtype='%s')
84""" % ((array_size, dtype)*3)
85expr3 = '2*a + (cos(3)+5)*sinh(cos(b))'
86
87
88setup4 = """\
89from numpy import arange, sin, cos, sinh, arctan2
90try: from scipy.weave import blitz
91except: pass
92from numexpr import evaluate
93a = arange(2*%f, dtype='%s')[::2]
94b = arange(%f, dtype='%s')
95result = arange(%f, dtype='%s')
96""" % ((array_size, dtype)*3)
97expr4 = '2*a + arctan2(a, b)'
98
99
100setup5 = """\
101from numpy import arange, sin, cos, sinh, arctan2, sqrt, where
102try: from scipy.weave import blitz
103except: pass
104from numexpr import evaluate
105a = arange(2*%f, dtype='%s')[::2]
106b = arange(%f, dtype='%s')
107result = arange(%f, dtype='%s')
108""" % ((array_size, dtype)*3)
109expr5 = 'where(0.1*a > arctan2(a, b), 2*a, arctan2(a,b))'
110
111expr6 = 'where(a != 0.0, 2, b)'
112
113expr7 = 'where(a-10 != 0.0, a, 2)'
114
115expr8 = 'where(a%2 != 0.0, b+5, 2)'
116
117expr9 = 'where(a%2 != 0.0, 2, b+5)'
118
119expr10 = 'a**2 + (b+1)**-2.5'
120
121expr11 = '(a+1)**50'
122
123expr12 = 'sqrt(a**2 + b**2)'
124
125def compare(check_only=False):
126    experiments = [(setup1, expr1), (setup2, expr2), (setup3, expr3),
127                   (setup4, expr4), (setup5, expr5), (setup5, expr6),
128                   (setup5, expr7), (setup5, expr8), (setup5, expr9),
129                   (setup5, expr10), (setup5, expr11), (setup5, expr12),
130                   ]
131    total = 0
132    for params in experiments:
133        total += compare_times(*params)
134        print
135    average = total / len(experiments)
136    print("Average =", round(average, 2))
137    return average
138
139if __name__ == '__main__':
140    import numexpr
141    print("Numexpr version: ", numexpr.__version__)
142
143    averages = []
144    for i in range(iterations):
145        averages.append(compare())
146    print("Averages:", ', '.join("%.2f" % x for x in averages))
147