1""" 2======================================================================== 3NamedObject.py 4======================================================================== 5This is the very base class of PyMTL objects. It enables all pymtl 6components/signals to share the functionality of recursively collecting 7objects and tagging objects with the full name. 8 9We bookkeep name hierarchy for error message and other purposes. 10For example, s.x[0][3].y[2].z[1:3] is stored as 11( ["top","x","y","z"], [ [], [0,3], [2], [slice(1,3,0)] ] ) 12Note that len(name) == len(idx)-1 only when the variable ends with slice 13 14We keep all metadata in inst._dsl.*. This is to create a namespace 15to centralize all DSL-related metadata. Passes will create other 16namespaces to put their created metadata. 17 18Author : Shunning Jiang, Yanghui Ou 19Date : Nov 3, 2018 20""" 21import re 22from collections import deque 23 24from .errors import FieldReassignError, NotElaboratedError 25 26 27class DSLMetadata: 28 pass 29 30# Special data structure for constructing the parameter tree. 31class ParamTreeNode: 32 def __init__( self ): 33 self.compiled_re = None 34 self.children = None 35 self.leaf = None 36 37 # TODO do we still need to lazily create leaf? 38 def merge( self, other ): 39 # Merge leaf 40 if other.leaf is not None: 41 # Lazily create leaf 42 if self.leaf is None: 43 self.leaf = {} 44 for func_name, subdict in other.leaf.items(): 45 if func_name not in self.leaf: 46 self.leaf[ func_name ] = {} 47 self.leaf[ func_name ].update( subdict ) 48 49 # Merge children 50 if other.children is not None: 51 # Lazily create children 52 if self.children is None: 53 self.children = {} 54 for comp_name, node in other.children.items(): 55 if comp_name in self.children: 56 self.children[ comp_name ].merge( node ) 57 else: 58 self.children[ comp_name ] = node 59 60 def add_params( self, strs, func_name, **kwargs ): 61 62 if self.leaf is None: 63 self.leaf = {} 64 self.children = {} 65 66 # Traverse to the node 67 cur_node = self 68 idx = 1 69 for comp_name in strs: 70 # Lazily create children 71 if cur_node.children is None: 72 cur_node.children = {} 73 if comp_name not in cur_node.children: 74 new_node = ParamTreeNode() 75 if '*' in comp_name: 76 new_node.compiled_re = re.compile( comp_name ) 77 # Recursively update exisiting nodes that matches the regex 78 for name, node in cur_node.children.items(): 79 if node.compiled_re is None: 80 if new_node.compiled_re.match( name ): 81 node.add_params( strs[idx:], func_name, **kwargs ) 82 cur_node.children[ comp_name ] = new_node 83 cur_node = new_node 84 else: 85 new_node = cur_node.children.pop( comp_name ) 86 cur_node.children[ comp_name ] = new_node 87 cur_node = new_node 88 idx += 1 89 90 # Add parameters to leaf 91 if cur_node.leaf is None: 92 cur_node.leaf = {} 93 if func_name not in cur_node.leaf: 94 cur_node.leaf[ func_name ] = {} 95 cur_node.leaf[ func_name].update( kwargs ) 96 97 def __repr__( self ): 98 return f"\nleaf:{self.leaf}\nchildren:{self.children}" 99 100class NamedObject: 101 102 def __new__( cls, *args, **kwargs ): 103 104 inst = super().__new__( cls ) 105 inst._dsl = DSLMetadata() # TODO an actual object? 106 107 # Save parameters for elaborate 108 109 inst._dsl.args = args 110 inst._dsl.kwargs = kwargs 111 inst._dsl.constructed = False 112 113 # A tree of parameters. 114 inst._dsl.param_tree = None 115 116 return inst 117 118 #----------------------------------------------------------------------- 119 # Private methods 120 #----------------------------------------------------------------------- 121 122 def _construct( s ): 123 124 if not s._dsl.constructed: 125 126 # Merge the actual keyword args and those args set by set_parameter 127 if s._dsl.param_tree is None: 128 kwargs = s._dsl.kwargs 129 elif s._dsl.param_tree.leaf is None: 130 kwargs = s._dsl.kwargs 131 else: 132 kwargs = s._dsl.kwargs 133 if "construct" in s._dsl.param_tree.leaf: 134 more_args = s._dsl.param_tree.leaf[ "construct" ] 135 kwargs.update( more_args ) 136 137 s.construct( *s._dsl.args, **kwargs ) 138 139 s._dsl.constructed = True 140 141 def __setattr_for_elaborate__( s, name, obj ): 142 143 # I use non-recursive BFS to reduce error message depth 144 145 if name[0] != '_': # filter private variables 146 sd = s._dsl 147 148 # Shunning: here I optimize for common cases where the object is a NamedObject. 149 # I used to push the object directly to a stack to reuse the code across 150 # both NamedObject and list cases. Now I basically avoid the stack overheads 151 # for common cases. 152 if isinstance( obj, NamedObject ): 153 fields = sd.NamedObject_fields 154 if name in fields: 155 if getattr( s, name ) is obj: 156 return 157 raise FieldReassignError(f"The attempt to assign hardware construct to field {name} is illegal:\n" 158 f" - top{repr(s)[1:]} already has field {name} with type {type(getattr( s, name ))}.") 159 fields.add( name ) 160 161 ud = obj._dsl 162 163 ud.parent_obj = s 164 ud.level = sd.level + 1 165 166 ud._my_name = ud.my_name = name 167 ud.full_name = f"{sd.full_name}.{name}" 168 169 ud._my_indices = None 170 171 # Iterate through the param_tree and update u 172 if sd.param_tree is not None: 173 if sd.param_tree.children is not None: 174 for comp_name, node in sd.param_tree.children.items(): 175 if comp_name == name: 176 # Lazily create the param tree 177 if ud.param_tree is None: 178 ud.param_tree = ParamTreeNode() 179 ud.param_tree.merge( node ) 180 181 elif node.compiled_re is not None: 182 if node.compiled_re.match( name ): 183 # Lazily create the param tree 184 if ud.param_tree is None: 185 ud.param_tree = ParamTreeNode() 186 ud.param_tree.merge( node ) 187 188 ud.NamedObject_fields = set() 189 190 # Point u's top to my top 191 top = ud.elaborate_top = sd.elaborate_top 192 193 NamedObject._elaborate_stack.append( obj ) 194 obj._construct() 195 NamedObject._elaborate_stack.pop() 196 197 # ONLY LIST IS SUPPORTED, SORRY. 198 # I don't want to support any iterable object because later "Wire" 199 # can be infinitely iterated and cause infinite loop. Special 200 # casing Wire will be a mess around everywhere. 201 202 elif isinstance( obj, list ) and obj and isinstance( obj[0], (NamedObject, list) ): 203 fields = sd.NamedObject_fields 204 if name in fields: 205 if getattr( s, name ) is obj: 206 return 207 raise FieldReassignError(f"The attempt to assign hardware construct to field {name} is illegal:\n" 208 f" - top{repr(s)[1:]} already has field {name} with type {type(getattr( s, name ))}.") 209 fields.add( name ) 210 211 Q = deque( (u, (i,)) for i, u in enumerate(obj) ) 212 213 while Q: 214 u, indices = Q.popleft() 215 216 if isinstance( u, NamedObject ): 217 ud = u._dsl 218 219 ud.parent_obj = s 220 ud.level = sd.level + 1 221 222 ud._my_name = name 223 ud.my_name = u_name = name + "".join( [ f"[{x}]" for x in indices ] ) 224 ud.full_name = f"{sd.full_name}.{u_name}" 225 226 ud._my_indices = indices 227 228 # Iterate through the param_tree and update u 229 if sd.param_tree is not None: 230 if sd.param_tree.children is not None: 231 for comp_name, node in sd.param_tree.children.items(): 232 if comp_name == u_name: 233 # Lazily create the param tree 234 if ud.param_tree is None: 235 ud.param_tree = ParamTreeNode() 236 ud.param_tree.merge( node ) 237 238 elif node.compiled_re is not None: 239 if node.compiled_re.match( u_name ): 240 # Lazily create the param tree 241 if ud.param_tree is None: 242 ud.param_tree = ParamTreeNode() 243 ud.param_tree.merge( node ) 244 245 ud.NamedObject_fields = set() 246 247 # Point u's top to my top 248 top = ud.elaborate_top = sd.elaborate_top 249 250 NamedObject._elaborate_stack.append( u ) 251 u._construct() 252 NamedObject._elaborate_stack.pop() 253 254 elif isinstance( u, list ): 255 Q.extend( (v, indices+(i,)) for i, v in enumerate(u) ) 256 257 super().__setattr__( name, obj ) 258 259 def _collect_all_single( s, filt=lambda x: isinstance( x, NamedObject ) ): 260 ret = set() 261 stack = [s] 262 while stack: 263 u = stack.pop() 264 265 if isinstance( u, NamedObject ): 266 if filt( u ): # Check if m satisfies the filter 267 ret.add( u ) 268 269 for name, obj in u.__dict__.items(): 270 271 # If the id is string, it is a normal children field. Otherwise it 272 # should be an tuple that represents a slice 273 274 if isinstance( name, str ): 275 if name[0] != '_': # filter private variables 276 stack.append( obj ) 277 278 elif isinstance( name, tuple ): # name = [1:3] 279 stack.append( obj ) 280 281 # ONLY LIST IS SUPPORTED 282 elif isinstance( u, list ): 283 stack.extend( u ) 284 return ret 285 286 # It is possible to take multiple filters 287 def _collect_all( s, filt=[ lambda x: isinstance( x, NamedObject ) ] ): 288 ret = [ set() for _ in filt ] 289 stack = [s] 290 while stack: 291 u = stack.pop() 292 if isinstance( u, NamedObject ): 293 294 for i in range( len(filt) ): 295 if filt[i]( u ): # Check if m satisfies the filter 296 ret[i].add( u ) 297 298 for name, obj in u.__dict__.items(): 299 300 # If the id is string, it is a normal children field. Otherwise it 301 # should be an tuple that represents a slice 302 303 if isinstance( name, str ): 304 if name[0] != '_': # filter private variables 305 stack.append( obj ) 306 307 elif isinstance( name, tuple ): # name = [1:3] 308 stack.append( obj ) 309 310 # ONLY LIST IS SUPPORTED 311 elif isinstance( u, list ): 312 stack.extend( u ) 313 return ret 314 315 # Developers should use repr(x) everywhere to get the name 316 317 def __repr__( s ): 318 try: 319 return s._dsl.full_name 320 except AttributeError: 321 return super().__repr__() 322 323 #----------------------------------------------------------------------- 324 # Construction time APIs 325 #----------------------------------------------------------------------- 326 327 def construct( s, *args, **kwargs ): 328 pass 329 330 def set_param( s, string, **kwargs ): 331 # Assert no positional argumets 332 # assert len( s._dsl.args ) == 0, \ 333 # "Cannot use set_param because {} has positional arguments!".format(s._dsl.my_name) 334 assert not s._dsl.constructed 335 336 strs = string.split( "." ) 337 338 assert strs[0] == "top", "The component should start at top" 339 assert '*' not in strs[-1], "We don't support * with function name!" 340 341 assert len( strs ) >= 2 342 func_name = strs[-1] 343 strs = strs[1:-1] 344 if s._dsl.param_tree is None: 345 s._dsl.param_tree = ParamTreeNode() 346 s._dsl.param_tree.add_params( strs, func_name, **kwargs ) 347 348 # There are two reason I refactored this function into two separate 349 # functions. First of all in later levels of components, named objects 350 # can be spawned after the previous monolithic elaborate and hence this 351 # collect part won't capture them. Second, later levels can override 352 # this function and simply call construct at the beginning and call 353 # collect at the middle/end. 354 355 #----------------------------------------------------------------------- 356 # elaborate 357 #----------------------------------------------------------------------- 358 359 def _elaborate_construct( s ): 360 361 if s._dsl.constructed: 362 # Yanghui : Mute the warning for the isca tutorial. 363 # warnings.warn( "Don't elaborate the same model twice. " 364 # "Use APIs to mutate the model." ) 365 return 366 367 # Initialize the top level 368 369 s._dsl.parent_obj = None 370 s._dsl.level = 0 371 s._dsl.my_name = "s" 372 s._dsl.full_name = "s" 373 s._dsl.elaborate_top = s 374 s._dsl.NamedObject_fields = set() 375 376 # Secret sauce for letting the child know the field name of itself 377 # -- override setattr for elaboration, and remove it afterwards 378 # -- and the global elaborate to enable free function as decorator 379 380 NamedObject.__setattr__ = NamedObject.__setattr_for_elaborate__ 381 NamedObject._elaborate_stack = [ s ] 382 383 try: 384 s._construct() 385 except Exception: 386 # re-raise here after deleting __setattr__ 387 del NamedObject.__setattr__ # not harming the rest of execution 388 del NamedObject._elaborate_stack 389 raise 390 391 del NamedObject.__setattr__ 392 del NamedObject._elaborate_stack 393 394 def _elaborate_collect_all_named_objects( s ): 395 s._dsl.all_named_objects = s._collect_all_single() 396 397 def elaborate( s ): 398 s._elaborate_construct() 399 s._elaborate_collect_all_named_objects() 400 401 #----------------------------------------------------------------------- 402 # Post-elaborate public APIs (can only be called after elaboration) 403 #----------------------------------------------------------------------- 404 405 def is_component( s ): 406 raise NotImplementedError 407 408 def is_signal( s ): 409 raise NotImplementedError 410 411 def is_interface( s ): 412 raise NotImplementedError 413 414 # These two APIs are reused across Connectable and Component 415 416 def get_field_name( s ): 417 try: 418 return s._dsl.my_name 419 except AttributeError: 420 raise NotElaboratedError() 421 422 def get_parent_object( s ): 423 try: 424 return s._dsl.parent_obj 425 except AttributeError: 426 raise NotElaboratedError() 427