1import operator
2
3from numba.core import types, typing, cgutils
4
5from numba.core.imputils import (lower_cast, lower_builtin,
6                                 lower_getattr_generic, impl_ret_untracked,
7                                 lower_setattr_generic)
8
9
10def always_return_true_impl(context, builder, sig, args):
11    return cgutils.true_bit
12
13
14def always_return_false_impl(context, builder, sig, args):
15    return cgutils.false_bit
16
17
18def optional_is_none(context, builder, sig, args):
19    """
20    Check if an Optional value is invalid
21    """
22    [lty, rty] = sig.args
23    [lval, rval] = args
24
25    # Make sure None is on the right
26    if lty == types.none:
27        lty, rty = rty, lty
28        lval, rval = rval, lval
29
30    opt_type = lty
31    opt_val = lval
32
33    opt = context.make_helper(builder, opt_type, opt_val)
34    res = builder.not_(cgutils.as_bool_bit(builder, opt.valid))
35    return impl_ret_untracked(context, builder, sig.return_type, res)
36
37
38# None is/not None
39lower_builtin(operator.is_, types.none, types.none)(always_return_true_impl)
40
41# Optional is None
42lower_builtin(operator.is_, types.Optional, types.none)(optional_is_none)
43lower_builtin(operator.is_, types.none, types.Optional)(optional_is_none)
44
45
46@lower_getattr_generic(types.Optional)
47def optional_getattr(context, builder, typ, value, attr):
48    """
49    Optional.__getattr__ => redirect to the wrapped type.
50    """
51    inner_type = typ.type
52    val = context.cast(builder, value, typ, inner_type)
53    imp = context.get_getattr(inner_type, attr)
54    return imp(context, builder, inner_type, val, attr)
55
56
57@lower_setattr_generic(types.Optional)
58def optional_setattr(context, builder, sig, args, attr):
59    """
60    Optional.__setattr__ => redirect to the wrapped type.
61    """
62    basety, valty = sig.args
63    target, val = args
64    target_type = basety.type
65    target = context.cast(builder, target, basety, target_type)
66
67    newsig = typing.signature(sig.return_type, target_type, valty)
68    imp = context.get_setattr(attr, newsig)
69    return imp(builder, (target, val))
70
71
72@lower_cast(types.Optional, types.Optional)
73def optional_to_optional(context, builder, fromty, toty, val):
74    """
75    The handling of optional->optional cast must be special cased for
76    correct propagation of None value.  Given type T and U. casting of
77    T? to U? (? denotes optional) should always succeed.   If the from-value
78    is None, the None value the casted value (U?) should be None; otherwise,
79    the from-value is casted to U. This is different from casting T? to U,
80    which requires the from-value must not be None.
81    """
82    optval = context.make_helper(builder, fromty, value=val)
83    validbit = cgutils.as_bool_bit(builder, optval.valid)
84    # Create uninitialized optional value
85    outoptval = context.make_helper(builder, toty)
86
87    with builder.if_else(validbit) as (is_valid, is_not_valid):
88        with is_valid:
89            # Cast internal value
90            outoptval.valid = cgutils.true_bit
91            outoptval.data = context.cast(builder, optval.data,
92                                          fromty.type, toty.type)
93
94        with is_not_valid:
95            # Store None to result
96            outoptval.valid = cgutils.false_bit
97            outoptval.data = cgutils.get_null_value(
98                outoptval.data.type)
99
100    return outoptval._getvalue()
101
102
103@lower_cast(types.Any, types.Optional)
104def any_to_optional(context, builder, fromty, toty, val):
105    if fromty == types.none:
106        return context.make_optional_none(builder, toty.type)
107    else:
108        val = context.cast(builder, val, fromty, toty.type)
109        return context.make_optional_value(builder, toty.type, val)
110
111
112@lower_cast(types.Optional, types.Any)
113@lower_cast(types.Optional, types.Boolean)
114def optional_to_any(context, builder, fromty, toty, val):
115    optval = context.make_helper(builder, fromty, value=val)
116    validbit = cgutils.as_bool_bit(builder, optval.valid)
117    with builder.if_then(builder.not_(validbit), likely=False):
118        msg = "expected %s, got None" % (fromty.type,)
119        context.call_conv.return_user_exc(builder, TypeError, (msg,))
120
121    return context.cast(builder, optval.data, fromty.type, toty)
122