1#! /usr/bin/python3
2from pyln.proto.message import MessageNamespace, Message
3import pytest
4import io
5
6
7def test_fundamental():
8    ns = MessageNamespace()
9    ns.load_csv(['msgtype,test,1',
10                 'msgdata,test,test_byte,byte,',
11                 'msgdata,test,test_u16,u16,',
12                 'msgdata,test,test_u32,u32,',
13                 'msgdata,test,test_u64,u64,',
14                 'msgdata,test,test_chain_hash,chain_hash,',
15                 'msgdata,test,test_channel_id,channel_id,',
16                 'msgdata,test,test_sha256,sha256,',
17                 'msgdata,test,test_signature,signature,',
18                 'msgdata,test,test_point,point,',
19                 'msgdata,test,test_short_channel_id,short_channel_id,',
20                 ])
21
22    mstr = """test
23 test_byte=255
24 test_u16=65535
25 test_u32=4294967295
26 test_u64=18446744073709551615
27 test_chain_hash=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20
28 test_channel_id=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20
29 test_sha256=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20
30 test_signature=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f40
31 test_point=0201030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f2021
32 test_short_channel_id=1x2x3"""
33    m = Message.from_str(ns, mstr)
34
35    # Same (ignoring whitespace differences)
36    assert m.to_str().split() == mstr.split()
37
38
39def test_static_array():
40    ns = MessageNamespace()
41    ns.load_csv(['msgtype,test1,1',
42                 'msgdata,test1,test_arr,byte,4'])
43    ns.load_csv(['msgtype,test2,2',
44                 'msgdata,test2,test_arr,short_channel_id,4'])
45
46    for test in [["test1 test_arr=00010203", bytes([0, 1] + [0, 1, 2, 3])],
47                 ["test2 test_arr=[0x1x2,4x5x6,7x8x9,10x11x12]",
48                  bytes([0, 2]
49                        + [0, 0, 0, 0, 0, 1, 0, 2]
50                        + [0, 0, 4, 0, 0, 5, 0, 6]
51                        + [0, 0, 7, 0, 0, 8, 0, 9]
52                        + [0, 0, 10, 0, 0, 11, 0, 12])]]:
53        m = Message.from_str(ns, test[0])
54        assert m.to_str() == test[0]
55        buf = io.BytesIO()
56        m.write(buf)
57        assert buf.getvalue() == test[1]
58        assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
59
60
61def test_subtype():
62    ns = MessageNamespace()
63    ns.load_csv(['msgtype,test1,1',
64                 'msgdata,test1,test_sub,channel_update_timestamps,4',
65                 'subtype,channel_update_timestamps',
66                 'subtypedata,'
67                 + 'channel_update_timestamps,timestamp_node_id_1,u32,',
68                 'subtypedata,'
69                 + 'channel_update_timestamps,timestamp_node_id_2,u32,'])
70
71    for test in [["test1 test_sub=["
72                  "{timestamp_node_id_1=1,timestamp_node_id_2=2}"
73                  ",{timestamp_node_id_1=3,timestamp_node_id_2=4}"
74                  ",{timestamp_node_id_1=5,timestamp_node_id_2=6}"
75                  ",{timestamp_node_id_1=7,timestamp_node_id_2=8}]",
76                  bytes([0, 1]
77                        + [0, 0, 0, 1, 0, 0, 0, 2]
78                        + [0, 0, 0, 3, 0, 0, 0, 4]
79                        + [0, 0, 0, 5, 0, 0, 0, 6]
80                        + [0, 0, 0, 7, 0, 0, 0, 8])]]:
81        m = Message.from_str(ns, test[0])
82        assert m.to_str() == test[0]
83        buf = io.BytesIO()
84        m.write(buf)
85        assert buf.getvalue() == test[1]
86        assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
87
88    # Test missing field logic.
89    m = Message.from_str(ns, "test1", incomplete_ok=True)
90    assert m.missing_fields()
91
92
93def test_subtype_array():
94    ns = MessageNamespace()
95    ns.load_csv(['msgtype,tx_signatures,1',
96                 'msgdata,tx_signatures,num_witnesses,u16,',
97                 'msgdata,tx_signatures,witness_stack,witness_stack,num_witnesses',
98                 'subtype,witness_stack',
99                 'subtypedata,witness_stack,num_input_witness,u16,',
100                 'subtypedata,witness_stack,witness_element,witness_element,num_input_witness',
101                 'subtype,witness_element',
102                 'subtypedata,witness_element,len,u16,',
103                 'subtypedata,witness_element,witness,byte,len'])
104
105    for test in [["tx_signatures witness_stack="
106                 "[{witness_element=[{witness=3045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01},{witness=02d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b}]}]",
107                  bytes.fromhex('00010001000200483045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01002102d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b')]]:
108        m = Message.from_str(ns, test[0])
109        assert m.to_str() == test[0]
110        buf = io.BytesIO()
111        m.write(buf)
112        assert buf.getvalue().hex() == test[1].hex()
113        assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
114
115
116def test_tlv():
117    ns = MessageNamespace()
118    ns.load_csv(['msgtype,test1,1',
119                 'msgdata,test1,tlvs,test_tlvstream,',
120                 'tlvtype,test_tlvstream,tlv1,1',
121                 'tlvdata,test_tlvstream,tlv1,field1,byte,4',
122                 'tlvdata,test_tlvstream,tlv1,field2,u32,',
123                 'tlvtype,test_tlvstream,tlv2,255',
124                 'tlvdata,test_tlvstream,tlv2,field3,byte,...'])
125
126    for test in [["test1 tlvs={tlv1={field1=01020304,field2=5}}",
127                  bytes([0, 1]
128                        + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5])],
129                 ["test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304}}",
130                  bytes([0, 1]
131                        + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
132                        + [253, 0, 255, 4, 1, 2, 3, 4])],
133                 ["test1 tlvs={tlv1={field1=01020304,field2=5},4=010203,tlv2={field3=01020304}}",
134                  bytes([0, 1]
135                        + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
136                        + [4, 3, 1, 2, 3]
137                        + [253, 0, 255, 4, 1, 2, 3, 4])]]:
138        m = Message.from_str(ns, test[0])
139        assert m.to_str() == test[0]
140        buf = io.BytesIO()
141        m.write(buf)
142        assert buf.getvalue() == test[1]
143        assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0]
144
145    # Ordering test (turns into canonical ordering)
146    m = Message.from_str(ns, 'test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304},4=010203}')
147    buf = io.BytesIO()
148    m.write(buf)
149    assert buf.getvalue() == bytes([0, 1]
150                                   + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
151                                   + [4, 3, 1, 2, 3]
152                                   + [253, 0, 255, 4, 1, 2, 3, 4])
153
154
155def test_tlv_complex():
156    # A real example from the spec.
157    ns = MessageNamespace(["msgtype,reply_channel_range,264,gossip_queries",
158                           "msgdata,reply_channel_range,chain_hash,chain_hash,",
159                           "msgdata,reply_channel_range,first_blocknum,u32,",
160                           "msgdata,reply_channel_range,number_of_blocks,u32,",
161                           "msgdata,reply_channel_range,full_information,byte,",
162                           "msgdata,reply_channel_range,len,u16,",
163                           "msgdata,reply_channel_range,encoded_short_ids,byte,len",
164                           "msgdata,reply_channel_range,tlvs,reply_channel_range_tlvs,",
165                           "tlvtype,reply_channel_range_tlvs,timestamps_tlv,1",
166                           "tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,byte,",
167                           "tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoded_timestamps,byte,...",
168                           "tlvtype,reply_channel_range_tlvs,checksums_tlv,3",
169                           "tlvdata,reply_channel_range_tlvs,checksums_tlv,checksums,channel_update_checksums,...",
170                           "subtype,channel_update_timestamps",
171                           "subtypedata,channel_update_timestamps,timestamp_node_id_1,u32,",
172                           "subtypedata,channel_update_timestamps,timestamp_node_id_2,u32,",
173                           "subtype,channel_update_checksums",
174                           "subtypedata,channel_update_checksums,checksum_node_id_1,u32,",
175                           "subtypedata,channel_update_checksums,checksum_node_id_2,u32,"])
176
177    binmsg = bytes.fromhex('010806226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f000000670000000701001100000067000001000000006d000001000003101112fa300000000022d7a4a79bece840')
178    msg = Message.read(ns, io.BytesIO(binmsg))
179    buf = io.BytesIO()
180    msg.write(buf)
181    assert buf.getvalue() == binmsg
182
183
184def test_message_constructor():
185    ns = MessageNamespace(['msgtype,test1,1',
186                           'msgdata,test1,tlvs,test_tlvstream,',
187                           'tlvtype,test_tlvstream,tlv1,1',
188                           'tlvdata,test_tlvstream,tlv1,field1,byte,4',
189                           'tlvdata,test_tlvstream,tlv1,field2,u32,',
190                           'tlvtype,test_tlvstream,tlv2,255',
191                           'tlvdata,test_tlvstream,tlv2,field3,byte,...'])
192
193    m = Message(ns.get_msgtype('test1'),
194                tlvs='{tlv1={field1=01020304,field2=5}'
195                ',tlv2={field3=01020304},4=010203}')
196    buf = io.BytesIO()
197    m.write(buf)
198    assert buf.getvalue() == bytes([0, 1]
199                                   + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5]
200                                   + [4, 3, 1, 2, 3]
201                                   + [253, 0, 255, 4, 1, 2, 3, 4])
202
203
204def test_dynamic_array():
205    """Test that dynamic array types enforce matching lengths"""
206    ns = MessageNamespace(['msgtype,test1,1',
207                           'msgdata,test1,count,u16,',
208                           'msgdata,test1,arr1,byte,count',
209                           'msgdata,test1,arr2,u32,count'])
210
211    # This one is fine.
212    m = Message(ns.get_msgtype('test1'),
213                arr1='01020304', arr2='[1,2,3,4]')
214    buf = io.BytesIO()
215    m.write(buf)
216    assert buf.getvalue() == bytes([0, 1]
217                                   + [0, 4]
218                                   + [1, 2, 3, 4]
219                                   + [0, 0, 0, 1,
220                                      0, 0, 0, 2,
221                                      0, 0, 0, 3,
222                                      0, 0, 0, 4])
223
224    # These ones are not
225    with pytest.raises(ValueError, match='Inconsistent length.*count'):
226        m = Message(ns.get_msgtype('test1'),
227                    arr1='01020304', arr2='[1,2,3]')
228
229    with pytest.raises(ValueError, match='Inconsistent length.*count'):
230        m = Message(ns.get_msgtype('test1'),
231                    arr1='01020304', arr2='[1,2,3,4,5]')
232