1"""
2This scripts specifies all PTX special objects.
3"""
4import functools
5import llvmlite.llvmpy.core as lc
6import operator
7from numba.core.rewrites.macros import Macro
8from numba.core import types, typing, ir
9from .cudadrv import nvvm
10
11
12class Stub(object):
13    '''
14    A stub object to represent special objects that are meaningless
15    outside the context of a CUDA kernel
16    '''
17    _description_ = '<ptx special value>'
18    __slots__ = () # don't allocate __dict__
19
20    def __new__(cls):
21        raise NotImplementedError("%s is not instantiable" % cls)
22
23    def __repr__(self):
24        return self._description_
25
26
27def stub_function(fn):
28    '''
29    A stub function to represent special functions that are meaningless
30    outside the context of a CUDA kernel
31    '''
32    @functools.wraps(fn)
33    def wrapped(*args, **kwargs):
34        raise NotImplementedError("%s cannot be called from host code" % fn)
35    return wrapped
36
37
38#-------------------------------------------------------------------------------
39# Thread and grid indices and dimensions
40
41
42class Dim3(Stub):
43    '''A triple, (x, y, z)'''
44    _description_ = '<Dim3>'
45
46    @property
47    def x(self):
48        pass
49
50    @property
51    def y(self):
52        pass
53
54    @property
55    def z(self):
56        pass
57
58
59class threadIdx(Dim3):
60    '''
61    The thread indices in the current thread block. Each index is an integer
62    spanning the range from 0 inclusive to the corresponding value of the
63    attribute in :attr:`numba.cuda.blockDim` exclusive.
64    '''
65    _description_ = '<threadIdx.{x,y,z}>'
66
67
68class blockIdx(Dim3):
69    '''
70    The block indices in the grid of thread blocks. Each index is an integer
71    spanning the range from 0 inclusive to the corresponding value of the
72    attribute in :attr:`numba.cuda.gridDim` exclusive.
73    '''
74    _description_ = '<blockIdx.{x,y,z}>'
75
76
77class blockDim(Dim3):
78    '''
79    The shape of a block of threads, as declared when instantiating the kernel.
80    This value is the same for all threads in a given kernel launch, even if
81    they belong to different blocks (i.e. each block is "full").
82    '''
83    _description_ = '<blockDim.{x,y,z}>'
84
85
86class gridDim(Dim3):
87    '''
88    The shape of the grid of blocks. This value is the same for all threads in
89    a given kernel launch.
90    '''
91    _description_ = '<gridDim.{x,y,z}>'
92
93
94class warpsize(Stub):
95    '''
96    The size of a warp. All architectures implemented to date have a warp size
97    of 32.
98    '''
99    _description_ = '<warpsize>'
100
101
102class laneid(Stub):
103    '''
104    This thread's lane within a warp. Ranges from 0 to
105    :attr:`numba.cuda.warpsize` - 1.
106    '''
107    _description_ = '<laneid>'
108
109
110class grid(Stub):
111    '''grid(ndim)
112
113    Return the absolute position of the current thread in the entire grid of
114    blocks.  *ndim* should correspond to the number of dimensions declared when
115    instantiating the kernel. If *ndim* is 1, a single integer is returned.
116    If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
117
118	Computation of the first integer is as follows::
119
120		cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
121
122    and is similar for the other two indices, but using the ``y`` and ``z``
123    attributes.
124    '''
125    _description_ = '<grid(ndim)>'
126
127
128class gridsize(Stub):
129    '''gridsize(ndim)
130
131    Return the absolute size (or shape) in threads of the entire grid of
132    blocks. *ndim* should correspond to the number of dimensions declared when
133    instantiating the kernel. If *ndim* is 1, a single integer is returned.
134    If *ndim* is 2 or 3, a tuple of the given number of integers is returned.
135
136    Computation of the first integer is as follows::
137
138        cuda.blockDim.x * cuda.gridDim.x
139
140    and is similar for the other two indices, but using the ``y`` and ``z``
141    attributes.
142    '''
143    _description_ = '<gridsize(ndim)>'
144
145
146#-------------------------------------------------------------------------------
147# Array creation
148
149class shared(Stub):
150    '''
151    Shared memory namespace
152    '''
153    _description_ = '<shared>'
154
155    @stub_function
156    def array(shape, dtype):
157        '''
158        Allocate a shared array of the given *shape* and *type*. *shape* is
159        either an integer or a tuple of integers representing the array's
160        dimensions.  *type* is a :ref:`Numba type <numba-types>` of the
161        elements needing to be stored in the array.
162
163        The returned array-like object can be read and written to like any
164        normal device array (e.g. through indexing).
165        '''
166
167
168class local(Stub):
169    '''
170    Local memory namespace
171    '''
172    _description_ = '<local>'
173
174    @stub_function
175    def array(shape, dtype):
176        '''
177        Allocate a local array of the given *shape* and *type*. The array is
178        private to the current thread, and resides in global memory. An
179        array-like object is returned which can be read and written to like any
180        standard array (e.g.  through indexing).
181        '''
182
183
184class const(Stub):
185    '''
186    Constant memory namespace
187    '''
188
189    @stub_function
190    def array_like(ndarray):
191        '''
192        Create a const array from *ndarry*. The resulting const array will have
193        the same shape, type, and values as *ndarray*.
194        '''
195
196
197#-------------------------------------------------------------------------------
198# syncthreads
199
200class syncthreads(Stub):
201    '''
202    Synchronize all threads in the same thread block.  This function implements
203    the same pattern as barriers in traditional multi-threaded programming: this
204    function waits until all threads in the block call it, at which point it
205    returns control to all its callers.
206    '''
207    _description_ = '<syncthreads()>'
208
209
210class syncthreads_count(Stub):
211    '''
212    syncthreads_count(predictate)
213
214    An extension to numba.cuda.syncthreads where the return value is a count
215    of the threads where predicate is true.
216    '''
217    _description_ = '<syncthreads_count()>'
218
219
220class syncthreads_and(Stub):
221    '''
222    syncthreads_and(predictate)
223
224    An extension to numba.cuda.syncthreads where 1 is returned if predicate is
225    true for all threads or 0 otherwise.
226    '''
227    _description_ = '<syncthreads_and()>'
228
229
230class syncthreads_or(Stub):
231    '''
232    syncthreads_or(predictate)
233
234    An extension to numba.cuda.syncthreads where 1 is returned if predicate is
235    true for any thread or 0 otherwise.
236    '''
237    _description_ = '<syncthreads_or()>'
238
239
240# -------------------------------------------------------------------------------
241# warp level operations
242
243class syncwarp(Stub):
244    '''
245    syncwarp(mask)
246
247    Synchronizes a masked subset of threads in a warp.
248    '''
249    _description_ = '<warp_sync()>'
250
251
252class shfl_sync_intrinsic(Stub):
253    '''
254    shfl_sync_intrinsic(mask, mode, value, mode_offset, clamp)
255
256    Nvvm intrinsic for shuffling data across a warp
257    docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-datamove
258    '''
259    _description_ = '<shfl_sync()>'
260
261
262class vote_sync_intrinsic(Stub):
263    '''
264    vote_sync_intrinsic(mask, mode, predictate)
265
266    Nvvm intrinsic for performing a reduce and broadcast across a warp
267    docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-vote
268    '''
269    _description_ = '<vote_sync()>'
270
271
272class match_any_sync(Stub):
273    '''
274    match_any_sync(mask, value)
275
276    Nvvm intrinsic for performing a compare and broadcast across a warp.
277    Returns a mask of threads that have same value as the given value from
278    within the masked warp.
279    '''
280    _description_ = '<match_any_sync()>'
281
282
283class match_all_sync(Stub):
284    '''
285    match_all_sync(mask, value)
286
287    Nvvm intrinsic for performing a compare and broadcast across a warp.
288    Returns a tuple of (mask, pred), where mask is a mask of threads that have
289    same value as the given value from within the masked warp, if they
290    all have the same value, otherwise it is 0. Pred is a boolean of whether
291    or not all threads in the mask warp have the same warp.
292    '''
293    _description_ = '<match_all_sync()>'
294
295
296# -------------------------------------------------------------------------------
297# memory fences
298
299class threadfence_block(Stub):
300    '''
301    A memory fence at thread block level
302    '''
303    _description_ = '<threadfence_block()>'
304
305
306class threadfence_system(Stub):
307    '''
308    A memory fence at system level: across devices
309    '''
310    _description_ = '<threadfence_system()>'
311
312
313class threadfence(Stub):
314    '''
315    A memory fence at device level
316    '''
317    _description_ = '<threadfence()>'
318
319
320#-------------------------------------------------------------------------------
321# bit manipulation
322
323class popc(Stub):
324    """
325    popc(val)
326
327    Returns the number of set bits in the given value.
328    """
329
330
331class brev(Stub):
332    """
333    brev(val)
334
335    Reverse the bitpattern of an integer value; for example 0b10110110
336    becomes 0b01101101.
337    """
338
339
340class clz(Stub):
341    """
342    clz(val)
343
344    Counts the number of leading zeros in a value.
345    """
346
347
348class ffs(Stub):
349    """
350    ffs(val)
351
352    Find the position of the least significant bit set to 1 in an integer.
353    """
354
355#-------------------------------------------------------------------------------
356# comparison and selection instructions
357
358class selp(Stub):
359    """
360    selp(a, b, c)
361
362    Select between source operands, based on the value of the predicate source
363    operand.
364    """
365
366#-------------------------------------------------------------------------------
367# single / double precision arithmetic
368
369class fma(Stub):
370    """
371    fma(a, b, c)
372
373    Perform the fused multiply-add operation.
374    """
375
376#-------------------------------------------------------------------------------
377# atomic
378
379class atomic(Stub):
380    """Namespace for atomic operations
381    """
382    _description_ = '<atomic>'
383
384    class add(Stub):
385        """add(ary, idx, val)
386
387        Perform atomic ary[idx] += val. Supported on int32, float32, and
388        float64 operands only.
389
390        Returns the old value at the index location as if it is loaded
391        atomically.
392        """
393
394    class max(Stub):
395        """max(ary, idx, val)
396
397        Perform atomic ary[idx] = max(ary[idx], val).
398
399        Supported on int32, int64, uint32, uint64, float32, float64 operands
400        only.
401
402        Returns the old value at the index location as if it is loaded
403        atomically.
404        """
405
406    class min(Stub):
407        """min(ary, idx, val)
408
409        Perform atomic ary[idx] = min(ary[idx], val).
410
411        Supported on int32, int64, uint32, uint64, float32, float64 operands
412        only.
413
414        Returns the old value at the index location as if it is loaded
415        atomically.
416        """
417
418    class nanmax(Stub):
419        """nanmax(ary, idx, val)
420
421        Perform atomic ary[idx] = max(ary[idx], val).
422
423        NOTE: NaN is treated as a missing value such that:
424        nanmax(NaN, n) == n, nanmax(n, NaN) == n
425
426        Supported on int32, int64, uint32, uint64, float32, float64 operands
427        only.
428
429        Returns the old value at the index location as if it is loaded
430        atomically.
431        """
432
433    class nanmin(Stub):
434        """nanmin(ary, idx, val)
435
436        Perform atomic ary[idx] = min(ary[idx], val).
437
438        NOTE: NaN is treated as a missing value, such that:
439        nanmin(NaN, n) == n, nanmin(n, NaN) == n
440
441        Supported on int32, int64, uint32, uint64, float32, float64 operands
442        only.
443
444        Returns the old value at the index location as if it is loaded
445        atomically.
446        """
447
448    class compare_and_swap(Stub):
449        """compare_and_swap(ary, old, val)
450
451        Conditionally assign ``val`` to the first element of an 1D array ``ary``
452        if the current value matches ``old``.
453
454        Returns the current value as if it is loaded atomically.
455        """
456