1"""
2========================================================================
3NamedObject.py
4========================================================================
5This is the very base class of PyMTL objects. It enables all pymtl
6components/signals to share the functionality of recursively collecting
7objects and tagging objects with the full name.
8
9We bookkeep name hierarchy for error message and other purposes.
10For example, s.x[0][3].y[2].z[1:3] is stored as
11( ["top","x","y","z"], [ [], [0,3], [2], [slice(1,3,0)] ] )
12Note that len(name) == len(idx)-1 only when the variable ends with slice
13
14We keep all metadata in inst._dsl.*. This is to create a namespace
15to centralize all DSL-related metadata. Passes will create other
16namespaces to put their created metadata.
17
18Author : Shunning Jiang, Yanghui Ou
19Date   : Nov 3, 2018
20"""
21import re
22from collections import deque
23
24from .errors import FieldReassignError, NotElaboratedError
25
26
27class DSLMetadata:
28  pass
29
30# Special data structure for constructing the parameter tree.
31class ParamTreeNode:
32  def __init__( self ):
33    self.compiled_re = None
34    self.children = None
35    self.leaf = None
36
37  # TODO do we still need to lazily create leaf?
38  def merge( self, other ):
39    # Merge leaf
40    if other.leaf is not None:
41      # Lazily create leaf
42      if self.leaf is None:
43        self.leaf = {}
44      for func_name, subdict in other.leaf.items():
45        if func_name not in self.leaf:
46          self.leaf[ func_name ] = {}
47        self.leaf[ func_name ].update( subdict )
48
49    # Merge children
50    if other.children is not None:
51      # Lazily create children
52      if self.children is None:
53        self.children = {}
54      for comp_name, node in other.children.items():
55        if comp_name in self.children:
56          self.children[ comp_name ].merge( node )
57        else:
58          self.children[ comp_name ] = node
59
60  def add_params( self, strs, func_name, **kwargs ):
61
62    if self.leaf is None:
63      self.leaf = {}
64      self.children = {}
65
66    # Traverse to the node
67    cur_node = self
68    idx = 1
69    for comp_name in strs:
70      # Lazily create children
71      if cur_node.children is None:
72        cur_node.children = {}
73      if comp_name not in cur_node.children:
74        new_node = ParamTreeNode()
75        if '*' in comp_name:
76          new_node.compiled_re = re.compile( comp_name )
77          # Recursively update exisiting nodes that matches the regex
78          for name, node in cur_node.children.items():
79            if node.compiled_re is None:
80              if new_node.compiled_re.match( name ):
81                node.add_params( strs[idx:], func_name, **kwargs )
82        cur_node.children[ comp_name ] = new_node
83        cur_node = new_node
84      else:
85        new_node = cur_node.children.pop( comp_name )
86        cur_node.children[ comp_name ] = new_node
87        cur_node = new_node
88      idx += 1
89
90    # Add parameters to leaf
91    if cur_node.leaf is None:
92      cur_node.leaf = {}
93    if func_name not in cur_node.leaf:
94      cur_node.leaf[ func_name ] = {}
95    cur_node.leaf[ func_name].update( kwargs )
96
97  def __repr__( self ):
98    return f"\nleaf:{self.leaf}\nchildren:{self.children}"
99
100class NamedObject:
101
102  def __new__( cls, *args, **kwargs ):
103
104    inst = super().__new__( cls )
105    inst._dsl = DSLMetadata() # TODO an actual object?
106
107    # Save parameters for elaborate
108
109    inst._dsl.args        = args
110    inst._dsl.kwargs      = kwargs
111    inst._dsl.constructed = False
112
113    # A tree of parameters.
114    inst._dsl.param_tree = None
115
116    return inst
117
118  #-----------------------------------------------------------------------
119  # Private methods
120  #-----------------------------------------------------------------------
121
122  def _construct( s ):
123
124    if not s._dsl.constructed:
125
126      # Merge the actual keyword args and those args set by set_parameter
127      if s._dsl.param_tree is None:
128        kwargs = s._dsl.kwargs
129      elif s._dsl.param_tree.leaf is None:
130        kwargs = s._dsl.kwargs
131      else:
132        kwargs = s._dsl.kwargs
133        if "construct" in s._dsl.param_tree.leaf:
134          more_args = s._dsl.param_tree.leaf[ "construct" ]
135          kwargs.update( more_args )
136
137      s.construct( *s._dsl.args, **kwargs )
138
139      s._dsl.constructed = True
140
141  def __setattr_for_elaborate__( s, name, obj ):
142
143    # I use non-recursive BFS to reduce error message depth
144
145    if name[0] != '_': # filter private variables
146      sd = s._dsl
147
148      # Shunning: here I optimize for common cases where the object is a NamedObject.
149      # I used to push the object directly to a stack to reuse the code across
150      # both NamedObject and list cases. Now I basically avoid the stack overheads
151      # for common cases.
152      if isinstance( obj, NamedObject ):
153        fields = sd.NamedObject_fields
154        if name in fields:
155          if getattr( s, name ) is obj:
156            return
157          raise FieldReassignError(f"The attempt to assign hardware construct to field {name} is illegal:\n"
158                                   f" - top{repr(s)[1:]} already has field {name} with type {type(getattr( s, name ))}.")
159        fields.add( name )
160
161        ud = obj._dsl
162
163        ud.parent_obj = s
164        ud.level      = sd.level + 1
165
166        ud._my_name  = ud.my_name = name
167        ud.full_name = f"{sd.full_name}.{name}"
168
169        ud._my_indices = None
170
171        # Iterate through the param_tree and update u
172        if sd.param_tree is not None:
173          if sd.param_tree.children is not None:
174            for comp_name, node in sd.param_tree.children.items():
175              if comp_name == name:
176                # Lazily create the param tree
177                if ud.param_tree is None:
178                  ud.param_tree = ParamTreeNode()
179                ud.param_tree.merge( node )
180
181              elif node.compiled_re is not None:
182                if node.compiled_re.match( name ):
183                  # Lazily create the param tree
184                  if ud.param_tree is None:
185                    ud.param_tree = ParamTreeNode()
186                  ud.param_tree.merge( node )
187
188        ud.NamedObject_fields = set()
189
190        # Point u's top to my top
191        top = ud.elaborate_top = sd.elaborate_top
192
193        NamedObject._elaborate_stack.append( obj )
194        obj._construct()
195        NamedObject._elaborate_stack.pop()
196
197      # ONLY LIST IS SUPPORTED, SORRY.
198      # I don't want to support any iterable object because later "Wire"
199      # can be infinitely iterated and cause infinite loop. Special
200      # casing Wire will be a mess around everywhere.
201
202      elif isinstance( obj, list ) and obj and isinstance( obj[0], (NamedObject, list) ):
203        fields = sd.NamedObject_fields
204        if name in fields:
205          if getattr( s, name ) is obj:
206            return
207          raise FieldReassignError(f"The attempt to assign hardware construct to field {name} is illegal:\n"
208                                   f" - top{repr(s)[1:]} already has field {name} with type {type(getattr( s, name ))}.")
209        fields.add( name )
210
211        Q = deque( (u, (i,)) for i, u in enumerate(obj) )
212
213        while Q:
214          u, indices = Q.popleft()
215
216          if isinstance( u, NamedObject ):
217            ud = u._dsl
218
219            ud.parent_obj = s
220            ud.level      = sd.level + 1
221
222            ud._my_name  = name
223            ud.my_name   = u_name = name + "".join( [ f"[{x}]" for x in indices ] )
224            ud.full_name = f"{sd.full_name}.{u_name}"
225
226            ud._my_indices = indices
227
228            # Iterate through the param_tree and update u
229            if sd.param_tree is not None:
230              if sd.param_tree.children is not None:
231                for comp_name, node in sd.param_tree.children.items():
232                  if comp_name == u_name:
233                    # Lazily create the param tree
234                    if ud.param_tree is None:
235                      ud.param_tree = ParamTreeNode()
236                    ud.param_tree.merge( node )
237
238                  elif node.compiled_re is not None:
239                    if node.compiled_re.match( u_name ):
240                      # Lazily create the param tree
241                      if ud.param_tree is None:
242                        ud.param_tree = ParamTreeNode()
243                      ud.param_tree.merge( node )
244
245            ud.NamedObject_fields = set()
246
247            # Point u's top to my top
248            top = ud.elaborate_top = sd.elaborate_top
249
250            NamedObject._elaborate_stack.append( u )
251            u._construct()
252            NamedObject._elaborate_stack.pop()
253
254          elif isinstance( u, list ):
255            Q.extend( (v, indices+(i,)) for i, v in enumerate(u) )
256
257    super().__setattr__( name, obj )
258
259  def _collect_all_single( s, filt=lambda x: isinstance( x, NamedObject ) ):
260    ret = set()
261    stack = [s]
262    while stack:
263      u = stack.pop()
264
265      if   isinstance( u, NamedObject ):
266        if filt( u ): # Check if m satisfies the filter
267          ret.add( u )
268
269        for name, obj in u.__dict__.items():
270
271          # If the id is string, it is a normal children field. Otherwise it
272          # should be an tuple that represents a slice
273
274          if   isinstance( name, str ):
275            if name[0] != '_': # filter private variables
276              stack.append( obj )
277
278          elif isinstance( name, tuple ): # name = [1:3]
279            stack.append( obj )
280
281      # ONLY LIST IS SUPPORTED
282      elif isinstance( u, list ):
283        stack.extend( u )
284    return ret
285
286  # It is possible to take multiple filters
287  def _collect_all( s, filt=[ lambda x: isinstance( x, NamedObject ) ] ):
288    ret = [ set() for _ in filt ]
289    stack = [s]
290    while stack:
291      u = stack.pop()
292      if   isinstance( u, NamedObject ):
293
294        for i in range( len(filt) ):
295          if filt[i]( u ): # Check if m satisfies the filter
296            ret[i].add( u )
297
298        for name, obj in u.__dict__.items():
299
300          # If the id is string, it is a normal children field. Otherwise it
301          # should be an tuple that represents a slice
302
303          if   isinstance( name, str ):
304            if name[0] != '_': # filter private variables
305              stack.append( obj )
306
307          elif isinstance( name, tuple ): # name = [1:3]
308            stack.append( obj )
309
310      # ONLY LIST IS SUPPORTED
311      elif isinstance( u, list ):
312        stack.extend( u )
313    return ret
314
315  # Developers should use repr(x) everywhere to get the name
316
317  def __repr__( s ):
318    try:
319      return s._dsl.full_name
320    except AttributeError:
321      return super().__repr__()
322
323  #-----------------------------------------------------------------------
324  # Construction time APIs
325  #-----------------------------------------------------------------------
326
327  def construct( s, *args, **kwargs ):
328    pass
329
330  def set_param( s, string, **kwargs ):
331    # Assert no positional argumets
332    # assert len( s._dsl.args ) == 0, \
333    #   "Cannot use set_param because {} has positional arguments!".format(s._dsl.my_name)
334    assert not s._dsl.constructed
335
336    strs = string.split( "." )
337
338    assert strs[0] == "top", "The component should start at top"
339    assert '*' not in strs[-1], "We don't support * with function name!"
340
341    assert len( strs ) >= 2
342    func_name = strs[-1]
343    strs = strs[1:-1]
344    if s._dsl.param_tree is None:
345      s._dsl.param_tree = ParamTreeNode()
346    s._dsl.param_tree.add_params( strs, func_name, **kwargs )
347
348  # There are two reason I refactored this function into two separate
349  # functions. First of all in later levels of components, named objects
350  # can be spawned after the previous monolithic elaborate and hence this
351  # collect part won't capture them. Second, later levels can override
352  # this function and simply call construct at the beginning and call
353  # collect at the middle/end.
354
355  #-----------------------------------------------------------------------
356  # elaborate
357  #-----------------------------------------------------------------------
358
359  def _elaborate_construct( s ):
360
361    if s._dsl.constructed:
362      # Yanghui : Mute the warning for the isca tutorial.
363      # warnings.warn( "Don't elaborate the same model twice. "
364      #                "Use APIs to mutate the model." )
365      return
366
367    # Initialize the top level
368
369    s._dsl.parent_obj    = None
370    s._dsl.level         = 0
371    s._dsl.my_name       = "s"
372    s._dsl.full_name     = "s"
373    s._dsl.elaborate_top = s
374    s._dsl.NamedObject_fields = set()
375
376    # Secret sauce for letting the child know the field name of itself
377    # -- override setattr for elaboration, and remove it afterwards
378    # -- and the global elaborate to enable free function as decorator
379
380    NamedObject.__setattr__ = NamedObject.__setattr_for_elaborate__
381    NamedObject._elaborate_stack = [ s ]
382
383    try:
384      s._construct()
385    except Exception:
386      # re-raise here after deleting __setattr__
387      del NamedObject.__setattr__ # not harming the rest of execution
388      del NamedObject._elaborate_stack
389      raise
390
391    del NamedObject.__setattr__
392    del NamedObject._elaborate_stack
393
394  def _elaborate_collect_all_named_objects( s ):
395    s._dsl.all_named_objects = s._collect_all_single()
396
397  def elaborate( s ):
398    s._elaborate_construct()
399    s._elaborate_collect_all_named_objects()
400
401  #-----------------------------------------------------------------------
402  # Post-elaborate public APIs (can only be called after elaboration)
403  #-----------------------------------------------------------------------
404
405  def is_component( s ):
406    raise NotImplementedError
407
408  def is_signal( s ):
409    raise NotImplementedError
410
411  def is_interface( s ):
412    raise NotImplementedError
413
414  # These two APIs are reused across Connectable and Component
415
416  def get_field_name( s ):
417    try:
418      return s._dsl.my_name
419    except AttributeError:
420      raise NotElaboratedError()
421
422  def get_parent_object( s ):
423    try:
424      return s._dsl.parent_obj
425    except AttributeError:
426      raise NotElaboratedError()
427