1# tag: run
2# tag: openmp
3
4cimport cython.parallel
5from cython.parallel import prange, threadid
6cimport openmp
7from libc.stdlib cimport malloc, free
8
9openmp.omp_set_nested(1)
10
11cdef int forward(int x) nogil:
12    return x
13
14def test_parallel():
15    """
16    >>> test_parallel()
17    """
18    cdef int maxthreads = openmp.omp_get_max_threads()
19    cdef int *buf = <int *> malloc(sizeof(int) * maxthreads)
20
21    if buf == NULL:
22        raise MemoryError
23
24    with nogil, cython.parallel.parallel():
25        buf[threadid()] = threadid()
26        # Recognise threadid() also when it's used in a function argument.
27        # See https://github.com/cython/cython/issues/3594
28        buf[forward(cython.parallel.threadid())] = forward(threadid())
29
30    for i in range(maxthreads):
31        assert buf[i] == i
32
33    free(buf)
34
35cdef int get_num_threads() with gil:
36    print "get_num_threads called"
37    return 3
38
39def test_num_threads():
40    """
41    >>> test_num_threads()
42    1
43    get_num_threads called
44    3
45    get_num_threads called
46    3
47    """
48    cdef int dyn = openmp.omp_get_dynamic()
49    cdef int num_threads
50    cdef int *p = &num_threads
51
52    openmp.omp_set_dynamic(0)
53
54    with nogil, cython.parallel.parallel(num_threads=1):
55        p[0] = openmp.omp_get_num_threads()
56
57    print num_threads
58
59    with nogil, cython.parallel.parallel(num_threads=get_num_threads()):
60        p[0] = openmp.omp_get_num_threads()
61
62    print num_threads
63
64    cdef int i
65    num_threads = 0xbad
66    for i in prange(1, nogil=True, num_threads=get_num_threads()):
67        p[0] = openmp.omp_get_num_threads()
68        break
69
70    openmp.omp_set_dynamic(dyn)
71
72    return num_threads
73
74'''
75def test_parallel_catch():
76    """
77    >>> test_parallel_catch()
78    True
79    """
80    cdef int i, j, num_threads
81    exceptions = []
82
83    for i in prange(100, nogil=True, num_threads=4):
84        num_threads = openmp.omp_get_num_threads()
85
86        with gil:
87            try:
88                for j in prange(100, nogil=True):
89                    if i + j > 60:
90                        with gil:
91                            raise Exception("try and catch me if you can!")
92            except Exception, e:
93                exceptions.append(e)
94                break
95
96    print len(exceptions) == num_threads
97    assert len(exceptions) == num_threads, (len(exceptions), num_threads)
98'''
99
100
101OPENMP_PARALLEL = True
102include "sequential_parallel.pyx"
103