1CUDA Kernel API
2===============
3
4Kernel declaration
5------------------
6
7The ``@cuda.jit`` decorator is used to create a CUDA dispatcher object that can
8be configured and launched:
9
10.. autofunction:: numba.cuda.jit
11
12
13Dispatcher objects
14------------------
15
16The usual syntax for configuring a Dispatcher with a launch configuration uses
17subscripting, with the arguments being as in the following:
18
19.. code-block:: python
20
21   # func is some function decorated with @cuda.jit
22   func[griddim, blockdim, stream, sharedmem]
23
24
25The ``griddim`` and ``blockdim`` arguments specify the size of the grid and
26thread blocks, and may be either integers or tuples of length up to 3. The
27``stream`` parameter is an optional stream on which the kernel will be launched,
28and the ``sharedmem`` parameter specifies the size of dynamic shared memory in
29bytes.
30
31Subscripting the Dispatcher returns a configuration object that can be called
32with the kernel arguments:
33
34.. code-block:: python
35
36   configured = func[griddim, blockdim, stream, sharedmem]
37   configured(x, y, z)
38
39
40However, it is more idiomatic to configure and call the kernel within a single
41statement:
42
43.. code-block:: python
44
45   func[griddim, blockdim, stream, sharedmem](x, y, z)
46
47This is similar to launch configuration in CUDA C/C++:
48
49.. code-block:: cuda
50
51   func<<<griddim, blockdim, sharedmem, stream>>>(x, y, z)
52
53.. note:: The order of ``stream`` and ``sharedmem`` are reversed in Numba
54   compared to in CUDA C/C++.
55
56Dispatcher objects also provide several utility methods for inspection and
57creating a specialized instance:
58
59.. autoclass:: numba.cuda.compiler.Dispatcher
60   :members: inspect_asm, inspect_llvm, inspect_sass, inspect_types,
61             specialize, specialized, extensions
62
63
64Intrinsic Attributes and Functions
65----------------------------------
66
67The remainder of the attributes and functions in this section may only be called
68from within a CUDA Kernel.
69
70Thread Indexing
71~~~~~~~~~~~~~~~
72
73.. attribute:: numba.cuda.threadIdx
74
75    The thread indices in the current thread block, accessed through the
76    attributes ``x``, ``y``, and ``z``. Each index is an integer spanning the
77    range from 0 inclusive to the corresponding value of the attribute in
78    :attr:`numba.cuda.blockDim` exclusive.
79
80.. attribute:: numba.cuda.blockIdx
81
82    The block indices in the grid of thread blocks, accessed through the
83    attributes ``x``, ``y``, and ``z``. Each index is an integer spanning the
84    range from 0 inclusive to the corresponding value of the attribute in
85    :attr:`numba.cuda.gridDim` exclusive.
86
87.. attribute:: numba.cuda.blockDim
88
89    The shape of a block of threads, as declared when instantiating the
90    kernel.  This value is the same for all threads in a given kernel, even
91    if they belong to different blocks (i.e. each block is "full").
92
93.. attribute:: numba.cuda.gridDim
94
95    The shape of the grid of blocks, accessed through the attributes ``x``,
96    ``y``, and ``z``.
97
98.. attribute:: numba.cuda.laneid
99
100    The thread index in the current warp, as an integer spanning the range
101    from 0 inclusive to the :attr:`numba.cuda.warpsize` exclusive.
102
103.. attribute:: numba.cuda.warpsize
104
105    The size in threads of a warp on the GPU. Currently this is always 32.
106
107.. function:: numba.cuda.grid(ndim)
108
109   Return the absolute position of the current thread in the entire
110   grid of blocks.  *ndim* should correspond to the number of dimensions
111   declared when instantiating the kernel.  If *ndim* is 1, a single integer
112   is returned.  If *ndim* is 2 or 3, a tuple of the given number of
113   integers is returned.
114
115   Computation of the first integer is as follows::
116
117      cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
118
119   and is similar for the other two indices, but using the ``y`` and ``z``
120   attributes.
121
122.. function:: numba.cuda.gridsize(ndim)
123
124   Return the absolute size (or shape) in threads of the entire grid of
125   blocks. *ndim* should correspond to the number of dimensions declared when
126   instantiating the kernel.
127
128   Computation of the first integer is as follows::
129
130       cuda.blockDim.x * cuda.gridDim.x
131
132   and is similar for the other two indices, but using the ``y`` and ``z``
133   attributes.
134
135Memory Management
136~~~~~~~~~~~~~~~~~
137
138.. function:: numba.cuda.shared.array(shape, dtype)
139
140   Creates an array in the local memory space of the CUDA kernel with
141   the given ``shape`` and ``dtype``.
142
143   Returns an array with its content uninitialized.
144
145   .. note:: All threads in the same thread block sees the same array.
146
147.. function:: numba.cuda.local.array(shape, dtype)
148
149   Creates an array in the local memory space of the CUDA kernel with the
150   given ``shape`` and ``dtype``.
151
152   Returns an array with its content uninitialized.
153
154   .. note:: Each thread sees a unique array.
155
156.. function:: numba.cuda.const.array_like(ary)
157
158   Copies the ``ary`` into constant memory space on the CUDA kernel at compile
159   time.
160
161   Returns an array like the ``ary`` argument.
162
163   .. note:: All threads and blocks see the same array.
164
165Synchronization and Atomic Operations
166~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
167
168.. function:: numba.cuda.atomic.add(array, idx, value)
169
170    Perform ``array[idx] += value``. Support int32, int64, float32 and
171    float64 only. The ``idx`` argument can be an integer or a tuple of integer
172    indices for indexing into multiple dimensional arrays. The number of element
173    in ``idx`` must match the number of dimension of ``array``.
174
175    Returns the value of ``array[idx]`` before the storing the new value.
176    Behaves like an atomic load.
177
178.. function:: numba.cuda.atomic.max(array, idx, value)
179
180    Perform ``array[idx] = max(array[idx], value)``. Support int32, int64,
181    float32 and float64 only. The ``idx`` argument can be an integer or a
182    tuple of integer indices for indexing into multiple dimensional arrays.
183    The number of element in ``idx`` must match the number of dimension of
184    ``array``.
185
186    Returns the value of ``array[idx]`` before the storing the new value.
187    Behaves like an atomic load.
188
189
190.. function:: numba.cuda.syncthreads
191
192    Synchronize all threads in the same thread block.  This function implements
193    the same pattern as barriers in traditional multi-threaded programming: this
194    function waits until all threads in the block call it, at which point it
195    returns control to all its callers.
196
197.. function:: numba.cuda.syncthreads_count(predicate)
198
199    An extension to :attr:`numba.cuda.syncthreads` where the return value is a count
200    of the threads where ``predicate`` is true.
201
202.. function:: numba.cuda.syncthreads_and(predicate)
203
204    An extension to :attr:`numba.cuda.syncthreads` where 1 is returned if ``predicate`` is
205    true for all threads or 0 otherwise.
206
207.. function:: numba.cuda.syncthreads_or(predicate)
208
209    An extension to :attr:`numba.cuda.syncthreads` where 1 is returned if ``predicate`` is
210    true for any thread or 0 otherwise.
211
212    .. warning:: All syncthreads functions must be called by every thread in the
213                 thread-block. Falling to do so may result in undefined behavior.
214
215Memory Fences
216~~~~~~~~~~~~~
217
218The memory fences are used to guarantee the effect of memory operations
219are visible by other threads within the same thread-block, the same GPU device,
220and the same system (across GPUs on global memory). Memory loads and stores
221are guaranteed to not move across the memory fences by optimization passes.
222
223.. warning:: The memory fences are considered to be advanced API and most
224             usercases should use the thread barrier (e.g. ``syncthreads()``).
225
226
227
228.. function:: numba.cuda.threadfence
229
230   A memory fence at device level (within the GPU).
231
232.. function:: numba.cuda.threadfence_block
233
234   A memory fence at thread block level.
235
236.. function:: numba.cuda.threadfence_system
237
238
239   A memory fence at system level (across GPUs).
240
241Warp Intrinsics
242~~~~~~~~~~~~~~~~~~
243
244All warp level operations require at least CUDA 9. The argument ``membermask`` is
245a 32 bit integer mask with each bit corresponding to a thread in the warp, with 1
246meaning the thread is in the subset of threads within the function call. The
247``membermask`` must be all 1 if the GPU compute capability is below 7.x.
248
249.. function:: numba.cuda.syncwarp(membermask)
250
251   Synchronize a masked subset of the threads in a warp.
252
253.. function:: numba.cuda.all_sync(membermask, predicate)
254
255    If the ``predicate`` is true for all threads in the masked warp, then
256    a non-zero value is returned, otherwise 0 is returned.
257
258.. function:: numba.cuda.any_sync(membermask, predicate)
259
260    If the ``predicate`` is true for any thread in the masked warp, then
261    a non-zero value is returned, otherwise 0 is returned.
262
263.. function:: numba.cuda.eq_sync(membermask, predicate)
264
265    If the boolean ``predicate`` is the same for all threads in the masked warp,
266    then a non-zero value is returned, otherwise 0 is returned.
267
268.. function:: numba.cuda.ballot_sync(membermask, predicate)
269
270    Returns a mask of all threads in the warp whose ``predicate`` is true,
271    and are within the given mask.
272
273.. function:: numba.cuda.shfl_sync(membermask, value, src_lane)
274
275    Shuffles ``value`` across the masked warp and returns the ``value``
276    from ``src_lane``. If this is outside the warp, then the
277    given ``value`` is returned.
278
279.. function:: numba.cuda.shfl_up_sync(membermask, value, delta)
280
281    Shuffles ``value`` across the masked warp and returns the ``value``
282    from ``laneid - delta``. If this is outside the warp, then the
283    given ``value`` is returned.
284
285.. function:: numba.cuda.shfl_down_sync(membermask, value, delta)
286
287    Shuffles ``value`` across the masked warp and returns the ``value``
288    from ``laneid + delta``. If this is outside the warp, then the
289    given ``value`` is returned.
290
291.. function:: numba.cuda.shfl_xor_sync(membermask, value, lane_mask)
292
293    Shuffles ``value`` across the masked warp and returns the ``value``
294    from ``laneid ^ lane_mask``.
295
296.. function:: numba.cuda.match_any_sync(membermask, value, lane_mask)
297
298    Returns a mask of threads that have same ``value`` as the given ``value``
299    from within the masked warp.
300
301.. function:: numba.cuda.match_all_sync(membermask, value, lane_mask)
302
303    Returns a tuple of (mask, pred), where mask is a mask of threads that have
304    same ``value`` as the given ``value`` from within the masked warp, if they
305    all have the same value, otherwise it is 0. And pred is a boolean of whether
306    or not all threads in the mask warp have the same warp.
307
308
309Integer Intrinsics
310~~~~~~~~~~~~~~~~~~
311
312A subset of the CUDA Math API's integer intrinsics are available. For further
313documentation, including semantics, please refer to the `CUDA Toolkit
314documentation
315<https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__INTRINSIC__INT.html>`_.
316
317
318.. function:: numba.cuda.popc
319
320   Returns the number of set bits in the given value.
321
322.. function:: numba.cuda.brev
323
324   Reverses the bit pattern of an integer value, for example 0b10110110
325   becomes 0b01101101.
326
327.. function:: numba.cuda.clz
328
329   Counts the number of leading zeros in a value.
330
331.. function:: numba.cuda.ffs
332
333   Find the position of the least significant bit set to 1 in an integer.
334
335
336Floating Point Intrinsics
337~~~~~~~~~~~~~~~~~~~~~~~~~
338
339A subset of the CUDA Math API's floating point intrinsics are available. For further
340documentation, including semantics, please refer to the `single
341<https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html>`_ and
342`double <https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__DOUBLE.html>`_
343precision parts of the CUDA Toolkit documentation.
344
345
346.. function:: numba.cuda.fma
347
348   Perform the fused multiply-add operation. Named after the ``fma`` and ``fmaf`` in
349   the C api, but maps to the ``fma.rn.f32`` and ``fma.rn.f64`` (round-to-nearest-even)
350   PTX instructions.
351
352
353Control Flow Instructions
354~~~~~~~~~~~~~~~~~~~~~~~~~
355
356A subset of the CUDA's control flow instructions are directly available as
357intrinsics. Avoiding branches is a key way to improve CUDA performance, and
358using these intrinsics mean you don't have to rely on the ``nvcc`` optimizer
359identifying and removing branches. For further documentation, including
360semantics, please refer to the `relevant CUDA Toolkit documentation
361<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions>`_.
362
363
364.. function:: numba.cuda.selp
365
366    Select between two expressions, depending on the value of the first
367    argument. Similar to LLVM's ``select`` instruction.
368