1"""
2This is a direct translation of nvvm.h
3"""
4import sys, logging, re
5from ctypes import (c_void_p, c_int, POINTER, c_char_p, c_size_t, byref,
6                    c_char)
7
8import threading
9
10from llvmlite import ir
11
12from .error import NvvmError, NvvmSupportError
13from .libs import get_libdevice, open_libdevice, open_cudalib
14from numba.core import config
15
16
17logger = logging.getLogger(__name__)
18
19ADDRSPACE_GENERIC = 0
20ADDRSPACE_GLOBAL = 1
21ADDRSPACE_SHARED = 3
22ADDRSPACE_CONSTANT = 4
23ADDRSPACE_LOCAL = 5
24
25# Opaque handle for compilation unit
26nvvm_program = c_void_p
27
28# Result code
29nvvm_result = c_int
30
31RESULT_CODE_NAMES = '''
32NVVM_SUCCESS
33NVVM_ERROR_OUT_OF_MEMORY
34NVVM_ERROR_PROGRAM_CREATION_FAILURE
35NVVM_ERROR_IR_VERSION_MISMATCH
36NVVM_ERROR_INVALID_INPUT
37NVVM_ERROR_INVALID_PROGRAM
38NVVM_ERROR_INVALID_IR
39NVVM_ERROR_INVALID_OPTION
40NVVM_ERROR_NO_MODULE_IN_PROGRAM
41NVVM_ERROR_COMPILATION
42'''.split()
43
44for i, k in enumerate(RESULT_CODE_NAMES):
45    setattr(sys.modules[__name__], k, i)
46
47
48def is_available():
49    """
50    Return if libNVVM is available
51    """
52    try:
53        NVVM()
54    except NvvmSupportError:
55        return False
56    else:
57        return True
58
59
60_nvvm_lock = threading.Lock()
61
62class NVVM(object):
63    '''Process-wide singleton.
64    '''
65    _PROTOTYPES = {
66
67        # nvvmResult nvvmVersion(int *major, int *minor)
68        'nvvmVersion': (nvvm_result, POINTER(c_int), POINTER(c_int)),
69
70        # nvvmResult nvvmCreateProgram(nvvmProgram *cu)
71        'nvvmCreateProgram': (nvvm_result, POINTER(nvvm_program)),
72
73        # nvvmResult nvvmDestroyProgram(nvvmProgram *cu)
74        'nvvmDestroyProgram': (nvvm_result, POINTER(nvvm_program)),
75
76        # nvvmResult nvvmAddModuleToProgram(nvvmProgram cu, const char *buffer,
77        #                                   size_t size, const char *name)
78        'nvvmAddModuleToProgram': (
79            nvvm_result, nvvm_program, c_char_p, c_size_t, c_char_p),
80
81        # nvvmResult nvvmCompileProgram(nvvmProgram cu, int numOptions,
82        #                          const char **options)
83        'nvvmCompileProgram': (
84            nvvm_result, nvvm_program, c_int, POINTER(c_char_p)),
85
86        # nvvmResult nvvmGetCompiledResultSize(nvvmProgram cu,
87        #                                      size_t *bufferSizeRet)
88        'nvvmGetCompiledResultSize': (
89            nvvm_result, nvvm_program, POINTER(c_size_t)),
90
91        # nvvmResult nvvmGetCompiledResult(nvvmProgram cu, char *buffer)
92        'nvvmGetCompiledResult': (nvvm_result, nvvm_program, c_char_p),
93
94        # nvvmResult nvvmGetProgramLogSize(nvvmProgram cu,
95        #                                      size_t *bufferSizeRet)
96        'nvvmGetProgramLogSize': (nvvm_result, nvvm_program, POINTER(c_size_t)),
97
98        # nvvmResult nvvmGetProgramLog(nvvmProgram cu, char *buffer)
99        'nvvmGetProgramLog': (nvvm_result, nvvm_program, c_char_p),
100    }
101
102    # Singleton reference
103    __INSTANCE = None
104
105    def __new__(cls):
106        with _nvvm_lock:
107            if cls.__INSTANCE is None:
108                cls.__INSTANCE = inst = object.__new__(cls)
109                try:
110                    inst.driver = open_cudalib('nvvm')
111                except OSError as e:
112                    cls.__INSTANCE = None
113                    errmsg = ("libNVVM cannot be found. Do `conda install "
114                              "cudatoolkit`:\n%s")
115                    raise NvvmSupportError(errmsg % e)
116
117                # Find & populate functions
118                for name, proto in inst._PROTOTYPES.items():
119                    func = getattr(inst.driver, name)
120                    func.restype = proto[0]
121                    func.argtypes = proto[1:]
122                    setattr(inst, name, func)
123
124        return cls.__INSTANCE
125
126    def get_version(self):
127        major = c_int()
128        minor = c_int()
129        err = self.nvvmVersion(byref(major), byref(minor))
130        self.check_error(err, 'Failed to get version.')
131        return major.value, minor.value
132
133    def check_error(self, error, msg, exit=False):
134        if error:
135            exc = NvvmError(msg, RESULT_CODE_NAMES[error])
136            if exit:
137                print(exc)
138                sys.exit(1)
139            else:
140                raise exc
141
142
143class CompilationUnit(object):
144    def __init__(self):
145        self.driver = NVVM()
146        self._handle = nvvm_program()
147        err = self.driver.nvvmCreateProgram(byref(self._handle))
148        self.driver.check_error(err, 'Failed to create CU')
149
150    def __del__(self):
151        driver = NVVM()
152        err = driver.nvvmDestroyProgram(byref(self._handle))
153        driver.check_error(err, 'Failed to destroy CU', exit=True)
154
155    def add_module(self, buffer):
156        """
157         Add a module level NVVM IR to a compilation unit.
158         - The buffer should contain an NVVM module IR either in the bitcode
159           representation (LLVM3.0) or in the text representation.
160        """
161        err = self.driver.nvvmAddModuleToProgram(self._handle, buffer,
162                                                 len(buffer), None)
163        self.driver.check_error(err, 'Failed to add module')
164
165    def compile(self, **options):
166        """Perform Compilation
167
168        The valid compiler options are
169
170         *   - -g (enable generation of debugging information)
171         *   - -opt=
172         *     - 0 (disable optimizations)
173         *     - 3 (default, enable optimizations)
174         *   - -arch=
175         *     - compute_20 (default)
176         *     - compute_30
177         *     - compute_35
178         *   - -ftz=
179         *     - 0 (default, preserve denormal values, when performing
180         *          single-precision floating-point operations)
181         *     - 1 (flush denormal values to zero, when performing
182         *          single-precision floating-point operations)
183         *   - -prec-sqrt=
184         *     - 0 (use a faster approximation for single-precision
185         *          floating-point square root)
186         *     - 1 (default, use IEEE round-to-nearest mode for
187         *          single-precision floating-point square root)
188         *   - -prec-div=
189         *     - 0 (use a faster approximation for single-precision
190         *          floating-point division and reciprocals)
191         *     - 1 (default, use IEEE round-to-nearest mode for
192         *          single-precision floating-point division and reciprocals)
193         *   - -fma=
194         *     - 0 (disable FMA contraction)
195         *     - 1 (default, enable FMA contraction)
196         *
197         """
198
199        # stringify options
200        opts = []
201        if 'debug' in options:
202            if options.pop('debug'):
203                opts.append('-g')
204
205        if 'opt' in options:
206            opts.append('-opt=%d' % options.pop('opt'))
207
208        if options.get('arch'):
209            opts.append('-arch=%s' % options.pop('arch'))
210
211        other_options = (
212            'ftz',
213            'prec_sqrt',
214            'prec_div',
215            'fma',
216        )
217
218        for k in other_options:
219            if k in options:
220                v = int(bool(options.pop(k)))
221                opts.append('-%s=%d' % (k.replace('_', '-'), v))
222
223        # If there are any option left
224        if options:
225            optstr = ', '.join(map(repr, options.keys()))
226            raise NvvmError("unsupported option {0}".format(optstr))
227
228        # compile
229        c_opts = (c_char_p * len(opts))(*[c_char_p(x.encode('utf8'))
230                                          for x in opts])
231        err = self.driver.nvvmCompileProgram(self._handle, len(opts), c_opts)
232        self._try_error(err, 'Failed to compile\n')
233
234        # get result
235        reslen = c_size_t()
236        err = self.driver.nvvmGetCompiledResultSize(self._handle, byref(reslen))
237
238        self._try_error(err, 'Failed to get size of compiled result.')
239
240        ptxbuf = (c_char * reslen.value)()
241        err = self.driver.nvvmGetCompiledResult(self._handle, ptxbuf)
242        self._try_error(err, 'Failed to get compiled result.')
243
244        # get log
245        self.log = self.get_log()
246
247        return ptxbuf[:]
248
249    def _try_error(self, err, msg):
250        self.driver.check_error(err, "%s\n%s" % (msg, self.get_log()))
251
252    def get_log(self):
253        reslen = c_size_t()
254        err = self.driver.nvvmGetProgramLogSize(self._handle, byref(reslen))
255        self.driver.check_error(err, 'Failed to get compilation log size.')
256
257        if reslen.value > 1:
258            logbuf = (c_char * reslen.value)()
259            err = self.driver.nvvmGetProgramLog(self._handle, logbuf)
260            self.driver.check_error(err, 'Failed to get compilation log.')
261
262            return logbuf.value.decode('utf8')  # populate log attribute
263
264        return ''
265
266
267data_layout = {
268    32: ('e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-'
269         'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64'),
270    64: ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-'
271         'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64')}
272
273default_data_layout = data_layout[tuple.__itemsize__ * 8]
274
275_supported_cc = None
276
277
278def get_supported_ccs():
279    global _supported_cc
280
281    if _supported_cc:
282        return _supported_cc
283
284    try:
285        from numba.cuda.cudadrv.runtime import runtime
286        cudart_version_major = runtime.get_version()[0]
287    except:
288        # The CUDA Runtime may not be present
289        cudart_version_major = 0
290
291    # List of supported compute capability in sorted order
292    if cudart_version_major == 0:
293        _supported_cc = (),
294    elif cudart_version_major < 9:
295        # CUDA 8.x
296        _supported_cc = (2, 0), (2, 1), (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2)
297    elif cudart_version_major < 10:
298        # CUDA 9.x
299        _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0)
300    elif cudart_version_major < 11:
301        # CUDA 10.x
302        _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5)
303    else:
304        # CUDA 11.0 and later
305        _supported_cc = (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5), (8, 0)
306
307    return _supported_cc
308
309
310def find_closest_arch(mycc):
311    """
312    Given a compute capability, return the closest compute capability supported
313    by the CUDA toolkit.
314
315    :param mycc: Compute capability as a tuple ``(MAJOR, MINOR)``
316    :return: Closest supported CC as a tuple ``(MAJOR, MINOR)``
317    """
318    supported_cc = get_supported_ccs()
319
320    for i, cc in enumerate(supported_cc):
321        if cc == mycc:
322            # Matches
323            return cc
324        elif cc > mycc:
325            # Exceeded
326            if i == 0:
327                # CC lower than supported
328                raise NvvmSupportError("GPU compute capability %d.%d is "
329                                       "not supported (requires >=%d.%d)" % (mycc + cc))
330            else:
331                # return the previous CC
332                return supported_cc[i - 1]
333
334    # CC higher than supported
335    return supported_cc[-1]  # Choose the highest
336
337
338def get_arch_option(major, minor):
339    """Matches with the closest architecture option
340    """
341    if config.FORCE_CUDA_CC:
342        arch = config.FORCE_CUDA_CC
343    else:
344        arch = find_closest_arch((major, minor))
345    return 'compute_%d%d' % arch
346
347
348MISSING_LIBDEVICE_FILE_MSG = '''Missing libdevice file for {arch}.
349Please ensure you have package cudatoolkit >= 8.
350Install package by:
351
352    conda install cudatoolkit
353'''
354
355
356class LibDevice(object):
357    _cache_ = {}
358    _known_arch = [
359        "compute_20",
360        "compute_30",
361        "compute_35",
362        "compute_50",
363    ]
364
365    def __init__(self, arch):
366        """
367        arch --- must be result from get_arch_option()
368        """
369        if arch not in self._cache_:
370            arch = self._get_closest_arch(arch)
371            if get_libdevice(arch) is None:
372                raise RuntimeError(MISSING_LIBDEVICE_FILE_MSG.format(arch=arch))
373            self._cache_[arch] = open_libdevice(arch)
374
375        self.arch = arch
376        self.bc = self._cache_[arch]
377
378    def _get_closest_arch(self, arch):
379        res = self._known_arch[0]
380        for potential in self._known_arch:
381            if arch >= potential:
382                res = potential
383        return res
384
385    def get(self):
386        return self.bc
387
388
389ir_numba_cas_hack = """
390define internal i32 @___numba_cas_hack(i32* %ptr, i32 %cmp, i32 %val) alwaysinline {
391    %out = cmpxchg volatile i32* %ptr, i32 %cmp, i32 %val monotonic
392    ret i32 %out
393}
394"""
395
396# Translation of code from CUDA Programming Guide v6.5, section B.12
397ir_numba_atomic_double_add = """
398define internal double @___numba_atomic_double_add(double* %ptr, double %val) alwaysinline {
399entry:
400    %iptr = bitcast double* %ptr to i64*
401    %old2 = load volatile i64, i64* %iptr
402    br label %attempt
403
404attempt:
405    %old = phi i64 [ %old2, %entry ], [ %cas, %attempt ]
406    %dold = bitcast i64 %old to double
407    %dnew = fadd double %dold, %val
408    %new = bitcast double %dnew to i64
409    %cas = cmpxchg volatile i64* %iptr, i64 %old, i64 %new monotonic
410    %repeat = icmp ne i64 %cas, %old
411    br i1 %repeat, label %attempt, label %done
412
413done:
414    %result = bitcast i64 %old to double
415    ret double %result
416}
417"""
418
419
420ir_numba_atomic_minmax = """
421define internal {T} @___numba_atomic_{T}_{NAN}{FUNC}({T}* %ptr, {T} %val) alwaysinline {{
422entry:
423    %ptrval = load volatile {T}, {T}* %ptr
424    ; Return early when:
425    ; - For nanmin / nanmax when val is a NaN
426    ; - For min / max when val or ptr is a NaN
427    %early_return = fcmp uno {T} %val, %{PTR_OR_VAL}val
428    br i1 %early_return, label %done, label %lt_check
429
430lt_check:
431    %dold = phi {T} [ %ptrval, %entry ], [ %dcas, %attempt ]
432    ; Continue attempts if dold less or greater than val (depending on whether min or max)
433    ; or if dold is NaN (for nanmin / nanmax)
434    %cmp = fcmp {OP} {T} %dold, %val
435    br i1 %cmp, label %attempt, label %done
436
437attempt:
438    ; Attempt to swap in the value
439    %iold = bitcast {T} %dold to {Ti}
440    %iptr = bitcast {T}* %ptr to {Ti}*
441    %ival = bitcast {T} %val to {Ti}
442    %cas = cmpxchg volatile {Ti}* %iptr, {Ti} %iold, {Ti} %ival monotonic
443    %dcas = bitcast {Ti} %cas to {T}
444    br label %lt_check
445
446done:
447    ret {T} %ptrval
448}}
449"""
450
451
452def _replace_datalayout(llvmir):
453    """
454    Find the line containing the datalayout and replace it
455    """
456    lines = llvmir.splitlines()
457    for i, ln in enumerate(lines):
458        if ln.startswith("target datalayout"):
459            tmp = 'target datalayout = "{0}"'
460            lines[i] = tmp.format(default_data_layout)
461            break
462    return '\n'.join(lines)
463
464
465def llvm_to_ptx(llvmir, **opts):
466    if opts.pop('fastmath', False):
467        opts.update({
468            'ftz': True,
469            'fma': True,
470            'prec_div': False,
471            'prec_sqrt': False,
472        })
473
474    cu = CompilationUnit()
475    libdevice = LibDevice(arch=opts.get('arch', 'compute_20'))
476    # New LLVM generate a shorthand for datalayout that NVVM does not know
477    llvmir = _replace_datalayout(llvmir)
478    # Replace with our cmpxchg and atomic implementations because LLVM 3.5 has
479    # a new semantic for cmpxchg.
480    replacements = [
481        ('declare i32 @___numba_cas_hack(i32*, i32, i32)',
482         ir_numba_cas_hack),
483        ('declare double @___numba_atomic_double_add(double*, double)',
484         ir_numba_atomic_double_add),
485        ('declare float @___numba_atomic_float_max(float*, float)',
486         ir_numba_atomic_minmax.format(T='float', Ti='i32', NAN='', OP='nnan olt',
487                                    PTR_OR_VAL='ptr', FUNC='max')),
488        ('declare double @___numba_atomic_double_max(double*, double)',
489         ir_numba_atomic_minmax.format(T='double', Ti='i64', NAN='', OP='nnan olt',
490                                    PTR_OR_VAL='ptr', FUNC='max')),
491        ('declare float @___numba_atomic_float_min(float*, float)',
492         ir_numba_atomic_minmax.format(T='float', Ti='i32', NAN='', OP='nnan ogt',
493                                    PTR_OR_VAL='ptr', FUNC='min')),
494        ('declare double @___numba_atomic_double_min(double*, double)',
495         ir_numba_atomic_minmax.format(T='double', Ti='i64', NAN='', OP='nnan ogt',
496                                    PTR_OR_VAL='ptr', FUNC='min')),
497        ('declare float @___numba_atomic_float_nanmax(float*, float)',
498         ir_numba_atomic_minmax.format(T='float', Ti='i32', NAN='nan', OP='ult',
499                                    PTR_OR_VAL='', FUNC='max')),
500        ('declare double @___numba_atomic_double_nanmax(double*, double)',
501         ir_numba_atomic_minmax.format(T='double', Ti='i64', NAN='nan', OP='ult',
502                                    PTR_OR_VAL='', FUNC='max')),
503        ('declare float @___numba_atomic_float_nanmin(float*, float)',
504         ir_numba_atomic_minmax.format(T='float', Ti='i32', NAN='nan', OP='ugt',
505                                    PTR_OR_VAL='', FUNC='min')),
506        ('declare double @___numba_atomic_double_nanmin(double*, double)',
507         ir_numba_atomic_minmax.format(T='double', Ti='i64', NAN='nan', OP='ugt',
508                                    PTR_OR_VAL='', FUNC='min')),
509        ('immarg', '')
510    ]
511
512    for decl, fn in replacements:
513        llvmir = llvmir.replace(decl, fn)
514
515    # llvm.numba_nvvm.atomic is used to prevent LLVM 9 onwards auto-upgrading
516    # these intrinsics into atomicrmw instructions, which are not recognized by
517    # NVVM. We can now replace them with the real intrinsic names, ready to
518    # pass to NVVM.
519    llvmir = llvmir.replace('llvm.numba_nvvm.atomic', 'llvm.nvvm.atomic')
520
521    llvmir = llvm39_to_34_ir(llvmir)
522    cu.add_module(llvmir.encode('utf8'))
523    cu.add_module(libdevice.get())
524
525    ptx = cu.compile(**opts)
526    # XXX remove debug_pubnames seems to be necessary sometimes
527    return patch_ptx_debug_pubnames(ptx)
528
529
530def patch_ptx_debug_pubnames(ptx):
531    """
532    Patch PTX to workaround .debug_pubnames NVVM error::
533
534        ptxas fatal   : Internal error: overlapping non-identical data
535
536    """
537    while True:
538        # Repeatedly remove debug_pubnames sections
539        start = ptx.find(b'.section .debug_pubnames')
540        if start < 0:
541            break
542        stop = ptx.find(b'}', start)
543        if stop < 0:
544            raise ValueError('missing "}"')
545        ptx = ptx[:start] + ptx[stop + 1:]
546    return ptx
547
548
549re_metadata_def = re.compile(r"\!\d+\s*=")
550re_metadata_correct_usage = re.compile(r"metadata\s*\![{'\"0-9]")
551re_metadata_ref = re.compile(r"\!\d+")
552re_metadata_debuginfo = re.compile(r"\!{i32 \d, \!\"Debug Info Version\", i32 \d}".replace(' ', r'\s+'))
553
554re_attributes_def = re.compile(r"^attributes #\d+ = \{ ([\w\s]+)\ }")
555supported_attributes = {'alwaysinline', 'cold', 'inlinehint', 'minsize',
556                        'noduplicate', 'noinline', 'noreturn', 'nounwind',
557                        'optnone', 'optisze', 'readnone', 'readonly'}
558
559re_getelementptr = re.compile(r"\bgetelementptr\s(?:inbounds )?\(?")
560
561re_load = re.compile(r"=\s*\bload\s(?:\bvolatile\s)?")
562
563re_call = re.compile(r"(call\s[^@]+\))(\s@)")
564re_range = re.compile(r"\s*!range\s+!\d+")
565
566re_type_tok = re.compile(r"[,{}()[\]]")
567
568re_annotations = re.compile(r"\bnonnull\b")
569
570re_unsupported_keywords = re.compile(r"\b(local_unnamed_addr|writeonly)\b")
571
572re_parenthesized_list = re.compile(r"\((.*)\)")
573
574
575def llvm39_to_34_ir(ir):
576    """
577    Convert LLVM 3.9 IR for LLVM 3.4.
578    """
579    def parse_out_leading_type(s):
580        par_level = 0
581        pos = 0
582        # Parse out the first <ty> (which may be an aggregate type)
583        while True:
584            m = re_type_tok.search(s, pos)
585            if m is None:
586                # End of line
587                raise RuntimeError("failed parsing leading type: %s" % (s,))
588                break
589            pos = m.end()
590            tok = m.group(0)
591            if tok == ',':
592                if par_level == 0:
593                    # End of operand
594                    break
595            elif tok in '{[(':
596                par_level += 1
597            elif tok in ')]}':
598                par_level -= 1
599        return s[pos:].lstrip()
600
601    buf = []
602    for line in ir.splitlines():
603
604        # Fix llvm.dbg.cu
605        if line.startswith('!numba.llvm.dbg.cu'):
606            line = line.replace('!numba.llvm.dbg.cu', '!llvm.dbg.cu')
607
608        # We insert a dummy inlineasm to put debuginfo
609        if (line.lstrip().startswith('tail call void asm sideeffect "// dbg') and
610                '!numba.dbg' in line):
611            # Fix the metadata
612            line = line.replace('!numba.dbg', '!dbg')
613        if re_metadata_def.match(line):
614            # Rewrite metadata since LLVM 3.7 dropped the "metadata" type prefix.
615            if None is re_metadata_correct_usage.search(line):
616                # Reintroduce the "metadata" prefix
617                line = line.replace('!{', 'metadata !{')
618                line = line.replace('!"', 'metadata !"')
619
620                assigpos = line.find('=')
621                lhs, rhs = line[:assigpos + 1], line[assigpos + 1:]
622
623                # Fix metadata reference
624                def fix_metadata_ref(m):
625                    return 'metadata ' + m.group(0)
626                line = ' '.join((lhs, re_metadata_ref.sub(fix_metadata_ref, rhs)))
627        if line.startswith('source_filename ='):
628            continue    # skip line
629        if re_unsupported_keywords.search(line) is not None:
630            line = re_unsupported_keywords.sub(lambda m: '', line)
631
632        if line.startswith('attributes #'):
633            # Remove function attributes unsupported pre-3.8
634            m = re_attributes_def.match(line)
635            attrs = m.group(1).split()
636            attrs = ' '.join(a for a in attrs if a in supported_attributes)
637            line = line.replace(m.group(1), attrs)
638        if 'getelementptr ' in line:
639            # Rewrite "getelementptr ty, ty* ptr, ..."
640            # to "getelementptr ty *ptr, ..."
641            m = re_getelementptr.search(line)
642            if m is None:
643                raise RuntimeError("failed parsing getelementptr: %s" % (line,))
644            pos = m.end()
645            line = line[:pos] + parse_out_leading_type(line[pos:])
646        if 'load ' in line:
647            # Rewrite "load ty, ty* ptr"
648            # to "load ty *ptr"
649            m = re_load.search(line)
650            if m:
651                pos = m.end()
652                line = line[:pos] + parse_out_leading_type(line[pos:])
653        if 'call ' in line:
654            # Rewrite "call ty (...) @foo"
655            # to "call ty (...)* @foo"
656            line = re_call.sub(r"\1*\2", line)
657
658            # no !range metadata on calls
659            line = re_range.sub('', line).rstrip(',')
660
661            if '@llvm.memset' in line:
662                line = re_parenthesized_list.sub(
663                    _replace_llvm_memset_usage,
664                    line,
665                    )
666        if 'declare' in line:
667            if '@llvm.memset' in line:
668                line = re_parenthesized_list.sub(
669                    _replace_llvm_memset_declaration,
670                    line,
671                    )
672
673        # Remove unknown annotations
674        line = re_annotations.sub('', line)
675
676        buf.append(line)
677
678    return '\n'.join(buf)
679
680
681def _replace_llvm_memset_usage(m):
682    """Replace `llvm.memset` usage for llvm7+.
683
684    Used as functor for `re.sub.
685    """
686    params = list(m.group(1).split(','))
687    align_attr = re.search(r'align (\d+)', params[0])
688    if not align_attr:
689        raise ValueError("No alignment attribute found on memset dest")
690    else:
691        align = align_attr.group(1)
692    params.insert(-1, 'i32 {}'.format(align))
693    out = ', '.join(params)
694    return '({})'.format(out)
695
696
697def _replace_llvm_memset_declaration(m):
698    """Replace `llvm.memset` declaration for llvm7+.
699
700    Used as functor for `re.sub.
701    """
702    params = list(m.group(1).split(','))
703    params.insert(-1, 'i32')
704    out = ', '.join(params)
705    return '({})'.format(out)
706
707
708def set_cuda_kernel(lfunc):
709    from llvmlite.llvmpy.core import MetaData, MetaDataString, Constant, Type
710
711    m = lfunc.module
712
713    ops = lfunc, MetaDataString.get(m, "kernel"), Constant.int(Type.int(), 1)
714    md = MetaData.get(m, ops)
715
716    nmd = m.get_or_insert_named_metadata('nvvm.annotations')
717    nmd.add(md)
718
719    # set nvvm ir version
720    i32 = ir.IntType(32)
721    md_ver = m.add_metadata([i32(1), i32(2), i32(2), i32(0)])
722    m.add_named_metadata('nvvmir.version', md_ver)
723
724
725def fix_data_layout(module):
726    module.data_layout = default_data_layout
727