1"""
2Dynamically generate the NRT module
3"""
4
5
6from numba.core.config import MACHINE_BITS
7from numba.core import types, cgutils
8from llvmlite import ir, binding
9
10# Flag to enable debug print in NRT_incref and NRT_decref
11_debug_print = False
12
13_word_type = ir.IntType(MACHINE_BITS)
14_pointer_type = ir.PointerType(ir.IntType(8))
15
16_meminfo_struct_type = ir.LiteralStructType([
17    _word_type,     # size_t refct
18    _pointer_type,  # dtor_function dtor
19    _pointer_type,  # void *dtor_info
20    _pointer_type,  # void *data
21    _word_type,     # size_t size
22    ])
23
24
25incref_decref_ty = ir.FunctionType(ir.VoidType(), [_pointer_type])
26meminfo_data_ty = ir.FunctionType(_pointer_type, [_pointer_type])
27
28
29def _define_nrt_meminfo_data(module):
30    """
31    Implement NRT_MemInfo_data_fast in the module.  This allows LLVM
32    to inline lookup of the data pointer.
33    """
34    fn = module.get_or_insert_function(meminfo_data_ty,
35                                       name="NRT_MemInfo_data_fast")
36    builder = ir.IRBuilder(fn.append_basic_block())
37    [ptr] = fn.args
38    struct_ptr = builder.bitcast(ptr, _meminfo_struct_type.as_pointer())
39    data_ptr = builder.load(cgutils.gep(builder, struct_ptr, 0, 3))
40    builder.ret(data_ptr)
41
42
43def _define_nrt_incref(module, atomic_incr):
44    """
45    Implement NRT_incref in the module
46    """
47    fn_incref = module.get_or_insert_function(incref_decref_ty,
48                                              name="NRT_incref")
49    # Cannot inline this for refcount pruning to work
50    fn_incref.attributes.add('noinline')
51    builder = ir.IRBuilder(fn_incref.append_basic_block())
52    [ptr] = fn_incref.args
53    is_null = builder.icmp_unsigned("==", ptr, cgutils.get_null_value(ptr.type))
54    with cgutils.if_unlikely(builder, is_null):
55        builder.ret_void()
56
57    if _debug_print:
58        cgutils.printf(builder, "*** NRT_Incref %zu [%p]\n", builder.load(ptr),
59                       ptr)
60    builder.call(atomic_incr, [builder.bitcast(ptr, atomic_incr.args[0].type)])
61    builder.ret_void()
62
63
64def _define_nrt_decref(module, atomic_decr):
65    """
66    Implement NRT_decref in the module
67    """
68    fn_decref = module.get_or_insert_function(incref_decref_ty,
69                                              name="NRT_decref")
70    # Cannot inline this for refcount pruning to work
71    fn_decref.attributes.add('noinline')
72    calldtor = module.add_function(ir.FunctionType(ir.VoidType(), [_pointer_type]),
73                                   name="NRT_MemInfo_call_dtor")
74
75    builder = ir.IRBuilder(fn_decref.append_basic_block())
76    [ptr] = fn_decref.args
77    is_null = builder.icmp_unsigned("==", ptr, cgutils.get_null_value(ptr.type))
78    with cgutils.if_unlikely(builder, is_null):
79        builder.ret_void()
80
81    if _debug_print:
82        cgutils.printf(builder, "*** NRT_Decref %zu [%p]\n", builder.load(ptr),
83                       ptr)
84
85    # For memory fence usage, see https://llvm.org/docs/Atomics.html
86
87    # A release fence is used before the relevant write operation.
88    # No-op on x86.  On POWER, it lowers to lwsync.
89    builder.fence("release")
90    newrefct = builder.call(atomic_decr,
91                            [builder.bitcast(ptr, atomic_decr.args[0].type)])
92
93    refct_eq_0 = builder.icmp_unsigned("==", newrefct,
94                                       ir.Constant(newrefct.type, 0))
95    with cgutils.if_unlikely(builder, refct_eq_0):
96        # An acquire fence is used after the relevant read operation.
97        # No-op on x86.  On POWER, it lowers to lwsync.
98        builder.fence("acquire")
99        builder.call(calldtor, [ptr])
100    builder.ret_void()
101
102
103# Set this to True to measure the overhead of atomic refcounts compared
104# to non-atomic.
105_disable_atomicity = 0
106
107
108def _define_atomic_inc_dec(module, op, ordering):
109    """Define a llvm function for atomic increment/decrement to the given module
110    Argument ``op`` is the operation "add"/"sub".  Argument ``ordering`` is
111    the memory ordering.  The generated function returns the new value.
112    """
113    ftype = ir.FunctionType(_word_type, [_word_type.as_pointer()])
114    fn_atomic = ir.Function(module, ftype, name="nrt_atomic_{0}".format(op))
115
116    [ptr] = fn_atomic.args
117    bb = fn_atomic.append_basic_block()
118    builder = ir.IRBuilder(bb)
119    ONE = ir.Constant(_word_type, 1)
120    if not _disable_atomicity:
121        oldval = builder.atomic_rmw(op, ptr, ONE, ordering=ordering)
122        # Perform the operation on the old value so that we can pretend returning
123        # the "new" value.
124        res = getattr(builder, op)(oldval, ONE)
125        builder.ret(res)
126    else:
127        oldval = builder.load(ptr)
128        newval = getattr(builder, op)(oldval, ONE)
129        builder.store(newval, ptr)
130        builder.ret(oldval)
131
132    return fn_atomic
133
134
135def _define_atomic_cas(module, ordering):
136    """Define a llvm function for atomic compare-and-swap.
137    The generated function is a direct wrapper of the LLVM cmpxchg with the
138    difference that the a int indicate success (1) or failure (0) is returned
139    and the last argument is a output pointer for storing the old value.
140
141    Note
142    ----
143    On failure, the generated function behaves like an atomic load.  The loaded
144    value is stored to the last argument.
145    """
146    ftype = ir.FunctionType(ir.IntType(32), [_word_type.as_pointer(),
147                                             _word_type, _word_type,
148                                             _word_type.as_pointer()])
149    fn_cas = ir.Function(module, ftype, name="nrt_atomic_cas")
150
151    [ptr, cmp, repl, oldptr] = fn_cas.args
152    bb = fn_cas.append_basic_block()
153    builder = ir.IRBuilder(bb)
154    outtup = builder.cmpxchg(ptr, cmp, repl, ordering=ordering)
155    old, ok = cgutils.unpack_tuple(builder, outtup, 2)
156    builder.store(old, oldptr)
157    builder.ret(builder.zext(ok, ftype.return_type))
158
159    return fn_cas
160
161
162def _define_nrt_unresolved_abort(ctx, module):
163    """
164    Defines an abort function due to unresolved symbol.
165
166    The function takes no args and will always raise an exception.
167    It should be safe to call this function with incorrect number of arguments.
168    """
169    fnty = ctx.call_conv.get_function_type(types.none, ())
170    fn = ir.Function(module, fnty, name="nrt_unresolved_abort")
171    bb = fn.append_basic_block()
172    builder = ir.IRBuilder(bb)
173    msg = "numba jitted function aborted due to unresolved symbol"
174    ctx.call_conv.return_user_exc(builder, RuntimeError, (msg,))
175    return fn
176
177
178def create_nrt_module(ctx):
179    """
180    Create an IR module defining the LLVM NRT functions.
181    A (IR module, library) tuple is returned.
182    """
183    codegen = ctx.codegen()
184    library = codegen.create_library("nrt")
185
186    # Implement LLVM module with atomic ops
187    ir_mod = library.create_ir_module("nrt_module")
188
189    atomic_inc = _define_atomic_inc_dec(ir_mod, "add", ordering='monotonic')
190    atomic_dec = _define_atomic_inc_dec(ir_mod, "sub", ordering='monotonic')
191    _define_atomic_cas(ir_mod, ordering='monotonic')
192
193    _define_nrt_meminfo_data(ir_mod)
194    _define_nrt_incref(ir_mod, atomic_inc)
195    _define_nrt_decref(ir_mod, atomic_dec)
196
197    _define_nrt_unresolved_abort(ctx, ir_mod)
198
199    return ir_mod, library
200
201
202def compile_nrt_functions(ctx):
203    """
204    Compile all LLVM NRT functions and return a library containing them.
205    The library is created using the given target context.
206    """
207    ir_mod, library = create_nrt_module(ctx)
208
209    library.add_ir_module(ir_mod)
210    library.finalize()
211
212    return library
213