1# "magictoken" is used for markers as beginning and ending of example text.
2
3import unittest
4
5# magictoken.ex_structref_type_definition.begin
6import numpy as np
7
8from numba import njit
9from numba.core import types
10from numba.experimental import structref
11
12from numba.tests.support import skip_unless_scipy
13
14
15# Define a StructRef.
16# `structref.register` associates the type with the default data model.
17# This will also install getters and setters to the fields of
18# the StructRef.
19@structref.register
20class MyStructType(types.StructRef):
21    def preprocess_fields(self, fields):
22        # This method is called by the type constructor for additional
23        # preprocessing on the fields.
24        # Here, we don't want the struct to take Literal types.
25        return tuple((name, types.unliteral(typ)) for name, typ in fields)
26
27
28# Define a Python type that can be use as a proxy to the StructRef
29# allocated inside Numba. Users can construct the StructRef via
30# the constructor for this type in python code and jit-code.
31class MyStruct(structref.StructRefProxy):
32    def __new__(cls, name, vector):
33        # Overriding the __new__ method is optional, doing so
34        # allows Python code to use keyword arguments,
35        # or add other customized behavior.
36        # The default __new__ takes `*args`.
37        # IMPORTANT: Users should not override __init__.
38        return structref.StructRefProxy.__new__(cls, name, vector)
39
40    # By default, the proxy type does not reflect the attributes or
41    # methods to the Python side. It is up to users to define
42    # these. (This may be automated in the future.)
43
44    @property
45    def name(self):
46        # To access a field, we can define a function that simply
47        # return the field in jit-code.
48        # The definition of MyStruct_get_name is shown later.
49        return MyStruct_get_name(self)
50
51    @property
52    def vector(self):
53        # The definition of MyStruct_get_vector is shown later.
54        return MyStruct_get_vector(self)
55
56
57@njit
58def MyStruct_get_name(self):
59    # In jit-code, the StructRef's attribute is exposed via
60    # structref.register
61    return self.name
62
63
64@njit
65def MyStruct_get_vector(self):
66    return self.vector
67
68
69# This associates the proxy with MyStructType for the given set of
70# fields. Notice how we are not contraining the type of each field.
71# Field types remain generic.
72structref.define_proxy(MyStruct, MyStructType, ["name", "vector"])
73# magictoken.ex_structref_type_definition.end
74
75
76@skip_unless_scipy
77class TestStructRefUsage(unittest.TestCase):
78    def test_type_definition(self):
79        np.random.seed(0)
80        # Redirect print
81        buf = []
82
83        def print(*args):
84            buf.append(args)
85
86        # magictoken.ex_structref_type_definition_test.begin
87        # Let's test our new StructRef.
88
89        # Define one in Python
90        alice = MyStruct("Alice", vector=np.random.random(3))
91
92        # Define one in jit-code
93        @njit
94        def make_bob():
95            bob = MyStruct("unnamed", vector=np.zeros(3))
96            # Mutate the attributes
97            bob.name = "Bob"
98            bob.vector = np.random.random(3)
99            return bob
100
101        bob = make_bob()
102
103        # Out: Alice: [0.5488135  0.71518937 0.60276338]
104        print(f"{alice.name}: {alice.vector}")
105        # Out: Bob: [0.88325739 0.73527629 0.87746707]
106        print(f"{bob.name}: {bob.vector}")
107
108        # Define a jit function to operate on the structs.
109        @njit
110        def distance(a, b):
111            return np.linalg.norm(a.vector - b.vector)
112
113        # Out: 0.4332647200356598
114        print(distance(alice, bob))
115        # magictoken.ex_structref_type_definition_test.end
116
117        self.assertEqual(len(buf), 3)
118
119    def test_overload_method(self):
120        # magictoken.ex_structref_method.begin
121        from numba.core.extending import overload_method
122        from numba.core.errors import TypingError
123
124        # Use @overload_method to add a method for
125        # MyStructType.distance(other)
126        # where *other* is an instance of MyStructType.
127        @overload_method(MyStructType, "distance")
128        def ol_distance(self, other):
129            # Guard that *other* is an instance of MyStructType
130            if not isinstance(other, MyStructType):
131                raise TypingError(
132                    f"*other* must be a {MyStructType}; got {other}"
133                )
134
135            def impl(self, other):
136                return np.linalg.norm(self.vector - other.vector)
137
138            return impl
139
140        # Test
141        @njit
142        def test():
143            alice = MyStruct("Alice", vector=np.random.random(3))
144            bob = MyStruct("Bob", vector=np.random.random(3))
145            # Use the method
146            return alice.distance(bob)
147        # magictoken.ex_structref_method.end
148
149        self.assertIsInstance(test(), float)
150