1#========================================================================= 2# test_utility.py 3#========================================================================= 4# Author : Peitian Pan 5# Date : Jun 5, 2019 6"""Provide utility methods for testing.""" 7import copy 8from collections import deque 9 10from hypothesis import strategies as st 11 12from pymtl3.datatypes import Bits1, mk_bits 13from pymtl3.dsl import OutPort 14from pymtl3.passes.rtlir import RTLIRDataType as rdt 15from pymtl3.passes.rtlir import RTLIRType as rt 16from pymtl3.stdlib.test_utils import TestVectorSimulator 17 18from ...yosys import YosysTranslationImportPass 19from .. import VerilogTranslationImportPass 20 21#========================================================================= 22# test utility functions 23#========================================================================= 24 25def trim( s ): 26 string = [] 27 lines = s.split( '\n' ) 28 for line in lines: 29 _line = line.split() 30 _string = "".join( _line ) 31 if _string and not _string.startswith( '//' ): 32 string.append( "".join( line.split() ) ) 33 return "\n".join( string ) 34 35def check_eq( s, t ): 36 if isinstance( s, list ) and isinstance( t, list ): 37 for _s, _t in zip( s, t ): 38 assert trim(_s) == trim(_t) 39 else: 40 assert trim(s) == trim(t) 41 42#========================================================================= 43# Hypothesis strategies for testing 44#========================================================================= 45 46def flatten( _rtype ): 47 if isinstance( _rtype, rt.Array ): 48 n_dim = _rtype.get_dim_sizes() 49 rtype = _rtype.get_sub_type() 50 else: 51 n_dim = [] 52 rtype = _rtype 53 return n_dim, rtype 54 55#------------------------------------------------------------------------- 56# Generate initialization data for a signal 57#------------------------------------------------------------------------- 58 59def VectorInitData( dtype ): 60 nbits = dtype.get_length() 61 min_val, max_val = -(2**(nbits-1)), 2**(nbits-1)-1 62 value = 0 63 return mk_bits( nbits )( value ) 64 65def StructInitData( dtype ): 66 data = dtype.get_class()() 67 for field_name, field in dtype.get_all_properties().items(): 68 setattr( data, field_name, DataTypeInitData( field ) ) 69 return data 70 71def PackedArrayInitData( n_dim, sub_dtype ): 72 if not n_dim: 73 return DataTypeInitData( sub_dtype ) 74 else: 75 data = [] 76 for i in range(n_dim[0]): 77 data += [ PackedArrayInitData( n_dim[1:], sub_dtype ) ] 78 return data 79 80def DataTypeInitData( dtype ): 81 if isinstance( dtype, rdt.Vector ): 82 return VectorInitData( dtype ) 83 elif isinstance( dtype, rdt.Struct ): 84 return StructInitData( dtype ) 85 elif isinstance( dtype, rdt.PackedArray ): 86 n_dim = dtype.get_dim_sizes() 87 sub_dtype = dtype.get_sub_dtype() 88 return PackedArrayInitData( n_dim, sub_dtype ) 89 else: 90 assert False, f"unrecognized data type {sub_dtype}!" 91 92def InPortInitData( id_, port ): 93 return { id_ : DataTypeInitData( port.get_dtype() ) } 94 95def InterfaceInitData( id_, ifc ): 96 init = {} 97 for prop_name, prop_rtype in ifc.get_all_properties_packed(): 98 if isinstance( prop_rtype, rt.Array ): 99 n_dim = prop_rtype.get_dim_sizes() 100 sub_type = prop_rtype.get_sub_type() 101 if isinstance(sub_type, rt.Port) and sub_type.get_direction() != "input": 102 continue 103 init.update( ArrayInitData( id_+"."+prop_name, n_dim, sub_type ) ) 104 elif isinstance( prop_rtype, rt.Port ): 105 if prop_rtype.get_direction() == "input": 106 init.update( InPortInitData( id_+"."+prop_name, prop_rtype ) ) 107 elif isinstance( prop_rtype, rt.InterfaceView ): 108 init.update( InterfaceInitData( id_+"."+prop_name, prop_rtype ) ) 109 return init 110 111def ArrayInitData( id_, n_dim, subtype ): 112 if not n_dim: 113 if isinstance( subtype, rt.Port ): 114 return InPortInitData( id_, subtype ) 115 else: 116 return InterfaceInitData( id_, subtype ) 117 else: 118 init = {} 119 for i in range(n_dim[0]): 120 init.update( ArrayInitData(f'{id_}[{i}]', n_dim[1:], subtype) ) 121 return init 122 123#------------------------------------------------------------------------- 124# Hypothesis input data strategies 125#------------------------------------------------------------------------- 126 127@st.composite 128def VectorDataStrategy( draw, dtype ): 129 nbits = dtype.get_length() 130 min_val, max_val = -(2**(nbits-1)), 2**(nbits-1)-1 131 value = draw( st.integers( min_val, max_val ) ) 132 return mk_bits( nbits )( value ) 133 134@st.composite 135def StructDataStrategy( draw, dtype ): 136 data = dtype.get_class()() 137 for field_name, field in dtype.get_all_properties().items(): 138 setattr( data, field_name, draw( DataTypeDataStrategy( field ) ) ) 139 return data 140 141@st.composite 142def PackedArrayDataStrategy( draw, n_dim, sub_dtype ): 143 if not n_dim: 144 return draw( DataTypeDataStrategy( sub_dtype ) ) 145 else: 146 data = [] 147 for i in range(n_dim[0]): 148 data += [ draw(PackedArrayDataStrategy( n_dim[1:], sub_dtype )) ] 149 return data 150 151@st.composite 152def DataTypeDataStrategy( draw, dtype ): 153 if isinstance( dtype, rdt.Vector ): 154 return draw( VectorDataStrategy( dtype ) ) 155 elif isinstance( dtype, rdt.Struct ): 156 return draw( StructDataStrategy( dtype ) ) 157 elif isinstance( dtype, rdt.PackedArray ): 158 n_dim = dtype.get_dim_sizes() 159 sub_dtype = dtype.get_sub_dtype() 160 return draw( PackedArrayDataStrategy( n_dim, sub_dtype ) ) 161 else: 162 assert False, f"unrecognized data type {sub_dtype}!" 163 164@st.composite 165def InPortDataStrategy( draw, id_, port ): 166 return { id_ : draw(DataTypeDataStrategy( port.get_dtype() )) } 167 168@st.composite 169def InterfaceDataStrategy( draw, id_, ifc ): 170 data = {} 171 for prop_name, prop_rtype in ifc.get_all_properties_packed(): 172 if isinstance( prop_rtype, rt.Array ): 173 n_dim = prop_rtype.get_dim_sizes() 174 sub_type = prop_rtype.get_sub_type() 175 if isinstance(sub_type, rt.Port) and sub_type.get_direction() != "input": 176 continue 177 data.update(draw(ArrayDataStrategy(id_+"."+prop_name, n_dim, sub_type))) 178 elif isinstance( prop_rtype, rt.Port ): 179 if prop_rtype.get_direction() == "input": 180 data.update(draw(InPortDataStrategy(id_+"."+prop_name, prop_rtype))) 181 elif isinstance( prop_rtype, rt.InterfaceView ): 182 data.update(draw(InterfaceDataStrategy(id_+"."+prop_name, prop_rtype))) 183 return data 184 185@st.composite 186def ArrayDataStrategy( draw, id_, n_dim, subtype ): 187 if not n_dim: 188 if isinstance( subtype, rt.Port ): 189 return draw(InPortDataStrategy( id_, subtype )) 190 else: 191 return draw(InterfaceDataStrategy( id_, subtype )) 192 else: 193 data = {} 194 for i in range(n_dim[0]): 195 data.update(draw( 196 ArrayDataStrategy(f'{id_}[{i}]', n_dim[1:], subtype))) 197 return data 198 199@st.composite 200def DataStrategy( draw, dut ): 201 """Return a strategy that generates input vector for component `dut`.""" 202 max_cycles = 10 203 204 ret = [] 205 dut.elaborate() 206 rifc = rt.RTLIRGetter(cache=False).get_component_ifc_rtlir( dut ) 207 ports = rifc.get_ports_packed() 208 ifcs = rifc.get_ifc_views_packed() 209 210 # Add reset cycle at the beginning 211 reset1, reset2 = {}, {} 212 for id_, port in ports: 213 if id_ == "clk": 214 reset1.update( { id_ : Bits1(0) } ) 215 reset2.update( { id_ : Bits1(1) } ) 216 elif id_ == "reset": 217 reset1.update( { id_ : Bits1(1) } ) 218 reset2.update( { id_ : Bits1(1) } ) 219 else: 220 n_dim, port_rtype = flatten( port ) 221 if port_rtype.get_direction() == "input": 222 if n_dim: 223 reset1.update( ArrayInitData( id_, n_dim, port_rtype ) ) 224 reset2.update( ArrayInitData( id_, n_dim, port_rtype ) ) 225 else: 226 reset1.update( InPortInitData( id_, port_rtype ) ) 227 reset2.update( InPortInitData( id_, port_rtype ) ) 228 for id_, ifc in ifcs: 229 n_dim, ifc_rtype = flatten( ifc ) 230 if n_dim: 231 reset1.update( ArrayDataStrategy( id_, n_dim, ifc_rtype ) ) 232 reset2.update( ArrayDataStrategy( id_, n_dim, ifc_rtype ) ) 233 else: 234 reset1.update( InterfaceInitData( id_, n_dim, ifc_rtype ) ) 235 reset2.update( InterfaceInitData( id_, n_dim, ifc_rtype ) ) 236 237 ret.append( reset1 ) 238 ret.append( reset2 ) 239 240 for i in range(max_cycles): 241 data = {} 242 for id_, port in ports: 243 if id_ in [ "clk", "reset" ]: 244 data.update( { id_ : Bits1(0) } ) 245 else: 246 n_dim, port_rtype = flatten( port ) 247 if n_dim: 248 if port_rtype.get_direction() == "input": 249 data.update(draw( ArrayDataStrategy( id_, n_dim, port_rtype ) )) 250 elif port_rtype.get_direction() == "input": 251 data.update(draw( InPortDataStrategy( id_, port_rtype ) )) 252 for id_, ifc in ifcs: 253 n_dim, ifc_rtype = flatten( ifc ) 254 if n_dim: 255 data.update(draw( ArrayDataStrategy( id_, n_dim, ifc_rtype ) )) 256 else: 257 data.update(draw( InterfaceDataStrategy( id_, ifc_rtype ) )) 258 259 # Toggle clock signal 260 toggle_data = {} 261 for id_, signal in data.items(): 262 if id_ == "clk": 263 toggle_data.update( { id_ : Bits1(1) } ) 264 else: 265 toggle_data.update( { id_ : copy.deepcopy( signal ) } ) 266 267 ret.append( data ) 268 ret.append( toggle_data ) 269 270 return ret 271 272#------------------------------------------------------------------------- 273# closed_loop_component_input_test 274#------------------------------------------------------------------------- 275 276def closed_loop_component_input_test( dut, test_vector, tv_in, backend = "verilog" ): 277 278 # Filter to collect all output ports of a component 279 def outport_filter( obj ): 280 return isinstance( obj, OutPort ) 281 282 assert backend in [ "verilog", "yosys" ], f"invalid backend {backend}!" 283 284 dut.elaborate() 285 reference_output = deque() 286 all_output_ports = dut.get_local_object_filter( outport_filter ) 287 288 # Method to record reference outputs of the pure python component 289 def ref_tv_out( model, test_vector ): 290 dct = {} 291 for out_port in all_output_ports: 292 dct[ out_port ] = eval( "model." + out_port._dsl.my_name ).clone() # WE NEED TO CLONE NOW 293 reference_output.append( dct ) 294 295 # Method to compare the outputs of the imported model and the pure python one 296 def tv_out( model, test_vector ): 297 assert len(reference_output) > 0, \ 298 "Reference runs for fewer cycles than the imported model!" 299 for out_port in all_output_ports: 300 ref = reference_output[0][out_port] 301 imp = eval( "model." + out_port._dsl.my_name ) 302 assert ref == imp, f"Value mismatch: reference: {ref}, imported: {imp}" 303 reference_output.popleft() 304 305 # First simulate the pure python component to see if it has sane behavior 306 reference_sim = TestVectorSimulator( dut, test_vector, tv_in, ref_tv_out ) 307 reference_sim.run_test() 308 dut.unlock_simulation() 309 310 # If it simulates correctly, translate it and import it back 311 dut.elaborate() 312 if backend == "verilog": 313 dut.set_metadata( VerilogTranslationImportPass.enable, True ) 314 imported_obj = VerilogTranslationImportPass()( dut ) 315 elif backend == "yosys": 316 dut.set_metadata( YosysTranslationImportPass.enable, True ) 317 imported_obj = YosysTranslationImportPass()( dut ) 318 319 # Run another vector simulator spin 320 imported_sim = TestVectorSimulator( imported_obj, test_vector, tv_in, tv_out ) 321 imported_sim.run_test() 322 323#------------------------------------------------------------------------- 324# closed_loop_component_test 325#------------------------------------------------------------------------- 326 327def closed_loop_component_test( dut, data, backend = "verilog" ): 328 """Test the DUT with the given test_vector. 329 330 User who wish to use this method should pass in the hypothesis data 331 strategy instance as `data`. This method will reflect on the interfaces 332 and ports of the given DUT and generate input vector. 333 """ 334 # Method to feed data into the DUT 335 def tv_in( model, test_vector ): 336 for name, data in test_vector.items(): 337 # `setattr` fails to set the correct value of an array if indexed by 338 # a subscript. We use `exec` here to make sure the value of elements 339 # are assigned correctly. 340 exec( "model." + name + " @= data" ) 341 test_vector = data.draw( DataStrategy( dut ) ) 342 closed_loop_component_input_test( dut, test_vector, tv_in, backend ) 343 344#------------------------------------------------------------------------- 345# closed_loop_test 346#------------------------------------------------------------------------- 347 348# TODO: A hypothesis test that works on generated test component AND 349# generated input data. 350def closed_loop_test(): 351 pass 352