1""" 2This is a direct translation of nvvm.h 3""" 4import sys, logging, re 5from ctypes import (c_void_p, c_int, POINTER, c_char_p, c_size_t, byref, 6 c_char) 7 8import threading 9 10from llvmlite import ir 11 12from .error import NvvmError, NvvmSupportError 13from .libs import get_libdevice, open_libdevice, open_cudalib 14from numba.core import config 15 16 17logger = logging.getLogger(__name__) 18 19ADDRSPACE_GENERIC = 0 20ADDRSPACE_GLOBAL = 1 21ADDRSPACE_SHARED = 3 22ADDRSPACE_CONSTANT = 4 23ADDRSPACE_LOCAL = 5 24 25# Opaque handle for compilation unit 26nvvm_program = c_void_p 27 28# Result code 29nvvm_result = c_int 30 31RESULT_CODE_NAMES = ''' 32NVVM_SUCCESS 33NVVM_ERROR_OUT_OF_MEMORY 34NVVM_ERROR_PROGRAM_CREATION_FAILURE 35NVVM_ERROR_IR_VERSION_MISMATCH 36NVVM_ERROR_INVALID_INPUT 37NVVM_ERROR_INVALID_PROGRAM 38NVVM_ERROR_INVALID_IR 39NVVM_ERROR_INVALID_OPTION 40NVVM_ERROR_NO_MODULE_IN_PROGRAM 41NVVM_ERROR_COMPILATION 42'''.split() 43 44for i, k in enumerate(RESULT_CODE_NAMES): 45 setattr(sys.modules[__name__], k, i) 46 47 48def is_available(): 49 """ 50 Return if libNVVM is available 51 """ 52 try: 53 NVVM() 54 except NvvmSupportError: 55 return False 56 else: 57 return True 58 59 60_nvvm_lock = threading.Lock() 61 62class NVVM(object): 63 '''Process-wide singleton. 64 ''' 65 _PROTOTYPES = { 66 67 # nvvmResult nvvmVersion(int *major, int *minor) 68 'nvvmVersion': (nvvm_result, POINTER(c_int), POINTER(c_int)), 69 70 # nvvmResult nvvmCreateProgram(nvvmProgram *cu) 71 'nvvmCreateProgram': (nvvm_result, POINTER(nvvm_program)), 72 73 # nvvmResult nvvmDestroyProgram(nvvmProgram *cu) 74 'nvvmDestroyProgram': (nvvm_result, POINTER(nvvm_program)), 75 76 # nvvmResult nvvmAddModuleToProgram(nvvmProgram cu, const char *buffer, 77 # size_t size, const char *name) 78 'nvvmAddModuleToProgram': ( 79 nvvm_result, nvvm_program, c_char_p, c_size_t, c_char_p), 80 81 # nvvmResult nvvmCompileProgram(nvvmProgram cu, int numOptions, 82 # const char **options) 83 'nvvmCompileProgram': ( 84 nvvm_result, nvvm_program, c_int, POINTER(c_char_p)), 85 86 # nvvmResult nvvmGetCompiledResultSize(nvvmProgram cu, 87 # size_t *bufferSizeRet) 88 'nvvmGetCompiledResultSize': ( 89 nvvm_result, nvvm_program, POINTER(c_size_t)), 90 91 # nvvmResult nvvmGetCompiledResult(nvvmProgram cu, char *buffer) 92 'nvvmGetCompiledResult': (nvvm_result, nvvm_program, c_char_p), 93 94 # nvvmResult nvvmGetProgramLogSize(nvvmProgram cu, 95 # size_t *bufferSizeRet) 96 'nvvmGetProgramLogSize': (nvvm_result, nvvm_program, POINTER(c_size_t)), 97 98 # nvvmResult nvvmGetProgramLog(nvvmProgram cu, char *buffer) 99 'nvvmGetProgramLog': (nvvm_result, nvvm_program, c_char_p), 100 } 101 102 # Singleton reference 103 __INSTANCE = None 104 105 def __new__(cls): 106 with _nvvm_lock: 107 if cls.__INSTANCE is None: 108 cls.__INSTANCE = inst = object.__new__(cls) 109 try: 110 inst.driver = open_cudalib('nvvm') 111 except OSError as e: 112 cls.__INSTANCE = None 113 errmsg = ("libNVVM cannot be found. Do `conda install " 114 "cudatoolkit`:\n%s") 115 raise NvvmSupportError(errmsg % e) 116 117 # Find & populate functions 118 for name, proto in inst._PROTOTYPES.items(): 119 func = getattr(inst.driver, name) 120 func.restype = proto[0] 121 func.argtypes = proto[1:] 122 setattr(inst, name, func) 123 124 return cls.__INSTANCE 125 126 def get_version(self): 127 major = c_int() 128 minor = c_int() 129 err = self.nvvmVersion(byref(major), byref(minor)) 130 self.check_error(err, 'Failed to get version.') 131 return major.value, minor.value 132 133 def check_error(self, error, msg, exit=False): 134 if error: 135 exc = NvvmError(msg, RESULT_CODE_NAMES[error]) 136 if exit: 137 print(exc) 138 sys.exit(1) 139 else: 140 raise exc 141 142 143class CompilationUnit(object): 144 def __init__(self): 145 self.driver = NVVM() 146 self._handle = nvvm_program() 147 err = self.driver.nvvmCreateProgram(byref(self._handle)) 148 self.driver.check_error(err, 'Failed to create CU') 149 150 def __del__(self): 151 driver = NVVM() 152 err = driver.nvvmDestroyProgram(byref(self._handle)) 153 driver.check_error(err, 'Failed to destroy CU', exit=True) 154 155 def add_module(self, buffer): 156 """ 157 Add a module level NVVM IR to a compilation unit. 158 - The buffer should contain an NVVM module IR either in the bitcode 159 representation (LLVM3.0) or in the text representation. 160 """ 161 err = self.driver.nvvmAddModuleToProgram(self._handle, buffer, 162 len(buffer), None) 163 self.driver.check_error(err, 'Failed to add module') 164 165 def compile(self, **options): 166 """Perform Compilation 167 168 The valid compiler options are 169 170 * - -g (enable generation of debugging information) 171 * - -opt= 172 * - 0 (disable optimizations) 173 * - 3 (default, enable optimizations) 174 * - -arch= 175 * - compute_20 (default) 176 * - compute_30 177 * - compute_35 178 * - -ftz= 179 * - 0 (default, preserve denormal values, when performing 180 * single-precision floating-point operations) 181 * - 1 (flush denormal values to zero, when performing 182 * single-precision floating-point operations) 183 * - -prec-sqrt= 184 * - 0 (use a faster approximation for single-precision 185 * floating-point square root) 186 * - 1 (default, use IEEE round-to-nearest mode for 187 * single-precision floating-point square root) 188 * - -prec-div= 189 * - 0 (use a faster approximation for single-precision 190 * floating-point division and reciprocals) 191 * - 1 (default, use IEEE round-to-nearest mode for 192 * single-precision floating-point division and reciprocals) 193 * - -fma= 194 * - 0 (disable FMA contraction) 195 * - 1 (default, enable FMA contraction) 196 * 197 """ 198 199 # stringify options 200 opts = [] 201 if 'debug' in options: 202 if options.pop('debug'): 203 opts.append('-g') 204 205 if 'opt' in options: 206 opts.append('-opt=%d' % options.pop('opt')) 207 208 if options.get('arch'): 209 opts.append('-arch=%s' % options.pop('arch')) 210 211 other_options = ( 212 'ftz', 213 'prec_sqrt', 214 'prec_div', 215 'fma', 216 ) 217 218 for k in other_options: 219 if k in options: 220 v = int(bool(options.pop(k))) 221 opts.append('-%s=%d' % (k.replace('_', '-'), v)) 222 223 # If there are any option left 224 if options: 225 optstr = ', '.join(map(repr, options.keys())) 226 raise NvvmError("unsupported option {0}".format(optstr)) 227 228 # compile 229 c_opts = (c_char_p * len(opts))(*[c_char_p(x.encode('utf8')) 230 for x in opts]) 231 err = self.driver.nvvmCompileProgram(self._handle, len(opts), c_opts) 232 self._try_error(err, 'Failed to compile\n') 233 234 # get result 235 reslen = c_size_t() 236 err = self.driver.nvvmGetCompiledResultSize(self._handle, byref(reslen)) 237 238 self._try_error(err, 'Failed to get size of compiled result.') 239 240 ptxbuf = (c_char * reslen.value)() 241 err = self.driver.nvvmGetCompiledResult(self._handle, ptxbuf) 242 self._try_error(err, 'Failed to get compiled result.') 243 244 # get log 245 self.log = self.get_log() 246 247 return ptxbuf[:] 248 249 def _try_error(self, err, msg): 250 self.driver.check_error(err, "%s\n%s" % (msg, self.get_log())) 251 252 def get_log(self): 253 reslen = c_size_t() 254 err = self.driver.nvvmGetProgramLogSize(self._handle, byref(reslen)) 255 self.driver.check_error(err, 'Failed to get compilation log size.') 256 257 if reslen.value > 1: 258 logbuf = (c_char * reslen.value)() 259 err = self.driver.nvvmGetProgramLog(self._handle, logbuf) 260 self.driver.check_error(err, 'Failed to get compilation log.') 261 262 return logbuf.value.decode('utf8') # populate log attribute 263 264 return '' 265 266 267data_layout = { 268 32: ('e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-' 269 'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64'), 270 64: ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-' 271 'f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64')} 272 273default_data_layout = data_layout[tuple.__itemsize__ * 8] 274 275_supported_cc = None 276 277 278def get_supported_ccs(): 279 global _supported_cc 280 281 if _supported_cc: 282 return _supported_cc 283 284 try: 285 from numba.cuda.cudadrv.runtime import runtime 286 cudart_version_major = runtime.get_version()[0] 287 except: 288 # The CUDA Runtime may not be present 289 cudart_version_major = 0 290 291 # List of supported compute capability in sorted order 292 if cudart_version_major == 0: 293 _supported_cc = (), 294 elif cudart_version_major < 9: 295 # CUDA 8.x 296 _supported_cc = (2, 0), (2, 1), (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2) 297 elif cudart_version_major < 10: 298 # CUDA 9.x 299 _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0) 300 elif cudart_version_major < 11: 301 # CUDA 10.x 302 _supported_cc = (3, 0), (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5) 303 else: 304 # CUDA 11.0 and later 305 _supported_cc = (3, 5), (5, 0), (5, 2), (5, 3), (6, 0), (6, 1), (6, 2), (7, 0), (7, 2), (7, 5), (8, 0) 306 307 return _supported_cc 308 309 310def find_closest_arch(mycc): 311 """ 312 Given a compute capability, return the closest compute capability supported 313 by the CUDA toolkit. 314 315 :param mycc: Compute capability as a tuple ``(MAJOR, MINOR)`` 316 :return: Closest supported CC as a tuple ``(MAJOR, MINOR)`` 317 """ 318 supported_cc = get_supported_ccs() 319 320 for i, cc in enumerate(supported_cc): 321 if cc == mycc: 322 # Matches 323 return cc 324 elif cc > mycc: 325 # Exceeded 326 if i == 0: 327 # CC lower than supported 328 raise NvvmSupportError("GPU compute capability %d.%d is " 329 "not supported (requires >=%d.%d)" % (mycc + cc)) 330 else: 331 # return the previous CC 332 return supported_cc[i - 1] 333 334 # CC higher than supported 335 return supported_cc[-1] # Choose the highest 336 337 338def get_arch_option(major, minor): 339 """Matches with the closest architecture option 340 """ 341 if config.FORCE_CUDA_CC: 342 arch = config.FORCE_CUDA_CC 343 else: 344 arch = find_closest_arch((major, minor)) 345 return 'compute_%d%d' % arch 346 347 348MISSING_LIBDEVICE_FILE_MSG = '''Missing libdevice file for {arch}. 349Please ensure you have package cudatoolkit >= 8. 350Install package by: 351 352 conda install cudatoolkit 353''' 354 355 356class LibDevice(object): 357 _cache_ = {} 358 _known_arch = [ 359 "compute_20", 360 "compute_30", 361 "compute_35", 362 "compute_50", 363 ] 364 365 def __init__(self, arch): 366 """ 367 arch --- must be result from get_arch_option() 368 """ 369 if arch not in self._cache_: 370 arch = self._get_closest_arch(arch) 371 if get_libdevice(arch) is None: 372 raise RuntimeError(MISSING_LIBDEVICE_FILE_MSG.format(arch=arch)) 373 self._cache_[arch] = open_libdevice(arch) 374 375 self.arch = arch 376 self.bc = self._cache_[arch] 377 378 def _get_closest_arch(self, arch): 379 res = self._known_arch[0] 380 for potential in self._known_arch: 381 if arch >= potential: 382 res = potential 383 return res 384 385 def get(self): 386 return self.bc 387 388 389ir_numba_cas_hack = """ 390define internal i32 @___numba_cas_hack(i32* %ptr, i32 %cmp, i32 %val) alwaysinline { 391 %out = cmpxchg volatile i32* %ptr, i32 %cmp, i32 %val monotonic 392 ret i32 %out 393} 394""" 395 396# Translation of code from CUDA Programming Guide v6.5, section B.12 397ir_numba_atomic_double_add = """ 398define internal double @___numba_atomic_double_add(double* %ptr, double %val) alwaysinline { 399entry: 400 %iptr = bitcast double* %ptr to i64* 401 %old2 = load volatile i64, i64* %iptr 402 br label %attempt 403 404attempt: 405 %old = phi i64 [ %old2, %entry ], [ %cas, %attempt ] 406 %dold = bitcast i64 %old to double 407 %dnew = fadd double %dold, %val 408 %new = bitcast double %dnew to i64 409 %cas = cmpxchg volatile i64* %iptr, i64 %old, i64 %new monotonic 410 %repeat = icmp ne i64 %cas, %old 411 br i1 %repeat, label %attempt, label %done 412 413done: 414 %result = bitcast i64 %old to double 415 ret double %result 416} 417""" 418 419 420ir_numba_atomic_minmax = """ 421define internal {T} @___numba_atomic_{T}_{NAN}{FUNC}({T}* %ptr, {T} %val) alwaysinline {{ 422entry: 423 %ptrval = load volatile {T}, {T}* %ptr 424 ; Return early when: 425 ; - For nanmin / nanmax when val is a NaN 426 ; - For min / max when val or ptr is a NaN 427 %early_return = fcmp uno {T} %val, %{PTR_OR_VAL}val 428 br i1 %early_return, label %done, label %lt_check 429 430lt_check: 431 %dold = phi {T} [ %ptrval, %entry ], [ %dcas, %attempt ] 432 ; Continue attempts if dold less or greater than val (depending on whether min or max) 433 ; or if dold is NaN (for nanmin / nanmax) 434 %cmp = fcmp {OP} {T} %dold, %val 435 br i1 %cmp, label %attempt, label %done 436 437attempt: 438 ; Attempt to swap in the value 439 %iold = bitcast {T} %dold to {Ti} 440 %iptr = bitcast {T}* %ptr to {Ti}* 441 %ival = bitcast {T} %val to {Ti} 442 %cas = cmpxchg volatile {Ti}* %iptr, {Ti} %iold, {Ti} %ival monotonic 443 %dcas = bitcast {Ti} %cas to {T} 444 br label %lt_check 445 446done: 447 ret {T} %ptrval 448}} 449""" 450 451 452def _replace_datalayout(llvmir): 453 """ 454 Find the line containing the datalayout and replace it 455 """ 456 lines = llvmir.splitlines() 457 for i, ln in enumerate(lines): 458 if ln.startswith("target datalayout"): 459 tmp = 'target datalayout = "{0}"' 460 lines[i] = tmp.format(default_data_layout) 461 break 462 return '\n'.join(lines) 463 464 465def llvm_to_ptx(llvmir, **opts): 466 if opts.pop('fastmath', False): 467 opts.update({ 468 'ftz': True, 469 'fma': True, 470 'prec_div': False, 471 'prec_sqrt': False, 472 }) 473 474 cu = CompilationUnit() 475 libdevice = LibDevice(arch=opts.get('arch', 'compute_20')) 476 # New LLVM generate a shorthand for datalayout that NVVM does not know 477 llvmir = _replace_datalayout(llvmir) 478 # Replace with our cmpxchg and atomic implementations because LLVM 3.5 has 479 # a new semantic for cmpxchg. 480 replacements = [ 481 ('declare i32 @___numba_cas_hack(i32*, i32, i32)', 482 ir_numba_cas_hack), 483 ('declare double @___numba_atomic_double_add(double*, double)', 484 ir_numba_atomic_double_add), 485 ('declare float @___numba_atomic_float_max(float*, float)', 486 ir_numba_atomic_minmax.format(T='float', Ti='i32', NAN='', OP='nnan olt', 487 PTR_OR_VAL='ptr', FUNC='max')), 488 ('declare double @___numba_atomic_double_max(double*, double)', 489 ir_numba_atomic_minmax.format(T='double', Ti='i64', NAN='', OP='nnan olt', 490 PTR_OR_VAL='ptr', FUNC='max')), 491 ('declare float @___numba_atomic_float_min(float*, float)', 492 ir_numba_atomic_minmax.format(T='float', Ti='i32', NAN='', OP='nnan ogt', 493 PTR_OR_VAL='ptr', FUNC='min')), 494 ('declare double @___numba_atomic_double_min(double*, double)', 495 ir_numba_atomic_minmax.format(T='double', Ti='i64', NAN='', OP='nnan ogt', 496 PTR_OR_VAL='ptr', FUNC='min')), 497 ('declare float @___numba_atomic_float_nanmax(float*, float)', 498 ir_numba_atomic_minmax.format(T='float', Ti='i32', NAN='nan', OP='ult', 499 PTR_OR_VAL='', FUNC='max')), 500 ('declare double @___numba_atomic_double_nanmax(double*, double)', 501 ir_numba_atomic_minmax.format(T='double', Ti='i64', NAN='nan', OP='ult', 502 PTR_OR_VAL='', FUNC='max')), 503 ('declare float @___numba_atomic_float_nanmin(float*, float)', 504 ir_numba_atomic_minmax.format(T='float', Ti='i32', NAN='nan', OP='ugt', 505 PTR_OR_VAL='', FUNC='min')), 506 ('declare double @___numba_atomic_double_nanmin(double*, double)', 507 ir_numba_atomic_minmax.format(T='double', Ti='i64', NAN='nan', OP='ugt', 508 PTR_OR_VAL='', FUNC='min')), 509 ('immarg', '') 510 ] 511 512 for decl, fn in replacements: 513 llvmir = llvmir.replace(decl, fn) 514 515 # llvm.numba_nvvm.atomic is used to prevent LLVM 9 onwards auto-upgrading 516 # these intrinsics into atomicrmw instructions, which are not recognized by 517 # NVVM. We can now replace them with the real intrinsic names, ready to 518 # pass to NVVM. 519 llvmir = llvmir.replace('llvm.numba_nvvm.atomic', 'llvm.nvvm.atomic') 520 521 llvmir = llvm39_to_34_ir(llvmir) 522 cu.add_module(llvmir.encode('utf8')) 523 cu.add_module(libdevice.get()) 524 525 ptx = cu.compile(**opts) 526 # XXX remove debug_pubnames seems to be necessary sometimes 527 return patch_ptx_debug_pubnames(ptx) 528 529 530def patch_ptx_debug_pubnames(ptx): 531 """ 532 Patch PTX to workaround .debug_pubnames NVVM error:: 533 534 ptxas fatal : Internal error: overlapping non-identical data 535 536 """ 537 while True: 538 # Repeatedly remove debug_pubnames sections 539 start = ptx.find(b'.section .debug_pubnames') 540 if start < 0: 541 break 542 stop = ptx.find(b'}', start) 543 if stop < 0: 544 raise ValueError('missing "}"') 545 ptx = ptx[:start] + ptx[stop + 1:] 546 return ptx 547 548 549re_metadata_def = re.compile(r"\!\d+\s*=") 550re_metadata_correct_usage = re.compile(r"metadata\s*\![{'\"0-9]") 551re_metadata_ref = re.compile(r"\!\d+") 552re_metadata_debuginfo = re.compile(r"\!{i32 \d, \!\"Debug Info Version\", i32 \d}".replace(' ', r'\s+')) 553 554re_attributes_def = re.compile(r"^attributes #\d+ = \{ ([\w\s]+)\ }") 555supported_attributes = {'alwaysinline', 'cold', 'inlinehint', 'minsize', 556 'noduplicate', 'noinline', 'noreturn', 'nounwind', 557 'optnone', 'optisze', 'readnone', 'readonly'} 558 559re_getelementptr = re.compile(r"\bgetelementptr\s(?:inbounds )?\(?") 560 561re_load = re.compile(r"=\s*\bload\s(?:\bvolatile\s)?") 562 563re_call = re.compile(r"(call\s[^@]+\))(\s@)") 564re_range = re.compile(r"\s*!range\s+!\d+") 565 566re_type_tok = re.compile(r"[,{}()[\]]") 567 568re_annotations = re.compile(r"\bnonnull\b") 569 570re_unsupported_keywords = re.compile(r"\b(local_unnamed_addr|writeonly)\b") 571 572re_parenthesized_list = re.compile(r"\((.*)\)") 573 574 575def llvm39_to_34_ir(ir): 576 """ 577 Convert LLVM 3.9 IR for LLVM 3.4. 578 """ 579 def parse_out_leading_type(s): 580 par_level = 0 581 pos = 0 582 # Parse out the first <ty> (which may be an aggregate type) 583 while True: 584 m = re_type_tok.search(s, pos) 585 if m is None: 586 # End of line 587 raise RuntimeError("failed parsing leading type: %s" % (s,)) 588 break 589 pos = m.end() 590 tok = m.group(0) 591 if tok == ',': 592 if par_level == 0: 593 # End of operand 594 break 595 elif tok in '{[(': 596 par_level += 1 597 elif tok in ')]}': 598 par_level -= 1 599 return s[pos:].lstrip() 600 601 buf = [] 602 for line in ir.splitlines(): 603 604 # Fix llvm.dbg.cu 605 if line.startswith('!numba.llvm.dbg.cu'): 606 line = line.replace('!numba.llvm.dbg.cu', '!llvm.dbg.cu') 607 608 # We insert a dummy inlineasm to put debuginfo 609 if (line.lstrip().startswith('tail call void asm sideeffect "// dbg') and 610 '!numba.dbg' in line): 611 # Fix the metadata 612 line = line.replace('!numba.dbg', '!dbg') 613 if re_metadata_def.match(line): 614 # Rewrite metadata since LLVM 3.7 dropped the "metadata" type prefix. 615 if None is re_metadata_correct_usage.search(line): 616 # Reintroduce the "metadata" prefix 617 line = line.replace('!{', 'metadata !{') 618 line = line.replace('!"', 'metadata !"') 619 620 assigpos = line.find('=') 621 lhs, rhs = line[:assigpos + 1], line[assigpos + 1:] 622 623 # Fix metadata reference 624 def fix_metadata_ref(m): 625 return 'metadata ' + m.group(0) 626 line = ' '.join((lhs, re_metadata_ref.sub(fix_metadata_ref, rhs))) 627 if line.startswith('source_filename ='): 628 continue # skip line 629 if re_unsupported_keywords.search(line) is not None: 630 line = re_unsupported_keywords.sub(lambda m: '', line) 631 632 if line.startswith('attributes #'): 633 # Remove function attributes unsupported pre-3.8 634 m = re_attributes_def.match(line) 635 attrs = m.group(1).split() 636 attrs = ' '.join(a for a in attrs if a in supported_attributes) 637 line = line.replace(m.group(1), attrs) 638 if 'getelementptr ' in line: 639 # Rewrite "getelementptr ty, ty* ptr, ..." 640 # to "getelementptr ty *ptr, ..." 641 m = re_getelementptr.search(line) 642 if m is None: 643 raise RuntimeError("failed parsing getelementptr: %s" % (line,)) 644 pos = m.end() 645 line = line[:pos] + parse_out_leading_type(line[pos:]) 646 if 'load ' in line: 647 # Rewrite "load ty, ty* ptr" 648 # to "load ty *ptr" 649 m = re_load.search(line) 650 if m: 651 pos = m.end() 652 line = line[:pos] + parse_out_leading_type(line[pos:]) 653 if 'call ' in line: 654 # Rewrite "call ty (...) @foo" 655 # to "call ty (...)* @foo" 656 line = re_call.sub(r"\1*\2", line) 657 658 # no !range metadata on calls 659 line = re_range.sub('', line).rstrip(',') 660 661 if '@llvm.memset' in line: 662 line = re_parenthesized_list.sub( 663 _replace_llvm_memset_usage, 664 line, 665 ) 666 if 'declare' in line: 667 if '@llvm.memset' in line: 668 line = re_parenthesized_list.sub( 669 _replace_llvm_memset_declaration, 670 line, 671 ) 672 673 # Remove unknown annotations 674 line = re_annotations.sub('', line) 675 676 buf.append(line) 677 678 return '\n'.join(buf) 679 680 681def _replace_llvm_memset_usage(m): 682 """Replace `llvm.memset` usage for llvm7+. 683 684 Used as functor for `re.sub. 685 """ 686 params = list(m.group(1).split(',')) 687 align_attr = re.search(r'align (\d+)', params[0]) 688 if not align_attr: 689 raise ValueError("No alignment attribute found on memset dest") 690 else: 691 align = align_attr.group(1) 692 params.insert(-1, 'i32 {}'.format(align)) 693 out = ', '.join(params) 694 return '({})'.format(out) 695 696 697def _replace_llvm_memset_declaration(m): 698 """Replace `llvm.memset` declaration for llvm7+. 699 700 Used as functor for `re.sub. 701 """ 702 params = list(m.group(1).split(',')) 703 params.insert(-1, 'i32') 704 out = ', '.join(params) 705 return '({})'.format(out) 706 707 708def set_cuda_kernel(lfunc): 709 from llvmlite.llvmpy.core import MetaData, MetaDataString, Constant, Type 710 711 m = lfunc.module 712 713 ops = lfunc, MetaDataString.get(m, "kernel"), Constant.int(Type.int(), 1) 714 md = MetaData.get(m, ops) 715 716 nmd = m.get_or_insert_named_metadata('nvvm.annotations') 717 nmd.add(md) 718 719 # set nvvm ir version 720 i32 = ir.IntType(32) 721 md_ver = m.add_metadata([i32(1), i32(2), i32(2), i32(0)]) 722 m.add_named_metadata('nvvmir.version', md_ver) 723 724 725def fix_data_layout(module): 726 module.data_layout = default_data_layout 727