1#=========================================================================
2# test_utility.py
3#=========================================================================
4# Author : Peitian Pan
5# Date   : Jun 5, 2019
6"""Provide utility methods for testing."""
7import copy
8from collections import deque
9
10from hypothesis import strategies as st
11
12from pymtl3.datatypes import Bits1, mk_bits
13from pymtl3.dsl import OutPort
14from pymtl3.passes.rtlir import RTLIRDataType as rdt
15from pymtl3.passes.rtlir import RTLIRType as rt
16from pymtl3.stdlib.test_utils import TestVectorSimulator
17
18from ...yosys import YosysTranslationImportPass
19from .. import VerilogTranslationImportPass
20
21#=========================================================================
22# test utility functions
23#=========================================================================
24
25def trim( s ):
26  string = []
27  lines = s.split( '\n' )
28  for line in lines:
29    _line = line.split()
30    _string = "".join( _line )
31    if _string and not _string.startswith( '//' ):
32      string.append( "".join( line.split() ) )
33  return "\n".join( string )
34
35def check_eq( s, t ):
36  if isinstance( s, list ) and isinstance( t, list ):
37    for _s, _t in zip( s, t ):
38      assert trim(_s) == trim(_t)
39  else:
40    assert trim(s) == trim(t)
41
42#=========================================================================
43# Hypothesis strategies for testing
44#=========================================================================
45
46def flatten( _rtype ):
47  if isinstance( _rtype, rt.Array ):
48    n_dim = _rtype.get_dim_sizes()
49    rtype = _rtype.get_sub_type()
50  else:
51    n_dim = []
52    rtype = _rtype
53  return n_dim, rtype
54
55#-------------------------------------------------------------------------
56# Generate initialization data for a signal
57#-------------------------------------------------------------------------
58
59def VectorInitData( dtype ):
60  nbits = dtype.get_length()
61  min_val, max_val = -(2**(nbits-1)), 2**(nbits-1)-1
62  value = 0
63  return mk_bits( nbits )( value )
64
65def StructInitData( dtype ):
66  data = dtype.get_class()()
67  for field_name, field in dtype.get_all_properties().items():
68    setattr( data, field_name, DataTypeInitData( field ) )
69  return data
70
71def PackedArrayInitData( n_dim, sub_dtype ):
72  if not n_dim:
73    return DataTypeInitData( sub_dtype )
74  else:
75    data = []
76    for i in range(n_dim[0]):
77      data += [ PackedArrayInitData( n_dim[1:], sub_dtype ) ]
78    return data
79
80def DataTypeInitData( dtype ):
81  if isinstance( dtype, rdt.Vector ):
82    return VectorInitData( dtype )
83  elif isinstance( dtype, rdt.Struct ):
84    return StructInitData( dtype )
85  elif isinstance( dtype, rdt.PackedArray ):
86    n_dim = dtype.get_dim_sizes()
87    sub_dtype = dtype.get_sub_dtype()
88    return PackedArrayInitData( n_dim, sub_dtype )
89  else:
90    assert False, f"unrecognized data type {sub_dtype}!"
91
92def InPortInitData( id_, port ):
93  return { id_ : DataTypeInitData( port.get_dtype() ) }
94
95def InterfaceInitData( id_, ifc ):
96  init = {}
97  for prop_name, prop_rtype in ifc.get_all_properties_packed():
98    if isinstance( prop_rtype, rt.Array ):
99      n_dim = prop_rtype.get_dim_sizes()
100      sub_type = prop_rtype.get_sub_type()
101      if isinstance(sub_type, rt.Port) and sub_type.get_direction() != "input":
102        continue
103      init.update( ArrayInitData( id_+"."+prop_name, n_dim, sub_type ) )
104    elif isinstance( prop_rtype, rt.Port ):
105      if prop_rtype.get_direction() == "input":
106        init.update( InPortInitData( id_+"."+prop_name, prop_rtype ) )
107    elif isinstance( prop_rtype, rt.InterfaceView ):
108      init.update( InterfaceInitData( id_+"."+prop_name, prop_rtype ) )
109  return init
110
111def ArrayInitData( id_, n_dim, subtype ):
112  if not n_dim:
113    if isinstance( subtype, rt.Port ):
114      return InPortInitData( id_, subtype )
115    else:
116      return InterfaceInitData( id_, subtype )
117  else:
118    init = {}
119    for i in range(n_dim[0]):
120      init.update( ArrayInitData(f'{id_}[{i}]', n_dim[1:], subtype) )
121    return init
122
123#-------------------------------------------------------------------------
124# Hypothesis input data strategies
125#-------------------------------------------------------------------------
126
127@st.composite
128def VectorDataStrategy( draw, dtype ):
129  nbits = dtype.get_length()
130  min_val, max_val = -(2**(nbits-1)), 2**(nbits-1)-1
131  value = draw( st.integers( min_val, max_val ) )
132  return mk_bits( nbits )( value )
133
134@st.composite
135def StructDataStrategy( draw, dtype ):
136  data = dtype.get_class()()
137  for field_name, field in dtype.get_all_properties().items():
138    setattr( data, field_name, draw( DataTypeDataStrategy( field ) ) )
139  return data
140
141@st.composite
142def PackedArrayDataStrategy( draw, n_dim, sub_dtype ):
143  if not n_dim:
144    return draw( DataTypeDataStrategy( sub_dtype ) )
145  else:
146    data = []
147    for i in range(n_dim[0]):
148      data += [ draw(PackedArrayDataStrategy( n_dim[1:], sub_dtype )) ]
149    return data
150
151@st.composite
152def DataTypeDataStrategy( draw, dtype ):
153  if isinstance( dtype, rdt.Vector ):
154    return draw( VectorDataStrategy( dtype ) )
155  elif isinstance( dtype, rdt.Struct ):
156    return draw( StructDataStrategy( dtype ) )
157  elif isinstance( dtype, rdt.PackedArray ):
158    n_dim = dtype.get_dim_sizes()
159    sub_dtype = dtype.get_sub_dtype()
160    return draw( PackedArrayDataStrategy( n_dim, sub_dtype ) )
161  else:
162    assert False, f"unrecognized data type {sub_dtype}!"
163
164@st.composite
165def InPortDataStrategy( draw, id_, port ):
166  return { id_ : draw(DataTypeDataStrategy( port.get_dtype() )) }
167
168@st.composite
169def InterfaceDataStrategy( draw, id_, ifc ):
170  data = {}
171  for prop_name, prop_rtype in ifc.get_all_properties_packed():
172    if isinstance( prop_rtype, rt.Array ):
173      n_dim = prop_rtype.get_dim_sizes()
174      sub_type = prop_rtype.get_sub_type()
175      if isinstance(sub_type, rt.Port) and sub_type.get_direction() != "input":
176        continue
177      data.update(draw(ArrayDataStrategy(id_+"."+prop_name, n_dim, sub_type)))
178    elif isinstance( prop_rtype, rt.Port ):
179      if prop_rtype.get_direction() == "input":
180        data.update(draw(InPortDataStrategy(id_+"."+prop_name, prop_rtype)))
181    elif isinstance( prop_rtype, rt.InterfaceView ):
182      data.update(draw(InterfaceDataStrategy(id_+"."+prop_name, prop_rtype)))
183  return data
184
185@st.composite
186def ArrayDataStrategy( draw, id_, n_dim, subtype ):
187  if not n_dim:
188    if isinstance( subtype, rt.Port ):
189      return draw(InPortDataStrategy( id_, subtype ))
190    else:
191      return draw(InterfaceDataStrategy( id_, subtype ))
192  else:
193    data = {}
194    for i in range(n_dim[0]):
195      data.update(draw(
196        ArrayDataStrategy(f'{id_}[{i}]', n_dim[1:], subtype)))
197    return data
198
199@st.composite
200def DataStrategy( draw, dut ):
201  """Return a strategy that generates input vector for component `dut`."""
202  max_cycles = 10
203
204  ret = []
205  dut.elaborate()
206  rifc = rt.RTLIRGetter(cache=False).get_component_ifc_rtlir( dut )
207  ports = rifc.get_ports_packed()
208  ifcs = rifc.get_ifc_views_packed()
209
210  # Add reset cycle at the beginning
211  reset1, reset2 = {}, {}
212  for id_, port in ports:
213    if id_ == "clk":
214      reset1.update( { id_ : Bits1(0) } )
215      reset2.update( { id_ : Bits1(1) } )
216    elif id_ == "reset":
217      reset1.update( { id_ : Bits1(1) } )
218      reset2.update( { id_ : Bits1(1) } )
219    else:
220      n_dim, port_rtype = flatten( port )
221      if port_rtype.get_direction() == "input":
222        if n_dim:
223          reset1.update( ArrayInitData( id_, n_dim, port_rtype ) )
224          reset2.update( ArrayInitData( id_, n_dim, port_rtype ) )
225        else:
226          reset1.update( InPortInitData( id_, port_rtype ) )
227          reset2.update( InPortInitData( id_, port_rtype ) )
228  for id_, ifc in ifcs:
229    n_dim, ifc_rtype = flatten( ifc )
230    if n_dim:
231      reset1.update( ArrayDataStrategy( id_, n_dim, ifc_rtype ) )
232      reset2.update( ArrayDataStrategy( id_, n_dim, ifc_rtype ) )
233    else:
234      reset1.update( InterfaceInitData( id_, n_dim, ifc_rtype ) )
235      reset2.update( InterfaceInitData( id_, n_dim, ifc_rtype ) )
236
237  ret.append( reset1 )
238  ret.append( reset2 )
239
240  for i in range(max_cycles):
241    data = {}
242    for id_, port in ports:
243      if id_ in [ "clk", "reset" ]:
244        data.update( { id_ : Bits1(0) } )
245      else:
246        n_dim, port_rtype = flatten( port )
247        if n_dim:
248          if port_rtype.get_direction() == "input":
249            data.update(draw( ArrayDataStrategy( id_, n_dim, port_rtype ) ))
250        elif port_rtype.get_direction() == "input":
251          data.update(draw( InPortDataStrategy( id_, port_rtype ) ))
252    for id_, ifc in ifcs:
253      n_dim, ifc_rtype = flatten( ifc )
254      if n_dim:
255        data.update(draw( ArrayDataStrategy( id_, n_dim, ifc_rtype ) ))
256      else:
257        data.update(draw( InterfaceDataStrategy( id_, ifc_rtype ) ))
258
259    # Toggle clock signal
260    toggle_data = {}
261    for id_, signal in data.items():
262      if id_ == "clk":
263        toggle_data.update( { id_ : Bits1(1) } )
264      else:
265        toggle_data.update( { id_ : copy.deepcopy( signal ) } )
266
267    ret.append( data )
268    ret.append( toggle_data )
269
270  return ret
271
272#-------------------------------------------------------------------------
273# closed_loop_component_input_test
274#-------------------------------------------------------------------------
275
276def closed_loop_component_input_test( dut, test_vector, tv_in, backend = "verilog" ):
277
278  # Filter to collect all output ports of a component
279  def outport_filter( obj ):
280    return isinstance( obj, OutPort )
281
282  assert backend in [ "verilog", "yosys" ], f"invalid backend {backend}!"
283
284  dut.elaborate()
285  reference_output = deque()
286  all_output_ports = dut.get_local_object_filter( outport_filter )
287
288  # Method to record reference outputs of the pure python component
289  def ref_tv_out( model, test_vector ):
290    dct = {}
291    for out_port in all_output_ports:
292      dct[ out_port ] = eval( "model." + out_port._dsl.my_name ).clone() # WE NEED TO CLONE NOW
293    reference_output.append( dct )
294
295  # Method to compare the outputs of the imported model and the pure python one
296  def tv_out( model, test_vector ):
297    assert len(reference_output) > 0, \
298      "Reference runs for fewer cycles than the imported model!"
299    for out_port in all_output_ports:
300      ref = reference_output[0][out_port]
301      imp = eval( "model." + out_port._dsl.my_name )
302      assert ref == imp, f"Value mismatch: reference: {ref}, imported: {imp}"
303    reference_output.popleft()
304
305  # First simulate the pure python component to see if it has sane behavior
306  reference_sim = TestVectorSimulator( dut, test_vector, tv_in, ref_tv_out )
307  reference_sim.run_test()
308  dut.unlock_simulation()
309
310  # If it simulates correctly, translate it and import it back
311  dut.elaborate()
312  if backend == "verilog":
313    dut.set_metadata( VerilogTranslationImportPass.enable, True )
314    imported_obj = VerilogTranslationImportPass()( dut )
315  elif backend == "yosys":
316    dut.set_metadata( YosysTranslationImportPass.enable, True )
317    imported_obj = YosysTranslationImportPass()( dut )
318
319  # Run another vector simulator spin
320  imported_sim = TestVectorSimulator( imported_obj, test_vector, tv_in, tv_out )
321  imported_sim.run_test()
322
323#-------------------------------------------------------------------------
324# closed_loop_component_test
325#-------------------------------------------------------------------------
326
327def closed_loop_component_test( dut, data, backend = "verilog" ):
328  """Test the DUT with the given test_vector.
329
330  User who wish to use this method should pass in the hypothesis data
331  strategy instance as `data`. This method will reflect on the interfaces
332  and ports of the given DUT and generate input vector.
333  """
334  # Method to feed data into the DUT
335  def tv_in( model, test_vector ):
336    for name, data in test_vector.items():
337      # `setattr` fails to set the correct value of an array if indexed by
338      # a subscript. We use `exec` here to make sure the value of elements
339      # are assigned correctly.
340      exec( "model." + name + " @= data" )
341  test_vector = data.draw( DataStrategy( dut ) )
342  closed_loop_component_input_test( dut, test_vector, tv_in, backend )
343
344#-------------------------------------------------------------------------
345# closed_loop_test
346#-------------------------------------------------------------------------
347
348# TODO: A hypothesis test that works on generated test component AND
349# generated input data.
350def closed_loop_test():
351  pass
352