1#! /usr/bin/python3 2from pyln.proto.message import MessageNamespace, Message 3import pytest 4import io 5 6 7def test_fundamental(): 8 ns = MessageNamespace() 9 ns.load_csv(['msgtype,test,1', 10 'msgdata,test,test_byte,byte,', 11 'msgdata,test,test_u16,u16,', 12 'msgdata,test,test_u32,u32,', 13 'msgdata,test,test_u64,u64,', 14 'msgdata,test,test_chain_hash,chain_hash,', 15 'msgdata,test,test_channel_id,channel_id,', 16 'msgdata,test,test_sha256,sha256,', 17 'msgdata,test,test_signature,signature,', 18 'msgdata,test,test_point,point,', 19 'msgdata,test,test_short_channel_id,short_channel_id,', 20 ]) 21 22 mstr = """test 23 test_byte=255 24 test_u16=65535 25 test_u32=4294967295 26 test_u64=18446744073709551615 27 test_chain_hash=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20 28 test_channel_id=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20 29 test_sha256=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20 30 test_signature=0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f40 31 test_point=0201030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f2021 32 test_short_channel_id=1x2x3""" 33 m = Message.from_str(ns, mstr) 34 35 # Same (ignoring whitespace differences) 36 assert m.to_str().split() == mstr.split() 37 38 39def test_static_array(): 40 ns = MessageNamespace() 41 ns.load_csv(['msgtype,test1,1', 42 'msgdata,test1,test_arr,byte,4']) 43 ns.load_csv(['msgtype,test2,2', 44 'msgdata,test2,test_arr,short_channel_id,4']) 45 46 for test in [["test1 test_arr=00010203", bytes([0, 1] + [0, 1, 2, 3])], 47 ["test2 test_arr=[0x1x2,4x5x6,7x8x9,10x11x12]", 48 bytes([0, 2] 49 + [0, 0, 0, 0, 0, 1, 0, 2] 50 + [0, 0, 4, 0, 0, 5, 0, 6] 51 + [0, 0, 7, 0, 0, 8, 0, 9] 52 + [0, 0, 10, 0, 0, 11, 0, 12])]]: 53 m = Message.from_str(ns, test[0]) 54 assert m.to_str() == test[0] 55 buf = io.BytesIO() 56 m.write(buf) 57 assert buf.getvalue() == test[1] 58 assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] 59 60 61def test_subtype(): 62 ns = MessageNamespace() 63 ns.load_csv(['msgtype,test1,1', 64 'msgdata,test1,test_sub,channel_update_timestamps,4', 65 'subtype,channel_update_timestamps', 66 'subtypedata,' 67 + 'channel_update_timestamps,timestamp_node_id_1,u32,', 68 'subtypedata,' 69 + 'channel_update_timestamps,timestamp_node_id_2,u32,']) 70 71 for test in [["test1 test_sub=[" 72 "{timestamp_node_id_1=1,timestamp_node_id_2=2}" 73 ",{timestamp_node_id_1=3,timestamp_node_id_2=4}" 74 ",{timestamp_node_id_1=5,timestamp_node_id_2=6}" 75 ",{timestamp_node_id_1=7,timestamp_node_id_2=8}]", 76 bytes([0, 1] 77 + [0, 0, 0, 1, 0, 0, 0, 2] 78 + [0, 0, 0, 3, 0, 0, 0, 4] 79 + [0, 0, 0, 5, 0, 0, 0, 6] 80 + [0, 0, 0, 7, 0, 0, 0, 8])]]: 81 m = Message.from_str(ns, test[0]) 82 assert m.to_str() == test[0] 83 buf = io.BytesIO() 84 m.write(buf) 85 assert buf.getvalue() == test[1] 86 assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] 87 88 # Test missing field logic. 89 m = Message.from_str(ns, "test1", incomplete_ok=True) 90 assert m.missing_fields() 91 92 93def test_subtype_array(): 94 ns = MessageNamespace() 95 ns.load_csv(['msgtype,tx_signatures,1', 96 'msgdata,tx_signatures,num_witnesses,u16,', 97 'msgdata,tx_signatures,witness_stack,witness_stack,num_witnesses', 98 'subtype,witness_stack', 99 'subtypedata,witness_stack,num_input_witness,u16,', 100 'subtypedata,witness_stack,witness_element,witness_element,num_input_witness', 101 'subtype,witness_element', 102 'subtypedata,witness_element,len,u16,', 103 'subtypedata,witness_element,witness,byte,len']) 104 105 for test in [["tx_signatures witness_stack=" 106 "[{witness_element=[{witness=3045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01},{witness=02d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b}]}]", 107 bytes.fromhex('00010001000200483045022100ac0fdee3e157f50be3214288cb7f11b03ce33e13b39dadccfcdb1a174fd3729a02200b69b286ac9f0fc5c51f9f04ae5a9827ac11d384cc203a0eaddff37e8d15c1ac01002102d6a3c2d0cf7904ab6af54d7c959435a452b24a63194e1c4e7c337d3ebbb3017b')]]: 108 m = Message.from_str(ns, test[0]) 109 assert m.to_str() == test[0] 110 buf = io.BytesIO() 111 m.write(buf) 112 assert buf.getvalue().hex() == test[1].hex() 113 assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] 114 115 116def test_tlv(): 117 ns = MessageNamespace() 118 ns.load_csv(['msgtype,test1,1', 119 'msgdata,test1,tlvs,test_tlvstream,', 120 'tlvtype,test_tlvstream,tlv1,1', 121 'tlvdata,test_tlvstream,tlv1,field1,byte,4', 122 'tlvdata,test_tlvstream,tlv1,field2,u32,', 123 'tlvtype,test_tlvstream,tlv2,255', 124 'tlvdata,test_tlvstream,tlv2,field3,byte,...']) 125 126 for test in [["test1 tlvs={tlv1={field1=01020304,field2=5}}", 127 bytes([0, 1] 128 + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5])], 129 ["test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304}}", 130 bytes([0, 1] 131 + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] 132 + [253, 0, 255, 4, 1, 2, 3, 4])], 133 ["test1 tlvs={tlv1={field1=01020304,field2=5},4=010203,tlv2={field3=01020304}}", 134 bytes([0, 1] 135 + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] 136 + [4, 3, 1, 2, 3] 137 + [253, 0, 255, 4, 1, 2, 3, 4])]]: 138 m = Message.from_str(ns, test[0]) 139 assert m.to_str() == test[0] 140 buf = io.BytesIO() 141 m.write(buf) 142 assert buf.getvalue() == test[1] 143 assert Message.read(ns, io.BytesIO(test[1])).to_str() == test[0] 144 145 # Ordering test (turns into canonical ordering) 146 m = Message.from_str(ns, 'test1 tlvs={tlv1={field1=01020304,field2=5},tlv2={field3=01020304},4=010203}') 147 buf = io.BytesIO() 148 m.write(buf) 149 assert buf.getvalue() == bytes([0, 1] 150 + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] 151 + [4, 3, 1, 2, 3] 152 + [253, 0, 255, 4, 1, 2, 3, 4]) 153 154 155def test_tlv_complex(): 156 # A real example from the spec. 157 ns = MessageNamespace(["msgtype,reply_channel_range,264,gossip_queries", 158 "msgdata,reply_channel_range,chain_hash,chain_hash,", 159 "msgdata,reply_channel_range,first_blocknum,u32,", 160 "msgdata,reply_channel_range,number_of_blocks,u32,", 161 "msgdata,reply_channel_range,full_information,byte,", 162 "msgdata,reply_channel_range,len,u16,", 163 "msgdata,reply_channel_range,encoded_short_ids,byte,len", 164 "msgdata,reply_channel_range,tlvs,reply_channel_range_tlvs,", 165 "tlvtype,reply_channel_range_tlvs,timestamps_tlv,1", 166 "tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoding_type,byte,", 167 "tlvdata,reply_channel_range_tlvs,timestamps_tlv,encoded_timestamps,byte,...", 168 "tlvtype,reply_channel_range_tlvs,checksums_tlv,3", 169 "tlvdata,reply_channel_range_tlvs,checksums_tlv,checksums,channel_update_checksums,...", 170 "subtype,channel_update_timestamps", 171 "subtypedata,channel_update_timestamps,timestamp_node_id_1,u32,", 172 "subtypedata,channel_update_timestamps,timestamp_node_id_2,u32,", 173 "subtype,channel_update_checksums", 174 "subtypedata,channel_update_checksums,checksum_node_id_1,u32,", 175 "subtypedata,channel_update_checksums,checksum_node_id_2,u32,"]) 176 177 binmsg = bytes.fromhex('010806226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f000000670000000701001100000067000001000000006d000001000003101112fa300000000022d7a4a79bece840') 178 msg = Message.read(ns, io.BytesIO(binmsg)) 179 buf = io.BytesIO() 180 msg.write(buf) 181 assert buf.getvalue() == binmsg 182 183 184def test_message_constructor(): 185 ns = MessageNamespace(['msgtype,test1,1', 186 'msgdata,test1,tlvs,test_tlvstream,', 187 'tlvtype,test_tlvstream,tlv1,1', 188 'tlvdata,test_tlvstream,tlv1,field1,byte,4', 189 'tlvdata,test_tlvstream,tlv1,field2,u32,', 190 'tlvtype,test_tlvstream,tlv2,255', 191 'tlvdata,test_tlvstream,tlv2,field3,byte,...']) 192 193 m = Message(ns.get_msgtype('test1'), 194 tlvs='{tlv1={field1=01020304,field2=5}' 195 ',tlv2={field3=01020304},4=010203}') 196 buf = io.BytesIO() 197 m.write(buf) 198 assert buf.getvalue() == bytes([0, 1] 199 + [1, 8, 1, 2, 3, 4, 0, 0, 0, 5] 200 + [4, 3, 1, 2, 3] 201 + [253, 0, 255, 4, 1, 2, 3, 4]) 202 203 204def test_dynamic_array(): 205 """Test that dynamic array types enforce matching lengths""" 206 ns = MessageNamespace(['msgtype,test1,1', 207 'msgdata,test1,count,u16,', 208 'msgdata,test1,arr1,byte,count', 209 'msgdata,test1,arr2,u32,count']) 210 211 # This one is fine. 212 m = Message(ns.get_msgtype('test1'), 213 arr1='01020304', arr2='[1,2,3,4]') 214 buf = io.BytesIO() 215 m.write(buf) 216 assert buf.getvalue() == bytes([0, 1] 217 + [0, 4] 218 + [1, 2, 3, 4] 219 + [0, 0, 0, 1, 220 0, 0, 0, 2, 221 0, 0, 0, 3, 222 0, 0, 0, 4]) 223 224 # These ones are not 225 with pytest.raises(ValueError, match='Inconsistent length.*count'): 226 m = Message(ns.get_msgtype('test1'), 227 arr1='01020304', arr2='[1,2,3]') 228 229 with pytest.raises(ValueError, match='Inconsistent length.*count'): 230 m = Message(ns.get_msgtype('test1'), 231 arr1='01020304', arr2='[1,2,3,4,5]') 232