1# "magictoken" is used for markers as beginning and ending of example text. 2 3import unittest 4 5# magictoken.ex_structref_type_definition.begin 6import numpy as np 7 8from numba import njit 9from numba.core import types 10from numba.experimental import structref 11 12from numba.tests.support import skip_unless_scipy 13 14 15# Define a StructRef. 16# `structref.register` associates the type with the default data model. 17# This will also install getters and setters to the fields of 18# the StructRef. 19@structref.register 20class MyStructType(types.StructRef): 21 def preprocess_fields(self, fields): 22 # This method is called by the type constructor for additional 23 # preprocessing on the fields. 24 # Here, we don't want the struct to take Literal types. 25 return tuple((name, types.unliteral(typ)) for name, typ in fields) 26 27 28# Define a Python type that can be use as a proxy to the StructRef 29# allocated inside Numba. Users can construct the StructRef via 30# the constructor for this type in python code and jit-code. 31class MyStruct(structref.StructRefProxy): 32 def __new__(cls, name, vector): 33 # Overriding the __new__ method is optional, doing so 34 # allows Python code to use keyword arguments, 35 # or add other customized behavior. 36 # The default __new__ takes `*args`. 37 # IMPORTANT: Users should not override __init__. 38 return structref.StructRefProxy.__new__(cls, name, vector) 39 40 # By default, the proxy type does not reflect the attributes or 41 # methods to the Python side. It is up to users to define 42 # these. (This may be automated in the future.) 43 44 @property 45 def name(self): 46 # To access a field, we can define a function that simply 47 # return the field in jit-code. 48 # The definition of MyStruct_get_name is shown later. 49 return MyStruct_get_name(self) 50 51 @property 52 def vector(self): 53 # The definition of MyStruct_get_vector is shown later. 54 return MyStruct_get_vector(self) 55 56 57@njit 58def MyStruct_get_name(self): 59 # In jit-code, the StructRef's attribute is exposed via 60 # structref.register 61 return self.name 62 63 64@njit 65def MyStruct_get_vector(self): 66 return self.vector 67 68 69# This associates the proxy with MyStructType for the given set of 70# fields. Notice how we are not contraining the type of each field. 71# Field types remain generic. 72structref.define_proxy(MyStruct, MyStructType, ["name", "vector"]) 73# magictoken.ex_structref_type_definition.end 74 75 76@skip_unless_scipy 77class TestStructRefUsage(unittest.TestCase): 78 def test_type_definition(self): 79 np.random.seed(0) 80 # Redirect print 81 buf = [] 82 83 def print(*args): 84 buf.append(args) 85 86 # magictoken.ex_structref_type_definition_test.begin 87 # Let's test our new StructRef. 88 89 # Define one in Python 90 alice = MyStruct("Alice", vector=np.random.random(3)) 91 92 # Define one in jit-code 93 @njit 94 def make_bob(): 95 bob = MyStruct("unnamed", vector=np.zeros(3)) 96 # Mutate the attributes 97 bob.name = "Bob" 98 bob.vector = np.random.random(3) 99 return bob 100 101 bob = make_bob() 102 103 # Out: Alice: [0.5488135 0.71518937 0.60276338] 104 print(f"{alice.name}: {alice.vector}") 105 # Out: Bob: [0.88325739 0.73527629 0.87746707] 106 print(f"{bob.name}: {bob.vector}") 107 108 # Define a jit function to operate on the structs. 109 @njit 110 def distance(a, b): 111 return np.linalg.norm(a.vector - b.vector) 112 113 # Out: 0.4332647200356598 114 print(distance(alice, bob)) 115 # magictoken.ex_structref_type_definition_test.end 116 117 self.assertEqual(len(buf), 3) 118 119 def test_overload_method(self): 120 # magictoken.ex_structref_method.begin 121 from numba.core.extending import overload_method 122 from numba.core.errors import TypingError 123 124 # Use @overload_method to add a method for 125 # MyStructType.distance(other) 126 # where *other* is an instance of MyStructType. 127 @overload_method(MyStructType, "distance") 128 def ol_distance(self, other): 129 # Guard that *other* is an instance of MyStructType 130 if not isinstance(other, MyStructType): 131 raise TypingError( 132 f"*other* must be a {MyStructType}; got {other}" 133 ) 134 135 def impl(self, other): 136 return np.linalg.norm(self.vector - other.vector) 137 138 return impl 139 140 # Test 141 @njit 142 def test(): 143 alice = MyStruct("Alice", vector=np.random.random(3)) 144 bob = MyStruct("Bob", vector=np.random.random(3)) 145 # Use the method 146 return alice.distance(bob) 147 # magictoken.ex_structref_method.end 148 149 self.assertIsInstance(test(), float) 150