1"""
2Generic helpers for LLVM code generation.
3"""
4
5
6import collections
7from contextlib import contextmanager
8import functools
9
10from llvmlite import ir
11
12from numba.core import utils, types, config
13import numba.core.datamodel
14
15
16bool_t = ir.IntType(1)
17int8_t = ir.IntType(8)
18int32_t = ir.IntType(32)
19intp_t = ir.IntType(utils.MACHINE_BITS)
20voidptr_t = int8_t.as_pointer()
21
22true_bit = bool_t(1)
23false_bit = bool_t(0)
24true_byte = int8_t(1)
25false_byte = int8_t(0)
26
27
28def as_bool_bit(builder, value):
29    return builder.icmp_unsigned('!=', value, value.type(0))
30
31
32def make_anonymous_struct(builder, values, struct_type=None):
33    """
34    Create an anonymous struct containing the given LLVM *values*.
35    """
36    if struct_type is None:
37        struct_type = ir.LiteralStructType([v.type for v in values])
38    struct_val = struct_type(ir.Undefined)
39    for i, v in enumerate(values):
40        struct_val = builder.insert_value(struct_val, v, i)
41    return struct_val
42
43
44def make_bytearray(buf):
45    """
46    Make a byte array constant from *buf*.
47    """
48    b = bytearray(buf)
49    n = len(b)
50    return ir.Constant(ir.ArrayType(ir.IntType(8), n), b)
51
52
53_struct_proxy_cache = {}
54
55
56def create_struct_proxy(fe_type, kind='value'):
57    """
58    Returns a specialized StructProxy subclass for the given fe_type.
59    """
60    cache_key = (fe_type, kind)
61    res = _struct_proxy_cache.get(cache_key)
62    if res is None:
63        base = {'value': ValueStructProxy,
64                'data': DataStructProxy,
65                }[kind]
66        clsname = base.__name__ + '_' + str(fe_type)
67        bases = (base,)
68        clsmembers = dict(_fe_type=fe_type)
69        res = type(clsname, bases, clsmembers)
70
71        _struct_proxy_cache[cache_key] = res
72    return res
73
74
75def copy_struct(dst, src, repl={}):
76    """
77    Copy structure from *src* to *dst* with replacement from *repl*.
78    """
79    repl = repl.copy()
80    # copy data from src or use those in repl
81    for k in src._datamodel._fields:
82        v = repl.pop(k, getattr(src, k))
83        setattr(dst, k, v)
84    # use remaining key-values in repl
85    for k, v in repl.items():
86        setattr(dst, k, v)
87    return dst
88
89
90class _StructProxy(object):
91    """
92    Creates a `Structure` like interface that is constructed with information
93    from DataModel instance.  FE type must have a data model that is a
94    subclass of StructModel.
95    """
96    # The following class members must be overridden by subclass
97    _fe_type = None
98
99    def __init__(self, context, builder, value=None, ref=None):
100        self._context = context
101        self._datamodel = self._context.data_model_manager[self._fe_type]
102        if not isinstance(self._datamodel, numba.core.datamodel.StructModel):
103            raise TypeError(
104                "Not a structure model: {0}".format(self._datamodel))
105        self._builder = builder
106
107        self._be_type = self._get_be_type(self._datamodel)
108        assert not is_pointer(self._be_type)
109
110        outer_ref, ref = self._make_refs(ref)
111        if ref.type.pointee != self._be_type:
112            raise AssertionError("bad ref type: expected %s, got %s"
113                                 % (self._be_type.as_pointer(), ref.type))
114
115        if value is not None:
116            if value.type != outer_ref.type.pointee:
117                raise AssertionError("bad value type: expected %s, got %s"
118                                     % (outer_ref.type.pointee, value.type))
119            self._builder.store(value, outer_ref)
120
121        self._value = ref
122        self._outer_ref = outer_ref
123
124    def _make_refs(self, ref):
125        """
126        Return an (outer ref, value ref) pair.  By default, these are
127        the same pointers, but a derived class may override this.
128        """
129        if ref is None:
130            ref = alloca_once(self._builder, self._be_type, zfill=True)
131        return ref, ref
132
133    def _get_be_type(self, datamodel):
134        raise NotImplementedError
135
136    def _cast_member_to_value(self, index, val):
137        raise NotImplementedError
138
139    def _cast_member_from_value(self, index, val):
140        raise NotImplementedError
141
142    def _get_ptr_by_index(self, index):
143        return gep_inbounds(self._builder, self._value, 0, index)
144
145    def _get_ptr_by_name(self, attrname):
146        index = self._datamodel.get_field_position(attrname)
147        return self._get_ptr_by_index(index)
148
149    def __getattr__(self, field):
150        """
151        Load the LLVM value of the named *field*.
152        """
153        if not field.startswith('_'):
154            return self[self._datamodel.get_field_position(field)]
155        else:
156            raise AttributeError(field)
157
158    def __setattr__(self, field, value):
159        """
160        Store the LLVM *value* into the named *field*.
161        """
162        if field.startswith('_'):
163            return super(_StructProxy, self).__setattr__(field, value)
164        self[self._datamodel.get_field_position(field)] = value
165
166    def __getitem__(self, index):
167        """
168        Load the LLVM value of the field at *index*.
169        """
170        member_val = self._builder.load(self._get_ptr_by_index(index))
171        return self._cast_member_to_value(index, member_val)
172
173    def __setitem__(self, index, value):
174        """
175        Store the LLVM *value* into the field at *index*.
176        """
177        ptr = self._get_ptr_by_index(index)
178        value = self._cast_member_from_value(index, value)
179        if value.type != ptr.type.pointee:
180            if (is_pointer(value.type) and is_pointer(ptr.type.pointee)
181                    and value.type.pointee == ptr.type.pointee.pointee):
182                # Differ by address-space only
183                # Auto coerce it
184                value = self._context.addrspacecast(self._builder,
185                                                    value,
186                                                    ptr.type.pointee.addrspace)
187            else:
188                raise TypeError("Invalid store of {value.type} to "
189                                "{ptr.type.pointee} in "
190                                "{self._datamodel} "
191                                "(trying to write member #{index})"
192                                .format(value=value, ptr=ptr, self=self,
193                                        index=index))
194        self._builder.store(value, ptr)
195
196    def __len__(self):
197        """
198        Return the number of fields.
199        """
200        return self._datamodel.field_count
201
202    def _getpointer(self):
203        """
204        Return the LLVM pointer to the underlying structure.
205        """
206        return self._outer_ref
207
208    def _getvalue(self):
209        """
210        Load and return the value of the underlying LLVM structure.
211        """
212        return self._builder.load(self._outer_ref)
213
214    def _setvalue(self, value):
215        """
216        Store the value in this structure.
217        """
218        assert not is_pointer(value.type)
219        assert value.type == self._be_type, (value.type, self._be_type)
220        self._builder.store(value, self._value)
221
222
223class ValueStructProxy(_StructProxy):
224    """
225    Create a StructProxy suitable for accessing regular values
226    (e.g. LLVM values or alloca slots).
227    """
228    def _get_be_type(self, datamodel):
229        return datamodel.get_value_type()
230
231    def _cast_member_to_value(self, index, val):
232        return val
233
234    def _cast_member_from_value(self, index, val):
235        return val
236
237
238class DataStructProxy(_StructProxy):
239    """
240    Create a StructProxy suitable for accessing data persisted in memory.
241    """
242    def _get_be_type(self, datamodel):
243        return datamodel.get_data_type()
244
245    def _cast_member_to_value(self, index, val):
246        model = self._datamodel.get_model(index)
247        return model.from_data(self._builder, val)
248
249    def _cast_member_from_value(self, index, val):
250        model = self._datamodel.get_model(index)
251        return model.as_data(self._builder, val)
252
253
254class Structure(object):
255    """
256    A high-level object wrapping a alloca'ed LLVM structure, including
257    named fields and attribute access.
258    """
259
260    # XXX Should this warrant several separate constructors?
261    def __init__(self, context, builder, value=None, ref=None, cast_ref=False):
262        self._type = context.get_struct_type(self)
263        self._context = context
264        self._builder = builder
265        if ref is None:
266            self._value = alloca_once(builder, self._type, zfill=True)
267            if value is not None:
268                assert not is_pointer(value.type)
269                assert value.type == self._type, (value.type, self._type)
270                builder.store(value, self._value)
271        else:
272            assert value is None
273            assert is_pointer(ref.type)
274            if self._type != ref.type.pointee:
275                if cast_ref:
276                    ref = builder.bitcast(ref, self._type.as_pointer())
277                else:
278                    raise TypeError(
279                        "mismatching pointer type: got %s, expected %s"
280                        % (ref.type.pointee, self._type))
281            self._value = ref
282
283        self._namemap = {}
284        self._fdmap = []
285        self._typemap = []
286        base = int32_t(0)
287        for i, (k, tp) in enumerate(self._fields):
288            self._namemap[k] = i
289            self._fdmap.append((base, int32_t(i)))
290            self._typemap.append(tp)
291
292    def _get_ptr_by_index(self, index):
293        ptr = self._builder.gep(self._value, self._fdmap[index], inbounds=True)
294        return ptr
295
296    def _get_ptr_by_name(self, attrname):
297        return self._get_ptr_by_index(self._namemap[attrname])
298
299    def __getattr__(self, field):
300        """
301        Load the LLVM value of the named *field*.
302        """
303        if not field.startswith('_'):
304            return self[self._namemap[field]]
305        else:
306            raise AttributeError(field)
307
308    def __setattr__(self, field, value):
309        """
310        Store the LLVM *value* into the named *field*.
311        """
312        if field.startswith('_'):
313            return super(Structure, self).__setattr__(field, value)
314        self[self._namemap[field]] = value
315
316    def __getitem__(self, index):
317        """
318        Load the LLVM value of the field at *index*.
319        """
320
321        return self._builder.load(self._get_ptr_by_index(index))
322
323    def __setitem__(self, index, value):
324        """
325        Store the LLVM *value* into the field at *index*.
326        """
327        ptr = self._get_ptr_by_index(index)
328        if ptr.type.pointee != value.type:
329            fmt = "Type mismatch: __setitem__(%d, ...) expected %r but got %r"
330            raise AssertionError(fmt % (index,
331                                        str(ptr.type.pointee),
332                                        str(value.type)))
333        self._builder.store(value, ptr)
334
335    def __len__(self):
336        """
337        Return the number of fields.
338        """
339        return len(self._namemap)
340
341    def _getpointer(self):
342        """
343        Return the LLVM pointer to the underlying structure.
344        """
345        return self._value
346
347    def _getvalue(self):
348        """
349        Load and return the value of the underlying LLVM structure.
350        """
351        return self._builder.load(self._value)
352
353    def _setvalue(self, value):
354        """Store the value in this structure"""
355        assert not is_pointer(value.type)
356        assert value.type == self._type, (value.type, self._type)
357        self._builder.store(value, self._value)
358
359    # __iter__ is derived by Python from __len__ and __getitem__
360
361
362def alloca_once(builder, ty, size=None, name='', zfill=False):
363    """Allocate stack memory at the entry block of the current function
364    pointed by ``builder`` withe llvm type ``ty``.  The optional ``size`` arg
365    set the number of element to allocate.  The default is 1.  The optional
366    ``name`` arg set the symbol name inside the llvm IR for debugging.
367    If ``zfill`` is set, fill the memory with zeros at the current
368    use-site location.  Note that the memory is always zero-filled after the
369    ``alloca`` at init-site (the entry block).
370    """
371    if isinstance(size, utils.INT_TYPES):
372        size = ir.Constant(intp_t, size)
373    with builder.goto_entry_block():
374        ptr = builder.alloca(ty, size=size, name=name)
375        # Always zero-fill at init-site.  This is safe.
376        builder.store(ty(None), ptr)
377    # Also zero-fill at the use-site
378    if zfill:
379        builder.store(ty(None), ptr)
380    return ptr
381
382
383def alloca_once_value(builder, value, name=''):
384    """
385    Like alloca_once(), but passing a *value* instead of a type.  The
386    type is inferred and the allocated slot is also initialized with the
387    given value.
388    """
389    storage = alloca_once(builder, value.type)
390    builder.store(value, storage)
391    return storage
392
393
394def insert_pure_function(module, fnty, name):
395    """
396    Insert a pure function (in the functional programming sense) in the
397    given module.
398    """
399    fn = module.get_or_insert_function(fnty, name=name)
400    fn.attributes.add("readonly")
401    fn.attributes.add("nounwind")
402    return fn
403
404
405def terminate(builder, bbend):
406    bb = builder.basic_block
407    if bb.terminator is None:
408        builder.branch(bbend)
409
410
411def get_null_value(ltype):
412    return ltype(None)
413
414
415def is_null(builder, val):
416    null = get_null_value(val.type)
417    return builder.icmp_unsigned('==', null, val)
418
419
420def is_not_null(builder, val):
421    null = get_null_value(val.type)
422    return builder.icmp_unsigned('!=', null, val)
423
424
425def if_unlikely(builder, pred):
426    return builder.if_then(pred, likely=False)
427
428
429def if_likely(builder, pred):
430    return builder.if_then(pred, likely=True)
431
432
433def ifnot(builder, pred):
434    return builder.if_then(builder.not_(pred))
435
436
437def increment_index(builder, val):
438    """
439    Increment an index *val*.
440    """
441    one = val.type(1)
442    # We pass the "nsw" flag in the hope that LLVM understands the index
443    # never changes sign.  Unfortunately this doesn't always work
444    # (e.g. ndindex()).
445    return builder.add(val, one, flags=['nsw'])
446
447
448Loop = collections.namedtuple('Loop', ('index', 'do_break'))
449
450
451@contextmanager
452def for_range(builder, count, start=None, intp=None):
453    """
454    Generate LLVM IR for a for-loop in [start, count).
455    *start* is equal to 0 by default.
456
457    Yields a Loop namedtuple with the following members:
458    - `index` is the loop index's value
459    - `do_break` is a no-argument callable to break out of the loop
460    """
461    if intp is None:
462        intp = count.type
463    if start is None:
464        start = intp(0)
465    stop = count
466
467    bbcond = builder.append_basic_block("for.cond")
468    bbbody = builder.append_basic_block("for.body")
469    bbend = builder.append_basic_block("for.end")
470
471    def do_break():
472        builder.branch(bbend)
473
474    bbstart = builder.basic_block
475    builder.branch(bbcond)
476
477    with builder.goto_block(bbcond):
478        index = builder.phi(intp, name="loop.index")
479        pred = builder.icmp_signed('<', index, stop)
480        builder.cbranch(pred, bbbody, bbend)
481
482    with builder.goto_block(bbbody):
483        yield Loop(index, do_break)
484        # Update bbbody as a new basic block may have been activated
485        bbbody = builder.basic_block
486        incr = increment_index(builder, index)
487        terminate(builder, bbcond)
488
489    index.add_incoming(start, bbstart)
490    index.add_incoming(incr, bbbody)
491
492    builder.position_at_end(bbend)
493
494
495@contextmanager
496def for_range_slice(builder, start, stop, step, intp=None, inc=True):
497    """
498    Generate LLVM IR for a for-loop based on a slice.  Yields a
499    (index, count) tuple where `index` is the slice index's value
500    inside the loop, and `count` the iteration count.
501
502    Parameters
503    -------------
504    builder : object
505        Builder object
506    start : int
507        The beginning value of the slice
508    stop : int
509        The end value of the slice
510    step : int
511        The step value of the slice
512    intp :
513        The data type
514    inc : boolean, optional
515        Signals whether the step is positive (True) or negative (False).
516
517    Returns
518    -----------
519        None
520    """
521    if intp is None:
522        intp = start.type
523
524    bbcond = builder.append_basic_block("for.cond")
525    bbbody = builder.append_basic_block("for.body")
526    bbend = builder.append_basic_block("for.end")
527    bbstart = builder.basic_block
528    builder.branch(bbcond)
529
530    with builder.goto_block(bbcond):
531        index = builder.phi(intp, name="loop.index")
532        count = builder.phi(intp, name="loop.count")
533        if (inc):
534            pred = builder.icmp_signed('<', index, stop)
535        else:
536            pred = builder.icmp_signed('>', index, stop)
537        builder.cbranch(pred, bbbody, bbend)
538
539    with builder.goto_block(bbbody):
540        yield index, count
541        bbbody = builder.basic_block
542        incr = builder.add(index, step)
543        next_count = increment_index(builder, count)
544        terminate(builder, bbcond)
545
546    index.add_incoming(start, bbstart)
547    index.add_incoming(incr, bbbody)
548    count.add_incoming(ir.Constant(intp, 0), bbstart)
549    count.add_incoming(next_count, bbbody)
550    builder.position_at_end(bbend)
551
552
553@contextmanager
554def for_range_slice_generic(builder, start, stop, step):
555    """
556    A helper wrapper for for_range_slice().  This is a context manager which
557    yields two for_range_slice()-alike context managers, the first for
558    the positive step case, the second for the negative step case.
559
560    Use:
561        with for_range_slice_generic(...) as (pos_range, neg_range):
562            with pos_range as (idx, count):
563                ...
564            with neg_range as (idx, count):
565                ...
566    """
567    intp = start.type
568    is_pos_step = builder.icmp_signed('>=', step, ir.Constant(intp, 0))
569
570    pos_for_range = for_range_slice(builder, start, stop, step, intp, inc=True)
571    neg_for_range = for_range_slice(builder, start, stop, step, intp, inc=False)
572
573    @contextmanager
574    def cm_cond(cond, inner_cm):
575        with cond:
576            with inner_cm as value:
577                yield value
578
579    with builder.if_else(is_pos_step, likely=True) as (then, otherwise):
580        yield cm_cond(then, pos_for_range), cm_cond(otherwise, neg_for_range)
581
582
583@contextmanager
584def loop_nest(builder, shape, intp, order='C'):
585    """
586    Generate a loop nest walking a N-dimensional array.
587    Yields a tuple of N indices for use in the inner loop body,
588    iterating over the *shape* space.
589
590    If *order* is 'C' (the default), indices are incremented inside-out
591    (i.e. (0,0), (0,1), (0,2), (1,0) etc.).
592    If *order* is 'F', they are incremented outside-in
593    (i.e. (0,0), (1,0), (2,0), (0,1) etc.).
594    This has performance implications when walking an array as it impacts
595    the spatial locality of memory accesses.
596    """
597    assert order in 'CF'
598    if not shape:
599        # 0-d array
600        yield ()
601    else:
602        if order == 'F':
603            _swap = lambda x: x[::-1]
604        else:
605            _swap = lambda x: x
606        with _loop_nest(builder, _swap(shape), intp) as indices:
607            assert len(indices) == len(shape)
608            yield _swap(indices)
609
610
611@contextmanager
612def _loop_nest(builder, shape, intp):
613    with for_range(builder, shape[0], intp=intp) as loop:
614        if len(shape) > 1:
615            with _loop_nest(builder, shape[1:], intp) as indices:
616                yield (loop.index,) + indices
617        else:
618            yield (loop.index,)
619
620
621def pack_array(builder, values, ty=None):
622    """
623    Pack a sequence of values in a LLVM array.  *ty* should be given
624    if the array may be empty, in which case the type can't be inferred
625    from the values.
626    """
627    n = len(values)
628    if ty is None:
629        ty = values[0].type
630    ary = ir.ArrayType(ty, n)(ir.Undefined)
631    for i, v in enumerate(values):
632        ary = builder.insert_value(ary, v, i)
633    return ary
634
635
636def pack_struct(builder, values):
637    """
638    Pack a sequence of values into a LLVM struct.
639    """
640    structty = ir.LiteralStructType([v.type for v in values])
641    st = structty(ir.Undefined)
642    for i, v in enumerate(values):
643        st = builder.insert_value(st, v, i)
644    return st
645
646
647def unpack_tuple(builder, tup, count=None):
648    """
649    Unpack an array or structure of values, return a Python tuple.
650    """
651    if count is None:
652        # Assuming *tup* is an aggregate
653        count = len(tup.type.elements)
654    vals = [builder.extract_value(tup, i)
655            for i in range(count)]
656    return vals
657
658
659def get_item_pointer(context, builder, aryty, ary, inds, wraparound=False,
660                     boundscheck=False):
661    # Set boundscheck=True for any pointer access that should be
662    # boundschecked. do_boundscheck() will handle enabling or disabling the
663    # actual boundschecking based on the user config.
664    shapes = unpack_tuple(builder, ary.shape, count=aryty.ndim)
665    strides = unpack_tuple(builder, ary.strides, count=aryty.ndim)
666    return get_item_pointer2(context, builder, data=ary.data, shape=shapes,
667                             strides=strides, layout=aryty.layout, inds=inds,
668                             wraparound=wraparound, boundscheck=boundscheck)
669
670
671def do_boundscheck(context, builder, ind, dimlen, axis=None):
672    def _dbg():
673        # Remove this when we figure out how to include this information
674        # in the error message.
675        if axis is not None:
676            if isinstance(axis, int):
677                printf(builder, "debug: IndexError: index %d is out of bounds "
678                       "for axis {} with size %d\n".format(axis), ind, dimlen)
679            else:
680                printf(builder, "debug: IndexError: index %d is out of bounds "
681                       "for axis %d with size %d\n", ind, axis,
682                       dimlen)
683        else:
684            printf(builder,
685                   "debug: IndexError: index %d is out of bounds for size %d\n",
686                   ind, dimlen)
687
688    msg = "index is out of bounds"
689    out_of_bounds_upper = builder.icmp_signed('>=', ind, dimlen)
690    with if_unlikely(builder, out_of_bounds_upper):
691        if config.FULL_TRACEBACKS:
692            _dbg()
693        context.call_conv.return_user_exc(builder, IndexError, (msg,))
694    out_of_bounds_lower = builder.icmp_signed('<', ind, ind.type(0))
695    with if_unlikely(builder, out_of_bounds_lower):
696        if config.FULL_TRACEBACKS:
697            _dbg()
698        context.call_conv.return_user_exc(builder, IndexError, (msg,))
699
700
701def get_item_pointer2(context, builder, data, shape, strides, layout, inds,
702                      wraparound=False, boundscheck=False):
703    # Set boundscheck=True for any pointer access that should be
704    # boundschecked. do_boundscheck() will handle enabling or disabling the
705    # actual boundschecking based on the user config.
706    if wraparound:
707        # Wraparound
708        indices = []
709        for ind, dimlen in zip(inds, shape):
710            negative = builder.icmp_signed('<', ind, ind.type(0))
711            wrapped = builder.add(dimlen, ind)
712            selected = builder.select(negative, wrapped, ind)
713            indices.append(selected)
714    else:
715        indices = inds
716    if boundscheck:
717        for axis, (ind, dimlen) in enumerate(zip(indices, shape)):
718            do_boundscheck(context, builder, ind, dimlen, axis)
719
720    if not indices:
721        # Indexing with empty tuple
722        return builder.gep(data, [int32_t(0)])
723    intp = indices[0].type
724    # Indexing code
725    if layout in 'CF':
726        steps = []
727        # Compute steps for each dimension
728        if layout == 'C':
729            # C contiguous
730            for i in range(len(shape)):
731                last = intp(1)
732                for j in shape[i + 1:]:
733                    last = builder.mul(last, j)
734                steps.append(last)
735        elif layout == 'F':
736            # F contiguous
737            for i in range(len(shape)):
738                last = intp(1)
739                for j in shape[:i]:
740                    last = builder.mul(last, j)
741                steps.append(last)
742        else:
743            raise Exception("unreachable")
744
745        # Compute index
746        loc = intp(0)
747        for i, s in zip(indices, steps):
748            tmp = builder.mul(i, s)
749            loc = builder.add(loc, tmp)
750        ptr = builder.gep(data, [loc])
751        return ptr
752    else:
753        # Any layout
754        dimoffs = [builder.mul(s, i) for s, i in zip(strides, indices)]
755        offset = functools.reduce(builder.add, dimoffs)
756        return pointer_add(builder, data, offset)
757
758
759def _scalar_pred_against_zero(builder, value, fpred, icond):
760    nullval = value.type(0)
761    if isinstance(value.type, (ir.FloatType, ir.DoubleType)):
762        isnull = fpred(value, nullval)
763    elif isinstance(value.type, ir.IntType):
764        isnull = builder.icmp_signed(icond, value, nullval)
765    else:
766        raise TypeError("unexpected value type %s" % (value.type,))
767    return isnull
768
769
770def is_scalar_zero(builder, value):
771    """
772    Return a predicate representing whether *value* is equal to zero.
773    """
774    return _scalar_pred_against_zero(
775        builder, value, functools.partial(builder.fcmp_ordered, '=='), '==')
776
777
778def is_not_scalar_zero(builder, value):
779    """
780    Return a predicate representing whether a *value* is not equal to zero.
781    (not exactly "not is_scalar_zero" because of nans)
782    """
783    return _scalar_pred_against_zero(
784        builder, value, functools.partial(builder.fcmp_unordered, '!='), '!=')
785
786
787def is_scalar_zero_or_nan(builder, value):
788    """
789    Return a predicate representing whether *value* is equal to either zero
790    or NaN.
791    """
792    return _scalar_pred_against_zero(
793        builder, value, functools.partial(builder.fcmp_unordered, '=='), '==')
794
795
796is_true = is_not_scalar_zero
797is_false = is_scalar_zero
798
799
800def is_scalar_neg(builder, value):
801    """
802    Is *value* negative?  Assumes *value* is signed.
803    """
804    return _scalar_pred_against_zero(
805        builder, value, functools.partial(builder.fcmp_ordered, '<'), '<')
806
807
808def guard_null(context, builder, value, exc_tuple):
809    """
810    Guard against *value* being null or zero.
811    *exc_tuple* should be a (exception type, arguments...) tuple.
812    """
813    with builder.if_then(is_scalar_zero(builder, value), likely=False):
814        exc = exc_tuple[0]
815        exc_args = exc_tuple[1:] or None
816        context.call_conv.return_user_exc(builder, exc, exc_args)
817
818
819def guard_memory_error(context, builder, pointer, msg=None):
820    """
821    Guard against *pointer* being NULL (and raise a MemoryError).
822    """
823    assert isinstance(pointer.type, ir.PointerType), pointer.type
824    exc_args = (msg,) if msg else ()
825    with builder.if_then(is_null(builder, pointer), likely=False):
826        context.call_conv.return_user_exc(builder, MemoryError, exc_args)
827
828
829@contextmanager
830def if_zero(builder, value, likely=False):
831    """
832    Execute the given block if the scalar value is zero.
833    """
834    with builder.if_then(is_scalar_zero(builder, value), likely=likely):
835        yield
836
837
838guard_zero = guard_null
839
840
841def is_pointer(ltyp):
842    """
843    Whether the LLVM type *typ* is a struct type.
844    """
845    return isinstance(ltyp, ir.PointerType)
846
847
848def get_record_member(builder, record, offset, typ):
849    pval = gep_inbounds(builder, record, 0, offset)
850    assert not is_pointer(pval.type.pointee)
851    return builder.bitcast(pval, typ.as_pointer())
852
853
854def is_neg_int(builder, val):
855    return builder.icmp_signed('<', val, val.type(0))
856
857
858def gep_inbounds(builder, ptr, *inds, **kws):
859    """
860    Same as *gep*, but add the `inbounds` keyword.
861    """
862    return gep(builder, ptr, *inds, inbounds=True, **kws)
863
864
865def gep(builder, ptr, *inds, **kws):
866    """
867    Emit a getelementptr instruction for the given pointer and indices.
868    The indices can be LLVM values or Python int constants.
869    """
870    name = kws.pop('name', '')
871    inbounds = kws.pop('inbounds', False)
872    assert not kws
873    idx = []
874    for i in inds:
875        if isinstance(i, utils.INT_TYPES):
876            # NOTE: llvm only accepts int32 inside structs, not int64
877            ind = int32_t(i)
878        else:
879            ind = i
880        idx.append(ind)
881    return builder.gep(ptr, idx, name=name, inbounds=inbounds)
882
883
884def pointer_add(builder, ptr, offset, return_type=None):
885    """
886    Add an integral *offset* to pointer *ptr*, and return a pointer
887    of *return_type* (or, if omitted, the same type as *ptr*).
888
889    Note the computation is done in bytes, and ignores the width of
890    the pointed item type.
891    """
892    intptr = builder.ptrtoint(ptr, intp_t)
893    if isinstance(offset, utils.INT_TYPES):
894        offset = intp_t(offset)
895    intptr = builder.add(intptr, offset)
896    return builder.inttoptr(intptr, return_type or ptr.type)
897
898
899def memset(builder, ptr, size, value):
900    """
901    Fill *size* bytes starting from *ptr* with *value*.
902    """
903    fn = builder.module.declare_intrinsic('llvm.memset', (voidptr_t, size.type))
904    ptr = builder.bitcast(ptr, voidptr_t)
905    if isinstance(value, int):
906        value = int8_t(value)
907    builder.call(fn, [ptr, value, size, bool_t(0)])
908
909
910def global_constant(builder_or_module, name, value, linkage='internal'):
911    """
912    Get or create a (LLVM module-)global constant with *name* or *value*.
913    """
914    if isinstance(builder_or_module, ir.Module):
915        module = builder_or_module
916    else:
917        module = builder_or_module.module
918    data = module.add_global_variable(value.type, name=name)
919    data.linkage = linkage
920    data.global_constant = True
921    data.initializer = value
922    return data
923
924
925def divmod_by_constant(builder, val, divisor):
926    """
927    Compute the (quotient, remainder) of *val* divided by the constant
928    positive *divisor*.  The semantics reflects those of Python integer
929    floor division, rather than C's / LLVM's signed division and modulo.
930    The difference lies with a negative *val*.
931    """
932    assert divisor > 0
933    divisor = val.type(divisor)
934    one = val.type(1)
935
936    quot = alloca_once(builder, val.type)
937
938    with builder.if_else(is_neg_int(builder, val)) as (if_neg, if_pos):
939        with if_pos:
940            # quot = val / divisor
941            quot_val = builder.sdiv(val, divisor)
942            builder.store(quot_val, quot)
943        with if_neg:
944            # quot = -1 + (val + 1) / divisor
945            val_plus_one = builder.add(val, one)
946            quot_val = builder.sdiv(val_plus_one, divisor)
947            builder.store(builder.sub(quot_val, one), quot)
948
949    # rem = val - quot * divisor
950    # (should be slightly faster than a separate modulo operation)
951    quot_val = builder.load(quot)
952    rem_val = builder.sub(val, builder.mul(quot_val, divisor))
953    return quot_val, rem_val
954
955
956def cbranch_or_continue(builder, cond, bbtrue):
957    """
958    Branch conditionally or continue.
959
960    Note: a new block is created and builder is moved to the end of the new
961          block.
962    """
963    bbcont = builder.append_basic_block('.continue')
964    builder.cbranch(cond, bbtrue, bbcont)
965    builder.position_at_end(bbcont)
966    return bbcont
967
968
969def memcpy(builder, dst, src, count):
970    """
971    Emit a memcpy to the builder.
972
973    Copies each element of dst to src. Unlike the C equivalent, each element
974    can be any LLVM type.
975
976    Assumes
977    -------
978    * dst.type == src.type
979    * count is positive
980    """
981    # Note this does seem to be optimized as a raw memcpy() by LLVM
982    # whenever possible...
983    assert dst.type == src.type
984    with for_range(builder, count, intp=count.type) as loop:
985        out_ptr = builder.gep(dst, [loop.index])
986        in_ptr = builder.gep(src, [loop.index])
987        builder.store(builder.load(in_ptr), out_ptr)
988
989
990def _raw_memcpy(builder, func_name, dst, src, count, itemsize, align):
991    size_t = count.type
992    if isinstance(itemsize, utils.INT_TYPES):
993        itemsize = ir.Constant(size_t, itemsize)
994
995    memcpy = builder.module.declare_intrinsic(func_name,
996                                              [voidptr_t, voidptr_t, size_t])
997    is_volatile = false_bit
998    builder.call(memcpy, [builder.bitcast(dst, voidptr_t),
999                          builder.bitcast(src, voidptr_t),
1000                          builder.mul(count, itemsize),
1001                          is_volatile])
1002
1003
1004def raw_memcpy(builder, dst, src, count, itemsize, align=1):
1005    """
1006    Emit a raw memcpy() call for `count` items of size `itemsize`
1007    from `src` to `dest`.
1008    """
1009    return _raw_memcpy(builder, 'llvm.memcpy', dst, src, count, itemsize, align)
1010
1011
1012def raw_memmove(builder, dst, src, count, itemsize, align=1):
1013    """
1014    Emit a raw memmove() call for `count` items of size `itemsize`
1015    from `src` to `dest`.
1016    """
1017    return _raw_memcpy(builder, 'llvm.memmove', dst, src, count,
1018                       itemsize, align)
1019
1020
1021def muladd_with_overflow(builder, a, b, c):
1022    """
1023    Compute (a * b + c) and return a (result, overflow bit) pair.
1024    The operands must be signed integers.
1025    """
1026    p = builder.smul_with_overflow(a, b)
1027    prod = builder.extract_value(p, 0)
1028    prod_ovf = builder.extract_value(p, 1)
1029    s = builder.sadd_with_overflow(prod, c)
1030    res = builder.extract_value(s, 0)
1031    ovf = builder.or_(prod_ovf, builder.extract_value(s, 1))
1032    return res, ovf
1033
1034
1035def printf(builder, format, *args):
1036    """
1037    Calls printf().
1038    Argument `format` is expected to be a Python string.
1039    Values to be printed are listed in `args`.
1040
1041    Note: There is no checking to ensure there is correct number of values
1042    in `args` and there type matches the declaration in the format string.
1043    """
1044    assert isinstance(format, str)
1045    mod = builder.module
1046    # Make global constant for format string
1047    cstring = voidptr_t
1048    fmt_bytes = make_bytearray((format + '\00').encode('ascii'))
1049    global_fmt = global_constant(mod, "printf_format", fmt_bytes)
1050    fnty = ir.FunctionType(int32_t, [cstring], var_arg=True)
1051    # Insert printf()
1052    try:
1053        fn = mod.get_global('printf')
1054    except KeyError:
1055        fn = ir.Function(mod, fnty, name="printf")
1056    # Call
1057    ptr_fmt = builder.bitcast(global_fmt, cstring)
1058    return builder.call(fn, [ptr_fmt] + list(args))
1059
1060
1061def snprintf(builder, buffer, bufsz, format, *args):
1062    """Calls libc snprintf(buffer, bufsz, format, ...args)
1063    """
1064    assert isinstance(format, str)
1065    mod = builder.module
1066    # Make global constant for format string
1067    cstring = voidptr_t
1068    fmt_bytes = make_bytearray((format + '\00').encode('ascii'))
1069    global_fmt = global_constant(mod, "snprintf_format", fmt_bytes)
1070    fnty = ir.FunctionType(
1071        int32_t, [cstring, intp_t, cstring], var_arg=True,
1072    )
1073    # Actual symbol name of snprintf is different on win32.
1074    symbol = 'snprintf'
1075    if config.IS_WIN32:
1076        symbol = '_' + symbol
1077    # Insert snprintf()
1078    try:
1079        fn = mod.get_global(symbol)
1080    except KeyError:
1081        fn = ir.Function(mod, fnty, name=symbol)
1082    # Call
1083    ptr_fmt = builder.bitcast(global_fmt, cstring)
1084    return builder.call(fn, [buffer, bufsz, ptr_fmt] + list(args))
1085
1086
1087def snprintf_stackbuffer(builder, bufsz, format, *args):
1088    """Similar to `snprintf()` but the buffer is stack allocated to size *bufsz*.
1089
1090    Returns the buffer pointer as i8*.
1091    """
1092    assert isinstance(bufsz, int)
1093    spacety = ir.ArrayType(ir.IntType(8), bufsz)
1094    space = alloca_once(builder, spacety, zfill=True)
1095    buffer = builder.bitcast(space, voidptr_t)
1096    snprintf(builder, buffer, intp_t(bufsz), format, *args)
1097    return buffer
1098
1099
1100def normalize_ir_text(text):
1101    """
1102    Normalize the given string to latin1 compatible encoding that is
1103    suitable for use in LLVM IR.
1104    """
1105    # Just re-encoding to latin1 is enough
1106    return text.encode('utf8').decode('latin1')
1107
1108
1109def hexdump(builder, ptr, nbytes):
1110    """Debug print the memory region in *ptr* to *ptr + nbytes*
1111    as hex.
1112    """
1113    bytes_per_line = 16
1114    nbytes = builder.zext(nbytes, intp_t)
1115    printf(builder, "hexdump p=%p n=%zu",
1116           ptr, nbytes)
1117    byte_t = ir.IntType(8)
1118    ptr = builder.bitcast(ptr, byte_t.as_pointer())
1119    # Loop to print the bytes in *ptr* as hex
1120    with for_range(builder, nbytes) as idx:
1121        div_by = builder.urem(idx.index, intp_t(bytes_per_line))
1122        do_new_line = builder.icmp_unsigned("==", div_by, intp_t(0))
1123        with builder.if_then(do_new_line):
1124            printf(builder, "\n")
1125
1126        offset = builder.gep(ptr, [idx.index])
1127        val = builder.load(offset)
1128        printf(builder, " %02x", val)
1129    printf(builder, "\n")
1130
1131
1132def is_nonelike(ty):
1133    """ returns if 'ty' is none """
1134    return (
1135        ty is None or
1136        isinstance(ty, types.NoneType) or
1137        isinstance(ty, types.Omitted)
1138    )
1139