1"""
2========================================================================
3Test sinks
4========================================================================
5Test sinks with CL and RTL interfaces.
6
7Author : Yanghui Ou
8  Date : Mar 11, 2019
9"""
10
11from pymtl3 import *
12from pymtl3.stdlib.ifcs import RecvIfcRTL, RecvRTL2SendCL
13
14
15class PyMTLTestSinkError( Exception ): pass
16
17#-------------------------------------------------------------------------
18# TestSinkCL
19#-------------------------------------------------------------------------
20
21class TestSinkCL( Component ):
22
23  def construct( s, Type, msgs, initial_delay=0, interval_delay=0,
24                 arrival_time=None, cmp_fn=lambda a, b : a == b ):
25
26    s.recv.Type = Type
27
28    # [msgs] and [arrival_time] must have the same length.
29    if arrival_time is not None:
30      assert len( msgs ) == len( arrival_time )
31
32    s.idx          = 0
33    s.cycle_count  = 0
34    s.msgs         = list( msgs )
35    s.arrival_time = None if not arrival_time else list( arrival_time )
36    s.cmp_fn       = cmp_fn
37    s.error_msg    = ''
38
39    s.all_msg_recved = False
40    s.done_flag      = False
41
42    s.count = initial_delay
43    s.intv  = interval_delay
44
45    s.recv_called = False
46
47    @update_once
48    def up_sink_count():
49      # Raise exception at the start of next cycle so that the errored
50      # line trace gets printed out
51      if s.error_msg:
52        raise PyMTLTestSinkError( s.error_msg )
53
54      # Tick one more cycle after all message is received so that the
55      # exception gets thrown
56      if s.all_msg_recved:
57        s.done_flag = True
58
59      if s.idx >= len( s.msgs ):
60        s.all_msg_recved = True
61
62      if not s.reset:
63        s.cycle_count += 1
64      else:
65        s.cycle_count = 0
66
67      # if recv was called in previous cycle
68      if s.recv_called:
69        s.count = s.intv
70      elif s.count != 0:
71        s.count -= 1
72      else:
73        s.count = 0
74
75      s.recv_called = False
76
77    s.add_constraints(
78      U( up_sink_count ) < M( s.recv ),
79      U( up_sink_count ) < M( s.recv.rdy )
80    )
81
82  @non_blocking( lambda s: s.count==0 )
83  def recv( s, msg ):
84    assert s.count == 0, "Invalid en/rdy transaction! Sink is stalled (not ready), but receives a message."
85
86    # Sanity check
87    if s.idx >= len( s.msgs ):
88      s.error_msg = ( 'Test Sink received more msgs than expected!\n'
89                      f'Received : {msg}' )
90
91    # Check correctness first
92    elif not s.cmp_fn( msg, s.msgs[ s.idx ] ):
93      s.error_msg = (
94        f'Test sink {s} received WRONG message!\n'
95        f'Expected : { s.msgs[ s.idx ] }\n'
96        f'Received : { msg }'
97      )
98
99    # Check timing if performance regeression is turned on
100    elif s.arrival_time and s.cycle_count > s.arrival_time[ s.idx ]:
101      s.error_msg = (
102        f'Test sink {s} received message LATER than expected!\n'
103        f'Expected msg : {s.msgs[ s.idx ]}\n'
104        f'Expected at  : {s.arrival_time[ s.idx ]}\n'
105        f'Received msg : {msg}\n'
106        f'Received at  : {s.cycle_count}'
107      )
108
109    else:
110      s.idx += 1
111      s.recv_called = True
112
113  def done( s ):
114    return s.done_flag
115
116  # Line trace
117  def line_trace( s ):
118    return "{}".format( s.recv )
119
120#-------------------------------------------------------------------------
121# TestSinkRTL
122#-------------------------------------------------------------------------
123
124class TestSinkRTL( Component ):
125
126  def construct( s, Type, msgs, initial_delay=0, interval_delay=0,
127                 arrival_time=None, cmp_fn=lambda a, b : a == b ):
128
129    # Interface
130
131    s.recv = RecvIfcRTL( Type )
132
133    # Components
134
135    s.sink    = TestSinkCL( Type, msgs, initial_delay, interval_delay,
136                            arrival_time, cmp_fn )
137    s.adapter = RecvRTL2SendCL( Type )
138
139    connect( s.recv,         s.adapter.recv )
140    connect( s.adapter.send, s.sink.recv    )
141
142  def done( s ):
143    return s.sink.done()
144
145  # Line trace
146
147  def line_trace( s ):
148    return "{}".format( s.recv )
149