1# -*- coding: utf-8 -*- 2# 3# Electrum - lightweight Bitcoin client 4# Copyright (C) 2018 The Electrum developers 5# 6# Permission is hereby granted, free of charge, to any person 7# obtaining a copy of this software and associated documentation files 8# (the "Software"), to deal in the Software without restriction, 9# including without limitation the rights to use, copy, modify, merge, 10# publish, distribute, sublicense, and/or sell copies of the Software, 11# and to permit persons to whom the Software is furnished to do so, 12# subject to the following conditions: 13# 14# The above copyright notice and this permission notice shall be 15# included in all copies or substantial portions of the Software. 16# 17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 21# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 22# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 23# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24# SOFTWARE. 25 26import time 27import random 28import os 29from collections import defaultdict 30from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set 31import binascii 32import base64 33import asyncio 34import threading 35from enum import IntEnum 36 37from aiorpcx import NetAddress 38 39from .sql_db import SqlDB, sql 40from . import constants, util 41from .util import bh2u, profiler, get_headers_dir, is_ip_address, json_normalize 42from .logging import Logger 43from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID, 44 validate_features, IncompatibleOrInsaneFeatures, InvalidGossipMsg) 45from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update 46from .lnmsg import decode_msg 47from . import ecc 48from .crypto import sha256d 49 50if TYPE_CHECKING: 51 from .network import Network 52 from .lnchannel import Channel 53 from .lnrouter import RouteEdge 54 55 56FLAG_DISABLE = 1 << 1 57FLAG_DIRECTION = 1 << 0 58 59 60class ChannelInfo(NamedTuple): 61 short_channel_id: ShortChannelID 62 node1_id: bytes 63 node2_id: bytes 64 capacity_sat: Optional[int] 65 66 @staticmethod 67 def from_msg(payload: dict) -> 'ChannelInfo': 68 features = int.from_bytes(payload['features'], 'big') 69 validate_features(features) 70 channel_id = payload['short_channel_id'] 71 node_id_1 = payload['node_id_1'] 72 node_id_2 = payload['node_id_2'] 73 assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] 74 capacity_sat = None 75 return ChannelInfo( 76 short_channel_id = ShortChannelID.normalize(channel_id), 77 node1_id = node_id_1, 78 node2_id = node_id_2, 79 capacity_sat = capacity_sat 80 ) 81 82 @staticmethod 83 def from_raw_msg(raw: bytes) -> 'ChannelInfo': 84 payload_dict = decode_msg(raw)[1] 85 return ChannelInfo.from_msg(payload_dict) 86 87 @staticmethod 88 def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo': 89 node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node]) 90 return ChannelInfo( 91 short_channel_id=route_edge.short_channel_id, 92 node1_id=node1_id, 93 node2_id=node2_id, 94 capacity_sat=None, 95 ) 96 97 98class Policy(NamedTuple): 99 key: bytes 100 cltv_expiry_delta: int 101 htlc_minimum_msat: int 102 htlc_maximum_msat: Optional[int] 103 fee_base_msat: int 104 fee_proportional_millionths: int 105 channel_flags: int 106 message_flags: int 107 timestamp: int 108 109 @staticmethod 110 def from_msg(payload: dict) -> 'Policy': 111 return Policy( 112 key = payload['short_channel_id'] + payload['start_node'], 113 cltv_expiry_delta = payload['cltv_expiry_delta'], 114 htlc_minimum_msat = payload['htlc_minimum_msat'], 115 htlc_maximum_msat = payload.get('htlc_maximum_msat', None), 116 fee_base_msat = payload['fee_base_msat'], 117 fee_proportional_millionths = payload['fee_proportional_millionths'], 118 message_flags = int.from_bytes(payload['message_flags'], "big"), 119 channel_flags = int.from_bytes(payload['channel_flags'], "big"), 120 timestamp = payload['timestamp'], 121 ) 122 123 @staticmethod 124 def from_raw_msg(key:bytes, raw: bytes) -> 'Policy': 125 payload = decode_msg(raw)[1] 126 payload['start_node'] = key[8:] 127 return Policy.from_msg(payload) 128 129 @staticmethod 130 def from_route_edge(route_edge: 'RouteEdge') -> 'Policy': 131 return Policy( 132 key=route_edge.short_channel_id + route_edge.start_node, 133 cltv_expiry_delta=route_edge.cltv_expiry_delta, 134 htlc_minimum_msat=0, 135 htlc_maximum_msat=None, 136 fee_base_msat=route_edge.fee_base_msat, 137 fee_proportional_millionths=route_edge.fee_proportional_millionths, 138 channel_flags=0, 139 message_flags=0, 140 timestamp=0, 141 ) 142 143 def is_disabled(self): 144 return self.channel_flags & FLAG_DISABLE 145 146 @property 147 def short_channel_id(self) -> ShortChannelID: 148 return ShortChannelID.normalize(self.key[0:8]) 149 150 @property 151 def start_node(self) -> bytes: 152 return self.key[8:] 153 154 155class NodeInfo(NamedTuple): 156 node_id: bytes 157 features: int 158 timestamp: int 159 alias: str 160 161 @staticmethod 162 def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: 163 node_id = payload['node_id'] 164 features = int.from_bytes(payload['features'], "big") 165 validate_features(features) 166 addresses = NodeInfo.parse_addresses_field(payload['addresses']) 167 peer_addrs = [] 168 for host, port in addresses: 169 try: 170 peer_addrs.append(LNPeerAddr(host=host, port=port, pubkey=node_id)) 171 except ValueError: 172 pass 173 alias = payload['alias'].rstrip(b'\x00') 174 try: 175 alias = alias.decode('utf8') 176 except: 177 alias = '' 178 timestamp = payload['timestamp'] 179 node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias) 180 return node_info, peer_addrs 181 182 @staticmethod 183 def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]: 184 payload_dict = decode_msg(raw)[1] 185 return NodeInfo.from_msg(payload_dict) 186 187 @staticmethod 188 def parse_addresses_field(addresses_field): 189 buf = addresses_field 190 def read(n): 191 nonlocal buf 192 data, buf = buf[0:n], buf[n:] 193 return data 194 addresses = [] 195 while buf: 196 atype = ord(read(1)) 197 if atype == 0: 198 pass 199 elif atype == 1: # IPv4 200 ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4))) 201 port = int.from_bytes(read(2), 'big') 202 if is_ip_address(ipv4_addr) and port != 0: 203 addresses.append((ipv4_addr, port)) 204 elif atype == 2: # IPv6 205 ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)]) 206 ipv6_addr = ipv6_addr.decode('ascii') 207 port = int.from_bytes(read(2), 'big') 208 if is_ip_address(ipv6_addr) and port != 0: 209 addresses.append((ipv6_addr, port)) 210 elif atype == 3: # onion v2 211 host = base64.b32encode(read(10)) + b'.onion' 212 host = host.decode('ascii').lower() 213 port = int.from_bytes(read(2), 'big') 214 addresses.append((host, port)) 215 elif atype == 4: # onion v3 216 host = base64.b32encode(read(35)) + b'.onion' 217 host = host.decode('ascii').lower() 218 port = int.from_bytes(read(2), 'big') 219 addresses.append((host, port)) 220 else: 221 # unknown address type 222 # we don't know how long it is -> have to escape 223 # if there are other addresses we could have parsed later, they are lost. 224 break 225 return addresses 226 227 228class UpdateStatus(IntEnum): 229 ORPHANED = 0 230 EXPIRED = 1 231 DEPRECATED = 2 232 UNCHANGED = 3 233 GOOD = 4 234 235class CategorizedChannelUpdates(NamedTuple): 236 orphaned: List # no channel announcement for channel update 237 expired: List # update older than two weeks 238 deprecated: List # update older than database entry 239 unchanged: List # unchanged policies 240 good: List # good updates 241 242 243def get_mychannel_info(short_channel_id: ShortChannelID, 244 my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]: 245 chan = my_channels.get(short_channel_id) 246 if not chan: 247 return 248 ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs()) 249 return ci._replace(capacity_sat=chan.constraints.capacity) 250 251def get_mychannel_policy(short_channel_id: bytes, node_id: bytes, 252 my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[Policy]: 253 chan = my_channels.get(short_channel_id) # type: Optional[Channel] 254 if not chan: 255 return 256 if node_id == chan.node_id: # incoming direction (to us) 257 remote_update_raw = chan.get_remote_update() 258 if not remote_update_raw: 259 return 260 now = int(time.time()) 261 remote_update_decoded = decode_msg(remote_update_raw)[1] 262 remote_update_decoded['timestamp'] = now 263 remote_update_decoded['start_node'] = node_id 264 return Policy.from_msg(remote_update_decoded) 265 elif node_id == chan.get_local_pubkey(): # outgoing direction (from us) 266 local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1] 267 local_update_decoded['start_node'] = node_id 268 return Policy.from_msg(local_update_decoded) 269 270 271create_channel_info = """ 272CREATE TABLE IF NOT EXISTS channel_info ( 273short_channel_id BLOB(8), 274msg BLOB, 275PRIMARY KEY(short_channel_id) 276)""" 277 278create_policy = """ 279CREATE TABLE IF NOT EXISTS policy ( 280key BLOB(41), 281msg BLOB, 282PRIMARY KEY(key) 283)""" 284 285create_address = """ 286CREATE TABLE IF NOT EXISTS address ( 287node_id BLOB(33), 288host STRING(256), 289port INTEGER NOT NULL, 290timestamp INTEGER, 291PRIMARY KEY(node_id, host, port) 292)""" 293 294create_node_info = """ 295CREATE TABLE IF NOT EXISTS node_info ( 296node_id BLOB(33), 297msg BLOB, 298PRIMARY KEY(node_id) 299)""" 300 301 302class ChannelDB(SqlDB): 303 304 NUM_MAX_RECENT_PEERS = 20 305 306 def __init__(self, network: 'Network'): 307 path = os.path.join(get_headers_dir(network.config), 'gossip_db') 308 super().__init__(network.asyncio_loop, path, commit_interval=100) 309 self.lock = threading.RLock() 310 self.num_nodes = 0 311 self.num_channels = 0 312 self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] 313 self.ca_verifier = LNChannelVerifier(network, self) 314 315 # initialized in load_data 316 # note: modify/iterate needs self.lock 317 self._channels = {} # type: Dict[ShortChannelID, ChannelInfo] 318 self._policies = {} # type: Dict[Tuple[bytes, ShortChannelID], Policy] # (node_id, scid) -> Policy 319 self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo 320 # node_id -> NetAddress -> timestamp 321 self._addresses = defaultdict(dict) # type: Dict[bytes, Dict[NetAddress, int]] 322 self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]] 323 self._recent_peers = [] # type: List[bytes] # list of node_ids 324 self._chans_with_0_policies = set() # type: Set[ShortChannelID] 325 self._chans_with_1_policies = set() # type: Set[ShortChannelID] 326 self._chans_with_2_policies = set() # type: Set[ShortChannelID] 327 328 self.data_loaded = asyncio.Event() 329 self.network = network # only for callback 330 331 def update_counts(self): 332 self.num_nodes = len(self._nodes) 333 self.num_channels = len(self._channels) 334 self.num_policies = len(self._policies) 335 util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies) 336 util.trigger_callback('ln_gossip_sync_progress') 337 338 def get_channel_ids(self): 339 with self.lock: 340 return set(self._channels.keys()) 341 342 def add_recent_peer(self, peer: LNPeerAddr): 343 now = int(time.time()) 344 node_id = peer.pubkey 345 with self.lock: 346 self._addresses[node_id][peer.net_addr()] = now 347 # list is ordered 348 if node_id in self._recent_peers: 349 self._recent_peers.remove(node_id) 350 self._recent_peers.insert(0, node_id) 351 self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS] 352 self._db_save_node_address(peer, now) 353 354 def get_200_randomly_sorted_nodes_not_in(self, node_ids): 355 with self.lock: 356 unshuffled = set(self._nodes.keys()) - node_ids 357 return random.sample(unshuffled, min(200, len(unshuffled))) 358 359 def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]: 360 """Returns latest address we successfully connected to, for given node.""" 361 addr_to_ts = self._addresses.get(node_id) 362 if not addr_to_ts: 363 return None 364 addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0] 365 try: 366 return LNPeerAddr(str(addr.host), addr.port, node_id) 367 except ValueError: 368 return None 369 370 def get_recent_peers(self): 371 if not self.data_loaded.is_set(): 372 raise Exception("channelDB data not loaded yet!") 373 with self.lock: 374 ret = [self.get_last_good_address(node_id) 375 for node_id in self._recent_peers] 376 return ret 377 378 # note: currently channel announcements are trusted by default (trusted=True); 379 # they are not SPV-verified. Verifying them would make the gossip sync 380 # even slower; especially as servers will start throttling us. 381 # It would probably put significant strain on servers if all clients 382 # verified the complete gossip. 383 def add_channel_announcements(self, msg_payloads, *, trusted=True): 384 # note: signatures have already been verified. 385 if type(msg_payloads) is dict: 386 msg_payloads = [msg_payloads] 387 added = 0 388 for msg in msg_payloads: 389 short_channel_id = ShortChannelID(msg['short_channel_id']) 390 if short_channel_id in self._channels: 391 continue 392 if constants.net.rev_genesis_bytes() != msg['chain_hash']: 393 self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash']))) 394 continue 395 try: 396 channel_info = ChannelInfo.from_msg(msg) 397 except IncompatibleOrInsaneFeatures as e: 398 self.logger.info(f"unknown or insane feature bits: {e!r}") 399 continue 400 if trusted: 401 added += 1 402 self.add_verified_channel_info(msg) 403 else: 404 added += self.ca_verifier.add_new_channel_info(short_channel_id, msg) 405 406 self.update_counts() 407 self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads))) 408 409 def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None: 410 try: 411 channel_info = ChannelInfo.from_msg(msg) 412 except IncompatibleOrInsaneFeatures: 413 return 414 channel_info = channel_info._replace(capacity_sat=capacity_sat) 415 with self.lock: 416 self._channels[channel_info.short_channel_id] = channel_info 417 self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) 418 self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) 419 self._update_num_policies_for_chan(channel_info.short_channel_id) 420 if 'raw' in msg: 421 self._db_save_channel(channel_info.short_channel_id, msg['raw']) 422 423 def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool: 424 changed = False 425 if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta: 426 changed |= True 427 if verbose: 428 self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}') 429 if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat: 430 changed |= True 431 if verbose: 432 self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}') 433 if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat: 434 changed |= True 435 if verbose: 436 self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}') 437 if old_policy.fee_base_msat != new_policy.fee_base_msat: 438 changed |= True 439 if verbose: 440 self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}') 441 if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths: 442 changed |= True 443 if verbose: 444 self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}') 445 if old_policy.channel_flags != new_policy.channel_flags: 446 changed |= True 447 if verbose: 448 self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') 449 if old_policy.message_flags != new_policy.message_flags: 450 changed |= True 451 if verbose: 452 self.logger.info(f'message_flags: {old_policy.message_flags} -> {new_policy.message_flags}') 453 if not changed and verbose: 454 self.logger.info(f'policy unchanged: {old_policy.timestamp} -> {new_policy.timestamp}') 455 return changed 456 457 def add_channel_update( 458 self, payload, *, max_age=None, verify=True, verbose=True) -> UpdateStatus: 459 now = int(time.time()) 460 short_channel_id = ShortChannelID(payload['short_channel_id']) 461 timestamp = payload['timestamp'] 462 if max_age and now - timestamp > max_age: 463 return UpdateStatus.EXPIRED 464 if timestamp - now > 60: 465 return UpdateStatus.DEPRECATED 466 channel_info = self._channels.get(short_channel_id) 467 if not channel_info: 468 return UpdateStatus.ORPHANED 469 flags = int.from_bytes(payload['channel_flags'], 'big') 470 direction = flags & FLAG_DIRECTION 471 start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id 472 payload['start_node'] = start_node 473 # compare updates to existing database entries 474 short_channel_id = ShortChannelID(payload['short_channel_id']) 475 key = (start_node, short_channel_id) 476 old_policy = self._policies.get(key) 477 if old_policy and timestamp <= old_policy.timestamp + 60: 478 return UpdateStatus.DEPRECATED 479 if verify: 480 self.verify_channel_update(payload) 481 policy = Policy.from_msg(payload) 482 with self.lock: 483 self._policies[key] = policy 484 self._update_num_policies_for_chan(short_channel_id) 485 if 'raw' in payload: 486 self._db_save_policy(policy.key, payload['raw']) 487 if old_policy and not self.policy_changed(old_policy, policy, verbose): 488 return UpdateStatus.UNCHANGED 489 else: 490 return UpdateStatus.GOOD 491 492 def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates: 493 orphaned = [] 494 expired = [] 495 deprecated = [] 496 unchanged = [] 497 good = [] 498 for payload in payloads: 499 r = self.add_channel_update(payload, max_age=max_age, verbose=False, verify=True) 500 if r == UpdateStatus.ORPHANED: 501 orphaned.append(payload) 502 elif r == UpdateStatus.EXPIRED: 503 expired.append(payload) 504 elif r == UpdateStatus.DEPRECATED: 505 deprecated.append(payload) 506 elif r == UpdateStatus.UNCHANGED: 507 unchanged.append(payload) 508 elif r == UpdateStatus.GOOD: 509 good.append(payload) 510 self.update_counts() 511 return CategorizedChannelUpdates( 512 orphaned=orphaned, 513 expired=expired, 514 deprecated=deprecated, 515 unchanged=unchanged, 516 good=good) 517 518 519 def create_database(self): 520 c = self.conn.cursor() 521 c.execute(create_node_info) 522 c.execute(create_address) 523 c.execute(create_policy) 524 c.execute(create_channel_info) 525 self.conn.commit() 526 527 @sql 528 def _db_save_policy(self, key: bytes, msg: bytes): 529 # 'msg' is a 'channel_update' message 530 c = self.conn.cursor() 531 c.execute("""REPLACE INTO policy (key, msg) VALUES (?,?)""", [key, msg]) 532 533 @sql 534 def _db_delete_policy(self, node_id: bytes, short_channel_id: ShortChannelID): 535 key = short_channel_id + node_id 536 c = self.conn.cursor() 537 c.execute("""DELETE FROM policy WHERE key=?""", (key,)) 538 539 @sql 540 def _db_save_channel(self, short_channel_id: ShortChannelID, msg: bytes): 541 # 'msg' is a 'channel_announcement' message 542 c = self.conn.cursor() 543 c.execute("REPLACE INTO channel_info (short_channel_id, msg) VALUES (?,?)", [short_channel_id, msg]) 544 545 @sql 546 def _db_delete_channel(self, short_channel_id: ShortChannelID): 547 c = self.conn.cursor() 548 c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,)) 549 550 @sql 551 def _db_save_node_info(self, node_id: bytes, msg: bytes): 552 # 'msg' is a 'node_announcement' message 553 c = self.conn.cursor() 554 c.execute("REPLACE INTO node_info (node_id, msg) VALUES (?,?)", [node_id, msg]) 555 556 @sql 557 def _db_save_node_address(self, peer: LNPeerAddr, timestamp: int): 558 c = self.conn.cursor() 559 c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", 560 (peer.pubkey, peer.host, peer.port, timestamp)) 561 562 @sql 563 def _db_save_node_addresses(self, node_addresses: Sequence[LNPeerAddr]): 564 c = self.conn.cursor() 565 for addr in node_addresses: 566 c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.pubkey, addr.host, addr.port)) 567 r = c.fetchall() 568 if r == []: 569 c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.pubkey, addr.host, addr.port, 0)) 570 571 @classmethod 572 def verify_channel_update(cls, payload, *, start_node: bytes = None) -> None: 573 short_channel_id = payload['short_channel_id'] 574 short_channel_id = ShortChannelID(short_channel_id) 575 if constants.net.rev_genesis_bytes() != payload['chain_hash']: 576 raise InvalidGossipMsg('wrong chain hash') 577 start_node = payload.get('start_node', None) or start_node 578 assert start_node is not None 579 if not verify_sig_for_channel_update(payload, start_node): 580 raise InvalidGossipMsg(f'failed verifying channel update for {short_channel_id}') 581 582 @classmethod 583 def verify_channel_announcement(cls, payload) -> None: 584 h = sha256d(payload['raw'][2+256:]) 585 pubkeys = [payload['node_id_1'], payload['node_id_2'], payload['bitcoin_key_1'], payload['bitcoin_key_2']] 586 sigs = [payload['node_signature_1'], payload['node_signature_2'], payload['bitcoin_signature_1'], payload['bitcoin_signature_2']] 587 for pubkey, sig in zip(pubkeys, sigs): 588 if not ecc.verify_signature(pubkey, sig, h): 589 raise InvalidGossipMsg('signature failed') 590 591 @classmethod 592 def verify_node_announcement(cls, payload) -> None: 593 pubkey = payload['node_id'] 594 signature = payload['signature'] 595 h = sha256d(payload['raw'][66:]) 596 if not ecc.verify_signature(pubkey, signature, h): 597 raise InvalidGossipMsg('signature failed') 598 599 def add_node_announcements(self, msg_payloads): 600 # note: signatures have already been verified. 601 if type(msg_payloads) is dict: 602 msg_payloads = [msg_payloads] 603 new_nodes = {} 604 for msg_payload in msg_payloads: 605 try: 606 node_info, node_addresses = NodeInfo.from_msg(msg_payload) 607 except IncompatibleOrInsaneFeatures: 608 continue 609 node_id = node_info.node_id 610 # Ignore node if it has no associated channel (DoS protection) 611 if node_id not in self._channels_for_node: 612 #self.logger.info('ignoring orphan node_announcement') 613 continue 614 node = self._nodes.get(node_id) 615 if node and node.timestamp >= node_info.timestamp: 616 continue 617 node = new_nodes.get(node_id) 618 if node and node.timestamp >= node_info.timestamp: 619 continue 620 # save 621 with self.lock: 622 self._nodes[node_id] = node_info 623 if 'raw' in msg_payload: 624 self._db_save_node_info(node_id, msg_payload['raw']) 625 with self.lock: 626 for addr in node_addresses: 627 net_addr = NetAddress(addr.host, addr.port) 628 self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0 629 self._db_save_node_addresses(node_addresses) 630 631 self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) 632 self.update_counts() 633 634 def get_old_policies(self, delta) -> Sequence[Tuple[bytes, ShortChannelID]]: 635 with self.lock: 636 _policies = self._policies.copy() 637 now = int(time.time()) 638 return list(k for k, v in _policies.items() if v.timestamp <= now - delta) 639 640 def prune_old_policies(self, delta): 641 old_policies = self.get_old_policies(delta) 642 if old_policies: 643 for key in old_policies: 644 node_id, scid = key 645 with self.lock: 646 self._policies.pop(key) 647 self._db_delete_policy(*key) 648 self._update_num_policies_for_chan(scid) 649 self.update_counts() 650 self.logger.info(f'Deleting {len(old_policies)} old policies') 651 652 def prune_orphaned_channels(self): 653 with self.lock: 654 orphaned_chans = self._chans_with_0_policies.copy() 655 if orphaned_chans: 656 for short_channel_id in orphaned_chans: 657 self.remove_channel(short_channel_id) 658 self.update_counts() 659 self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels') 660 661 def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes) -> bool: 662 """Returns True iff the channel update was successfully added and it was different than 663 what we had before (if any). 664 """ 665 if not verify_sig_for_channel_update(msg_payload, start_node_id): 666 return False # ignore 667 short_channel_id = ShortChannelID(msg_payload['short_channel_id']) 668 msg_payload['start_node'] = start_node_id 669 key = (start_node_id, short_channel_id) 670 prev_chanupd = self._channel_updates_for_private_channels.get(key) 671 if prev_chanupd == msg_payload: 672 return False 673 self._channel_updates_for_private_channels[key] = msg_payload 674 return True 675 676 def remove_channel(self, short_channel_id: ShortChannelID): 677 # FIXME what about rm-ing policies? 678 with self.lock: 679 channel_info = self._channels.pop(short_channel_id, None) 680 if channel_info: 681 self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) 682 self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id) 683 self._update_num_policies_for_chan(short_channel_id) 684 # delete from database 685 self._db_delete_channel(short_channel_id) 686 687 def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]: 688 """Returns list of (host, port, timestamp).""" 689 addr_to_ts = self._addresses.get(node_id) 690 if not addr_to_ts: 691 return [] 692 return [(str(net_addr.host), net_addr.port, ts) 693 for net_addr, ts in addr_to_ts.items()] 694 695 @sql 696 @profiler 697 def load_data(self): 698 if self.data_loaded.is_set(): 699 return 700 # Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow. 701 c = self.conn.cursor() 702 c.execute("""SELECT * FROM address""") 703 for x in c: 704 node_id, host, port, timestamp = x 705 try: 706 net_addr = NetAddress(host, port) 707 except Exception: 708 continue 709 self._addresses[node_id][net_addr] = int(timestamp or 0) 710 def newest_ts_for_node_id(node_id): 711 newest_ts = 0 712 for addr, ts in self._addresses[node_id].items(): 713 newest_ts = max(newest_ts, ts) 714 return newest_ts 715 sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True) 716 self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS] 717 c.execute("""SELECT * FROM channel_info""") 718 for short_channel_id, msg in c: 719 try: 720 ci = ChannelInfo.from_raw_msg(msg) 721 except IncompatibleOrInsaneFeatures: 722 continue 723 self._channels[ShortChannelID.normalize(short_channel_id)] = ci 724 c.execute("""SELECT * FROM node_info""") 725 for node_id, msg in c: 726 try: 727 node_info, node_addresses = NodeInfo.from_raw_msg(msg) 728 except IncompatibleOrInsaneFeatures: 729 continue 730 # don't load node_addresses because they dont have timestamps 731 self._nodes[node_id] = node_info 732 c.execute("""SELECT * FROM policy""") 733 for key, msg in c: 734 p = Policy.from_raw_msg(key, msg) 735 self._policies[(p.start_node, p.short_channel_id)] = p 736 for channel_info in self._channels.values(): 737 self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) 738 self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) 739 self._update_num_policies_for_chan(channel_info.short_channel_id) 740 self.logger.info(f'data loaded. {len(self._channels)} chans. {len(self._policies)} policies. ' 741 f'{len(self._channels_for_node)} nodes.') 742 self.update_counts() 743 (nchans_with_0p, nchans_with_1p, nchans_with_2p) = self.get_num_channels_partitioned_by_policy_count() 744 self.logger.info(f'num_channels_partitioned_by_policy_count. ' 745 f'0p: {nchans_with_0p}, 1p: {nchans_with_1p}, 2p: {nchans_with_2p}') 746 self.asyncio_loop.call_soon_threadsafe(self.data_loaded.set) 747 util.trigger_callback('gossip_db_loaded') 748 749 def _update_num_policies_for_chan(self, short_channel_id: ShortChannelID) -> None: 750 channel_info = self.get_channel_info(short_channel_id) 751 if channel_info is None: 752 with self.lock: 753 self._chans_with_0_policies.discard(short_channel_id) 754 self._chans_with_1_policies.discard(short_channel_id) 755 self._chans_with_2_policies.discard(short_channel_id) 756 return 757 p1 = self.get_policy_for_node(short_channel_id, channel_info.node1_id) 758 p2 = self.get_policy_for_node(short_channel_id, channel_info.node2_id) 759 with self.lock: 760 self._chans_with_0_policies.discard(short_channel_id) 761 self._chans_with_1_policies.discard(short_channel_id) 762 self._chans_with_2_policies.discard(short_channel_id) 763 if p1 is not None and p2 is not None: 764 self._chans_with_2_policies.add(short_channel_id) 765 elif p1 is None and p2 is None: 766 self._chans_with_0_policies.add(short_channel_id) 767 else: 768 self._chans_with_1_policies.add(short_channel_id) 769 770 def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]: 771 nchans_with_0p = len(self._chans_with_0_policies) 772 nchans_with_1p = len(self._chans_with_1_policies) 773 nchans_with_2p = len(self._chans_with_2_policies) 774 return nchans_with_0p, nchans_with_1p, nchans_with_2p 775 776 def get_policy_for_node( 777 self, 778 short_channel_id: bytes, 779 node_id: bytes, 780 *, 781 my_channels: Dict[ShortChannelID, 'Channel'] = None, 782 private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, 783 ) -> Optional['Policy']: 784 channel_info = self.get_channel_info(short_channel_id) 785 if channel_info is not None: # publicly announced channel 786 policy = self._policies.get((node_id, short_channel_id)) 787 if policy: 788 return policy 789 else: # private channel 790 chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id)) 791 if chan_upd_dict: 792 return Policy.from_msg(chan_upd_dict) 793 # check if it's one of our own channels 794 if my_channels: 795 policy = get_mychannel_policy(short_channel_id, node_id, my_channels) 796 if policy: 797 return policy 798 if private_route_edges: 799 route_edge = private_route_edges.get(short_channel_id, None) 800 if route_edge: 801 return Policy.from_route_edge(route_edge) 802 803 def get_channel_info( 804 self, 805 short_channel_id: ShortChannelID, 806 *, 807 my_channels: Dict[ShortChannelID, 'Channel'] = None, 808 private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, 809 ) -> Optional[ChannelInfo]: 810 ret = self._channels.get(short_channel_id) 811 if ret: 812 return ret 813 # check if it's one of our own channels 814 if my_channels: 815 channel_info = get_mychannel_info(short_channel_id, my_channels) 816 if channel_info: 817 return channel_info 818 if private_route_edges: 819 route_edge = private_route_edges.get(short_channel_id) 820 if route_edge: 821 return ChannelInfo.from_route_edge(route_edge) 822 823 def get_channels_for_node( 824 self, 825 node_id: bytes, 826 *, 827 my_channels: Dict[ShortChannelID, 'Channel'] = None, 828 private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, 829 ) -> Set[bytes]: 830 """Returns the set of short channel IDs where node_id is one of the channel participants.""" 831 if not self.data_loaded.is_set(): 832 raise Exception("channelDB data not loaded yet!") 833 relevant_channels = self._channels_for_node.get(node_id) or set() 834 relevant_channels = set(relevant_channels) # copy 835 # add our own channels # TODO maybe slow? 836 if my_channels: 837 for chan in my_channels.values(): 838 if node_id in (chan.node_id, chan.get_local_pubkey()): 839 relevant_channels.add(chan.short_channel_id) 840 # add private channels # TODO maybe slow? 841 if private_route_edges: 842 for route_edge in private_route_edges.values(): 843 if node_id in (route_edge.start_node, route_edge.end_node): 844 relevant_channels.add(route_edge.short_channel_id) 845 return relevant_channels 846 847 def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *, 848 my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[Tuple[bytes, bytes]]: 849 channel_info = self.get_channel_info(short_channel_id) 850 if channel_info is not None: # publicly announced channel 851 return channel_info.node1_id, channel_info.node2_id 852 # check if it's one of our own channels 853 if not my_channels: 854 return 855 chan = my_channels.get(short_channel_id) # type: Optional[Channel] 856 if not chan: 857 return 858 return chan.get_local_pubkey(), chan.node_id 859 860 def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']: 861 return self._nodes.get(node_id) 862 863 def get_node_infos(self) -> Dict[bytes, NodeInfo]: 864 with self.lock: 865 return self._nodes.copy() 866 867 def get_node_policies(self) -> Dict[Tuple[bytes, ShortChannelID], Policy]: 868 with self.lock: 869 return self._policies.copy() 870 871 def get_node_by_prefix(self, prefix): 872 with self.lock: 873 for k in self._addresses.keys(): 874 if k.startswith(prefix): 875 return k 876 raise Exception('node not found') 877 878 def to_dict(self) -> dict: 879 """ Generates a graph representation in terms of a dictionary. 880 881 The dictionary contains only native python types and can be encoded 882 to json. 883 """ 884 with self.lock: 885 graph = {'nodes': [], 'channels': []} 886 887 # gather nodes 888 for pk, nodeinfo in self._nodes.items(): 889 # use _asdict() to convert NamedTuples to json encodable dicts 890 graph['nodes'].append( 891 nodeinfo._asdict(), 892 ) 893 graph['nodes'][-1]['addresses'] = [ 894 {'host': str(addr.host), 'port': addr.port, 'timestamp': ts} 895 for addr, ts in self._addresses[pk].items() 896 ] 897 898 # gather channels 899 for cid, channelinfo in self._channels.items(): 900 graph['channels'].append( 901 channelinfo._asdict(), 902 ) 903 policy1 = self._policies.get( 904 (channelinfo.node1_id, channelinfo.short_channel_id)) 905 policy2 = self._policies.get( 906 (channelinfo.node2_id, channelinfo.short_channel_id)) 907 graph['channels'][-1]['policy1'] = policy1._asdict() if policy1 else None 908 graph['channels'][-1]['policy2'] = policy2._asdict() if policy2 else None 909 910 # need to use json_normalize otherwise json encoding in rpc server fails 911 graph = json_normalize(graph) 912 return graph 913