1# tag: run 2# tag: openmp 3 4cimport cython.parallel 5from cython.parallel import prange, threadid 6cimport openmp 7from libc.stdlib cimport malloc, free 8 9openmp.omp_set_nested(1) 10 11cdef int forward(int x) nogil: 12 return x 13 14def test_parallel(): 15 """ 16 >>> test_parallel() 17 """ 18 cdef int maxthreads = openmp.omp_get_max_threads() 19 cdef int *buf = <int *> malloc(sizeof(int) * maxthreads) 20 21 if buf == NULL: 22 raise MemoryError 23 24 with nogil, cython.parallel.parallel(): 25 buf[threadid()] = threadid() 26 # Recognise threadid() also when it's used in a function argument. 27 # See https://github.com/cython/cython/issues/3594 28 buf[forward(cython.parallel.threadid())] = forward(threadid()) 29 30 for i in range(maxthreads): 31 assert buf[i] == i 32 33 free(buf) 34 35cdef int get_num_threads() with gil: 36 print "get_num_threads called" 37 return 3 38 39def test_num_threads(): 40 """ 41 >>> test_num_threads() 42 1 43 get_num_threads called 44 3 45 get_num_threads called 46 3 47 """ 48 cdef int dyn = openmp.omp_get_dynamic() 49 cdef int num_threads 50 cdef int *p = &num_threads 51 52 openmp.omp_set_dynamic(0) 53 54 with nogil, cython.parallel.parallel(num_threads=1): 55 p[0] = openmp.omp_get_num_threads() 56 57 print num_threads 58 59 with nogil, cython.parallel.parallel(num_threads=get_num_threads()): 60 p[0] = openmp.omp_get_num_threads() 61 62 print num_threads 63 64 cdef int i 65 num_threads = 0xbad 66 for i in prange(1, nogil=True, num_threads=get_num_threads()): 67 p[0] = openmp.omp_get_num_threads() 68 break 69 70 openmp.omp_set_dynamic(dyn) 71 72 return num_threads 73 74''' 75def test_parallel_catch(): 76 """ 77 >>> test_parallel_catch() 78 True 79 """ 80 cdef int i, j, num_threads 81 exceptions = [] 82 83 for i in prange(100, nogil=True, num_threads=4): 84 num_threads = openmp.omp_get_num_threads() 85 86 with gil: 87 try: 88 for j in prange(100, nogil=True): 89 if i + j > 60: 90 with gil: 91 raise Exception("try and catch me if you can!") 92 except Exception, e: 93 exceptions.append(e) 94 break 95 96 print len(exceptions) == num_threads 97 assert len(exceptions) == num_threads, (len(exceptions), num_threads) 98''' 99 100 101OPENMP_PARALLEL = True 102include "sequential_parallel.pyx" 103