1# -*- coding: utf-8 -*-
2#
3# Electrum - lightweight Bitcoin client
4# Copyright (C) 2018 The Electrum developers
5#
6# Permission is hereby granted, free of charge, to any person
7# obtaining a copy of this software and associated documentation files
8# (the "Software"), to deal in the Software without restriction,
9# including without limitation the rights to use, copy, modify, merge,
10# publish, distribute, sublicense, and/or sell copies of the Software,
11# and to permit persons to whom the Software is furnished to do so,
12# subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be
15# included in all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24# SOFTWARE.
25
26import time
27import random
28import os
29from collections import defaultdict
30from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
31import binascii
32import base64
33import asyncio
34import threading
35from enum import IntEnum
36
37from aiorpcx import NetAddress
38
39from .sql_db import SqlDB, sql
40from . import constants, util
41from .util import bh2u, profiler, get_headers_dir, is_ip_address, json_normalize
42from .logging import Logger
43from .lnutil import (LNPeerAddr, format_short_channel_id, ShortChannelID,
44                     validate_features, IncompatibleOrInsaneFeatures, InvalidGossipMsg)
45from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
46from .lnmsg import decode_msg
47from . import ecc
48from .crypto import sha256d
49
50if TYPE_CHECKING:
51    from .network import Network
52    from .lnchannel import Channel
53    from .lnrouter import RouteEdge
54
55
56FLAG_DISABLE   = 1 << 1
57FLAG_DIRECTION = 1 << 0
58
59
60class ChannelInfo(NamedTuple):
61    short_channel_id: ShortChannelID
62    node1_id: bytes
63    node2_id: bytes
64    capacity_sat: Optional[int]
65
66    @staticmethod
67    def from_msg(payload: dict) -> 'ChannelInfo':
68        features = int.from_bytes(payload['features'], 'big')
69        validate_features(features)
70        channel_id = payload['short_channel_id']
71        node_id_1 = payload['node_id_1']
72        node_id_2 = payload['node_id_2']
73        assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
74        capacity_sat = None
75        return ChannelInfo(
76            short_channel_id = ShortChannelID.normalize(channel_id),
77            node1_id = node_id_1,
78            node2_id = node_id_2,
79            capacity_sat = capacity_sat
80        )
81
82    @staticmethod
83    def from_raw_msg(raw: bytes) -> 'ChannelInfo':
84        payload_dict = decode_msg(raw)[1]
85        return ChannelInfo.from_msg(payload_dict)
86
87    @staticmethod
88    def from_route_edge(route_edge: 'RouteEdge') -> 'ChannelInfo':
89        node1_id, node2_id = sorted([route_edge.start_node, route_edge.end_node])
90        return ChannelInfo(
91            short_channel_id=route_edge.short_channel_id,
92            node1_id=node1_id,
93            node2_id=node2_id,
94            capacity_sat=None,
95        )
96
97
98class Policy(NamedTuple):
99    key: bytes
100    cltv_expiry_delta: int
101    htlc_minimum_msat: int
102    htlc_maximum_msat: Optional[int]
103    fee_base_msat: int
104    fee_proportional_millionths: int
105    channel_flags: int
106    message_flags: int
107    timestamp: int
108
109    @staticmethod
110    def from_msg(payload: dict) -> 'Policy':
111        return Policy(
112            key                         = payload['short_channel_id'] + payload['start_node'],
113            cltv_expiry_delta           = payload['cltv_expiry_delta'],
114            htlc_minimum_msat           = payload['htlc_minimum_msat'],
115            htlc_maximum_msat           = payload.get('htlc_maximum_msat', None),
116            fee_base_msat               = payload['fee_base_msat'],
117            fee_proportional_millionths = payload['fee_proportional_millionths'],
118            message_flags               = int.from_bytes(payload['message_flags'], "big"),
119            channel_flags               = int.from_bytes(payload['channel_flags'], "big"),
120            timestamp                   = payload['timestamp'],
121        )
122
123    @staticmethod
124    def from_raw_msg(key:bytes, raw: bytes) -> 'Policy':
125        payload = decode_msg(raw)[1]
126        payload['start_node'] = key[8:]
127        return Policy.from_msg(payload)
128
129    @staticmethod
130    def from_route_edge(route_edge: 'RouteEdge') -> 'Policy':
131        return Policy(
132            key=route_edge.short_channel_id + route_edge.start_node,
133            cltv_expiry_delta=route_edge.cltv_expiry_delta,
134            htlc_minimum_msat=0,
135            htlc_maximum_msat=None,
136            fee_base_msat=route_edge.fee_base_msat,
137            fee_proportional_millionths=route_edge.fee_proportional_millionths,
138            channel_flags=0,
139            message_flags=0,
140            timestamp=0,
141        )
142
143    def is_disabled(self):
144        return self.channel_flags & FLAG_DISABLE
145
146    @property
147    def short_channel_id(self) -> ShortChannelID:
148        return ShortChannelID.normalize(self.key[0:8])
149
150    @property
151    def start_node(self) -> bytes:
152        return self.key[8:]
153
154
155class NodeInfo(NamedTuple):
156    node_id: bytes
157    features: int
158    timestamp: int
159    alias: str
160
161    @staticmethod
162    def from_msg(payload) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
163        node_id = payload['node_id']
164        features = int.from_bytes(payload['features'], "big")
165        validate_features(features)
166        addresses = NodeInfo.parse_addresses_field(payload['addresses'])
167        peer_addrs = []
168        for host, port in addresses:
169            try:
170                peer_addrs.append(LNPeerAddr(host=host, port=port, pubkey=node_id))
171            except ValueError:
172                pass
173        alias = payload['alias'].rstrip(b'\x00')
174        try:
175            alias = alias.decode('utf8')
176        except:
177            alias = ''
178        timestamp = payload['timestamp']
179        node_info = NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias)
180        return node_info, peer_addrs
181
182    @staticmethod
183    def from_raw_msg(raw: bytes) -> Tuple['NodeInfo', Sequence['LNPeerAddr']]:
184        payload_dict = decode_msg(raw)[1]
185        return NodeInfo.from_msg(payload_dict)
186
187    @staticmethod
188    def parse_addresses_field(addresses_field):
189        buf = addresses_field
190        def read(n):
191            nonlocal buf
192            data, buf = buf[0:n], buf[n:]
193            return data
194        addresses = []
195        while buf:
196            atype = ord(read(1))
197            if atype == 0:
198                pass
199            elif atype == 1:  # IPv4
200                ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
201                port = int.from_bytes(read(2), 'big')
202                if is_ip_address(ipv4_addr) and port != 0:
203                    addresses.append((ipv4_addr, port))
204            elif atype == 2:  # IPv6
205                ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
206                ipv6_addr = ipv6_addr.decode('ascii')
207                port = int.from_bytes(read(2), 'big')
208                if is_ip_address(ipv6_addr) and port != 0:
209                    addresses.append((ipv6_addr, port))
210            elif atype == 3:  # onion v2
211                host = base64.b32encode(read(10)) + b'.onion'
212                host = host.decode('ascii').lower()
213                port = int.from_bytes(read(2), 'big')
214                addresses.append((host, port))
215            elif atype == 4:  # onion v3
216                host = base64.b32encode(read(35)) + b'.onion'
217                host = host.decode('ascii').lower()
218                port = int.from_bytes(read(2), 'big')
219                addresses.append((host, port))
220            else:
221                # unknown address type
222                # we don't know how long it is -> have to escape
223                # if there are other addresses we could have parsed later, they are lost.
224                break
225        return addresses
226
227
228class UpdateStatus(IntEnum):
229    ORPHANED   = 0
230    EXPIRED    = 1
231    DEPRECATED = 2
232    UNCHANGED  = 3
233    GOOD       = 4
234
235class CategorizedChannelUpdates(NamedTuple):
236    orphaned: List    # no channel announcement for channel update
237    expired: List     # update older than two weeks
238    deprecated: List  # update older than database entry
239    unchanged: List   # unchanged policies
240    good: List        # good updates
241
242
243def get_mychannel_info(short_channel_id: ShortChannelID,
244                       my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[ChannelInfo]:
245    chan = my_channels.get(short_channel_id)
246    if not chan:
247        return
248    ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
249    return ci._replace(capacity_sat=chan.constraints.capacity)
250
251def get_mychannel_policy(short_channel_id: bytes, node_id: bytes,
252                         my_channels: Dict[ShortChannelID, 'Channel']) -> Optional[Policy]:
253    chan = my_channels.get(short_channel_id)  # type: Optional[Channel]
254    if not chan:
255        return
256    if node_id == chan.node_id:  # incoming direction (to us)
257        remote_update_raw = chan.get_remote_update()
258        if not remote_update_raw:
259            return
260        now = int(time.time())
261        remote_update_decoded = decode_msg(remote_update_raw)[1]
262        remote_update_decoded['timestamp'] = now
263        remote_update_decoded['start_node'] = node_id
264        return Policy.from_msg(remote_update_decoded)
265    elif node_id == chan.get_local_pubkey():  # outgoing direction (from us)
266        local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
267        local_update_decoded['start_node'] = node_id
268        return Policy.from_msg(local_update_decoded)
269
270
271create_channel_info = """
272CREATE TABLE IF NOT EXISTS channel_info (
273short_channel_id BLOB(8),
274msg BLOB,
275PRIMARY KEY(short_channel_id)
276)"""
277
278create_policy = """
279CREATE TABLE IF NOT EXISTS policy (
280key BLOB(41),
281msg BLOB,
282PRIMARY KEY(key)
283)"""
284
285create_address = """
286CREATE TABLE IF NOT EXISTS address (
287node_id BLOB(33),
288host STRING(256),
289port INTEGER NOT NULL,
290timestamp INTEGER,
291PRIMARY KEY(node_id, host, port)
292)"""
293
294create_node_info = """
295CREATE TABLE IF NOT EXISTS node_info (
296node_id BLOB(33),
297msg BLOB,
298PRIMARY KEY(node_id)
299)"""
300
301
302class ChannelDB(SqlDB):
303
304    NUM_MAX_RECENT_PEERS = 20
305
306    def __init__(self, network: 'Network'):
307        path = os.path.join(get_headers_dir(network.config), 'gossip_db')
308        super().__init__(network.asyncio_loop, path, commit_interval=100)
309        self.lock = threading.RLock()
310        self.num_nodes = 0
311        self.num_channels = 0
312        self._channel_updates_for_private_channels = {}  # type: Dict[Tuple[bytes, bytes], dict]
313        self.ca_verifier = LNChannelVerifier(network, self)
314
315        # initialized in load_data
316        # note: modify/iterate needs self.lock
317        self._channels = {}  # type: Dict[ShortChannelID, ChannelInfo]
318        self._policies = {}  # type: Dict[Tuple[bytes, ShortChannelID], Policy]  # (node_id, scid) -> Policy
319        self._nodes = {}  # type: Dict[bytes, NodeInfo]  # node_id -> NodeInfo
320        # node_id -> NetAddress -> timestamp
321        self._addresses = defaultdict(dict)  # type: Dict[bytes, Dict[NetAddress, int]]
322        self._channels_for_node = defaultdict(set)  # type: Dict[bytes, Set[ShortChannelID]]
323        self._recent_peers = []  # type: List[bytes]  # list of node_ids
324        self._chans_with_0_policies = set()  # type: Set[ShortChannelID]
325        self._chans_with_1_policies = set()  # type: Set[ShortChannelID]
326        self._chans_with_2_policies = set()  # type: Set[ShortChannelID]
327
328        self.data_loaded = asyncio.Event()
329        self.network = network # only for callback
330
331    def update_counts(self):
332        self.num_nodes = len(self._nodes)
333        self.num_channels = len(self._channels)
334        self.num_policies = len(self._policies)
335        util.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
336        util.trigger_callback('ln_gossip_sync_progress')
337
338    def get_channel_ids(self):
339        with self.lock:
340            return set(self._channels.keys())
341
342    def add_recent_peer(self, peer: LNPeerAddr):
343        now = int(time.time())
344        node_id = peer.pubkey
345        with self.lock:
346            self._addresses[node_id][peer.net_addr()] = now
347            # list is ordered
348            if node_id in self._recent_peers:
349                self._recent_peers.remove(node_id)
350            self._recent_peers.insert(0, node_id)
351            self._recent_peers = self._recent_peers[:self.NUM_MAX_RECENT_PEERS]
352        self._db_save_node_address(peer, now)
353
354    def get_200_randomly_sorted_nodes_not_in(self, node_ids):
355        with self.lock:
356            unshuffled = set(self._nodes.keys()) - node_ids
357        return random.sample(unshuffled, min(200, len(unshuffled)))
358
359    def get_last_good_address(self, node_id: bytes) -> Optional[LNPeerAddr]:
360        """Returns latest address we successfully connected to, for given node."""
361        addr_to_ts = self._addresses.get(node_id)
362        if not addr_to_ts:
363            return None
364        addr = sorted(list(addr_to_ts), key=lambda a: addr_to_ts[a], reverse=True)[0]
365        try:
366            return LNPeerAddr(str(addr.host), addr.port, node_id)
367        except ValueError:
368            return None
369
370    def get_recent_peers(self):
371        if not self.data_loaded.is_set():
372            raise Exception("channelDB data not loaded yet!")
373        with self.lock:
374            ret = [self.get_last_good_address(node_id)
375                   for node_id in self._recent_peers]
376            return ret
377
378    # note: currently channel announcements are trusted by default (trusted=True);
379    #       they are not SPV-verified. Verifying them would make the gossip sync
380    #       even slower; especially as servers will start throttling us.
381    #       It would probably put significant strain on servers if all clients
382    #       verified the complete gossip.
383    def add_channel_announcements(self, msg_payloads, *, trusted=True):
384        # note: signatures have already been verified.
385        if type(msg_payloads) is dict:
386            msg_payloads = [msg_payloads]
387        added = 0
388        for msg in msg_payloads:
389            short_channel_id = ShortChannelID(msg['short_channel_id'])
390            if short_channel_id in self._channels:
391                continue
392            if constants.net.rev_genesis_bytes() != msg['chain_hash']:
393                self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
394                continue
395            try:
396                channel_info = ChannelInfo.from_msg(msg)
397            except IncompatibleOrInsaneFeatures as e:
398                self.logger.info(f"unknown or insane feature bits: {e!r}")
399                continue
400            if trusted:
401                added += 1
402                self.add_verified_channel_info(msg)
403            else:
404                added += self.ca_verifier.add_new_channel_info(short_channel_id, msg)
405
406        self.update_counts()
407        self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
408
409    def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None:
410        try:
411            channel_info = ChannelInfo.from_msg(msg)
412        except IncompatibleOrInsaneFeatures:
413            return
414        channel_info = channel_info._replace(capacity_sat=capacity_sat)
415        with self.lock:
416            self._channels[channel_info.short_channel_id] = channel_info
417            self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
418            self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
419        self._update_num_policies_for_chan(channel_info.short_channel_id)
420        if 'raw' in msg:
421            self._db_save_channel(channel_info.short_channel_id, msg['raw'])
422
423    def policy_changed(self, old_policy: Policy, new_policy: Policy, verbose: bool) -> bool:
424        changed = False
425        if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
426            changed |= True
427            if verbose:
428                self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
429        if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
430            changed |= True
431            if verbose:
432                self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
433        if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
434            changed |= True
435            if verbose:
436                self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
437        if old_policy.fee_base_msat != new_policy.fee_base_msat:
438            changed |= True
439            if verbose:
440                self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
441        if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
442            changed |= True
443            if verbose:
444                self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
445        if old_policy.channel_flags != new_policy.channel_flags:
446            changed |= True
447            if verbose:
448                self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
449        if old_policy.message_flags != new_policy.message_flags:
450            changed |= True
451            if verbose:
452                self.logger.info(f'message_flags: {old_policy.message_flags} -> {new_policy.message_flags}')
453        if not changed and verbose:
454            self.logger.info(f'policy unchanged: {old_policy.timestamp} -> {new_policy.timestamp}')
455        return changed
456
457    def add_channel_update(
458            self, payload, *, max_age=None, verify=True, verbose=True) -> UpdateStatus:
459        now = int(time.time())
460        short_channel_id = ShortChannelID(payload['short_channel_id'])
461        timestamp = payload['timestamp']
462        if max_age and now - timestamp > max_age:
463            return UpdateStatus.EXPIRED
464        if timestamp - now > 60:
465            return UpdateStatus.DEPRECATED
466        channel_info = self._channels.get(short_channel_id)
467        if not channel_info:
468            return UpdateStatus.ORPHANED
469        flags = int.from_bytes(payload['channel_flags'], 'big')
470        direction = flags & FLAG_DIRECTION
471        start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
472        payload['start_node'] = start_node
473        # compare updates to existing database entries
474        short_channel_id = ShortChannelID(payload['short_channel_id'])
475        key = (start_node, short_channel_id)
476        old_policy = self._policies.get(key)
477        if old_policy and timestamp <= old_policy.timestamp + 60:
478            return UpdateStatus.DEPRECATED
479        if verify:
480            self.verify_channel_update(payload)
481        policy = Policy.from_msg(payload)
482        with self.lock:
483            self._policies[key] = policy
484        self._update_num_policies_for_chan(short_channel_id)
485        if 'raw' in payload:
486            self._db_save_policy(policy.key, payload['raw'])
487        if old_policy and not self.policy_changed(old_policy, policy, verbose):
488            return UpdateStatus.UNCHANGED
489        else:
490            return UpdateStatus.GOOD
491
492    def add_channel_updates(self, payloads, max_age=None) -> CategorizedChannelUpdates:
493        orphaned = []
494        expired = []
495        deprecated = []
496        unchanged = []
497        good = []
498        for payload in payloads:
499            r = self.add_channel_update(payload, max_age=max_age, verbose=False, verify=True)
500            if r == UpdateStatus.ORPHANED:
501                orphaned.append(payload)
502            elif r == UpdateStatus.EXPIRED:
503                expired.append(payload)
504            elif r == UpdateStatus.DEPRECATED:
505                deprecated.append(payload)
506            elif r == UpdateStatus.UNCHANGED:
507                unchanged.append(payload)
508            elif r == UpdateStatus.GOOD:
509                good.append(payload)
510        self.update_counts()
511        return CategorizedChannelUpdates(
512            orphaned=orphaned,
513            expired=expired,
514            deprecated=deprecated,
515            unchanged=unchanged,
516            good=good)
517
518
519    def create_database(self):
520        c = self.conn.cursor()
521        c.execute(create_node_info)
522        c.execute(create_address)
523        c.execute(create_policy)
524        c.execute(create_channel_info)
525        self.conn.commit()
526
527    @sql
528    def _db_save_policy(self, key: bytes, msg: bytes):
529        # 'msg' is a 'channel_update' message
530        c = self.conn.cursor()
531        c.execute("""REPLACE INTO policy (key, msg) VALUES (?,?)""", [key, msg])
532
533    @sql
534    def _db_delete_policy(self, node_id: bytes, short_channel_id: ShortChannelID):
535        key = short_channel_id + node_id
536        c = self.conn.cursor()
537        c.execute("""DELETE FROM policy WHERE key=?""", (key,))
538
539    @sql
540    def _db_save_channel(self, short_channel_id: ShortChannelID, msg: bytes):
541        # 'msg' is a 'channel_announcement' message
542        c = self.conn.cursor()
543        c.execute("REPLACE INTO channel_info (short_channel_id, msg) VALUES (?,?)", [short_channel_id, msg])
544
545    @sql
546    def _db_delete_channel(self, short_channel_id: ShortChannelID):
547        c = self.conn.cursor()
548        c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,))
549
550    @sql
551    def _db_save_node_info(self, node_id: bytes, msg: bytes):
552        # 'msg' is a 'node_announcement' message
553        c = self.conn.cursor()
554        c.execute("REPLACE INTO node_info (node_id, msg) VALUES (?,?)", [node_id, msg])
555
556    @sql
557    def _db_save_node_address(self, peer: LNPeerAddr, timestamp: int):
558        c = self.conn.cursor()
559        c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)",
560                  (peer.pubkey, peer.host, peer.port, timestamp))
561
562    @sql
563    def _db_save_node_addresses(self, node_addresses: Sequence[LNPeerAddr]):
564        c = self.conn.cursor()
565        for addr in node_addresses:
566            c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.pubkey, addr.host, addr.port))
567            r = c.fetchall()
568            if r == []:
569                c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.pubkey, addr.host, addr.port, 0))
570
571    @classmethod
572    def verify_channel_update(cls, payload, *, start_node: bytes = None) -> None:
573        short_channel_id = payload['short_channel_id']
574        short_channel_id = ShortChannelID(short_channel_id)
575        if constants.net.rev_genesis_bytes() != payload['chain_hash']:
576            raise InvalidGossipMsg('wrong chain hash')
577        start_node = payload.get('start_node', None) or start_node
578        assert start_node is not None
579        if not verify_sig_for_channel_update(payload, start_node):
580            raise InvalidGossipMsg(f'failed verifying channel update for {short_channel_id}')
581
582    @classmethod
583    def verify_channel_announcement(cls, payload) -> None:
584        h = sha256d(payload['raw'][2+256:])
585        pubkeys = [payload['node_id_1'], payload['node_id_2'], payload['bitcoin_key_1'], payload['bitcoin_key_2']]
586        sigs = [payload['node_signature_1'], payload['node_signature_2'], payload['bitcoin_signature_1'], payload['bitcoin_signature_2']]
587        for pubkey, sig in zip(pubkeys, sigs):
588            if not ecc.verify_signature(pubkey, sig, h):
589                raise InvalidGossipMsg('signature failed')
590
591    @classmethod
592    def verify_node_announcement(cls, payload) -> None:
593        pubkey = payload['node_id']
594        signature = payload['signature']
595        h = sha256d(payload['raw'][66:])
596        if not ecc.verify_signature(pubkey, signature, h):
597            raise InvalidGossipMsg('signature failed')
598
599    def add_node_announcements(self, msg_payloads):
600        # note: signatures have already been verified.
601        if type(msg_payloads) is dict:
602            msg_payloads = [msg_payloads]
603        new_nodes = {}
604        for msg_payload in msg_payloads:
605            try:
606                node_info, node_addresses = NodeInfo.from_msg(msg_payload)
607            except IncompatibleOrInsaneFeatures:
608                continue
609            node_id = node_info.node_id
610            # Ignore node if it has no associated channel (DoS protection)
611            if node_id not in self._channels_for_node:
612                #self.logger.info('ignoring orphan node_announcement')
613                continue
614            node = self._nodes.get(node_id)
615            if node and node.timestamp >= node_info.timestamp:
616                continue
617            node = new_nodes.get(node_id)
618            if node and node.timestamp >= node_info.timestamp:
619                continue
620            # save
621            with self.lock:
622                self._nodes[node_id] = node_info
623            if 'raw' in msg_payload:
624                self._db_save_node_info(node_id, msg_payload['raw'])
625            with self.lock:
626                for addr in node_addresses:
627                    net_addr = NetAddress(addr.host, addr.port)
628                    self._addresses[node_id][net_addr] = self._addresses[node_id].get(net_addr) or 0
629            self._db_save_node_addresses(node_addresses)
630
631        self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
632        self.update_counts()
633
634    def get_old_policies(self, delta) -> Sequence[Tuple[bytes, ShortChannelID]]:
635        with self.lock:
636            _policies = self._policies.copy()
637        now = int(time.time())
638        return list(k for k, v in _policies.items() if v.timestamp <= now - delta)
639
640    def prune_old_policies(self, delta):
641        old_policies = self.get_old_policies(delta)
642        if old_policies:
643            for key in old_policies:
644                node_id, scid = key
645                with self.lock:
646                    self._policies.pop(key)
647                self._db_delete_policy(*key)
648                self._update_num_policies_for_chan(scid)
649            self.update_counts()
650            self.logger.info(f'Deleting {len(old_policies)} old policies')
651
652    def prune_orphaned_channels(self):
653        with self.lock:
654            orphaned_chans = self._chans_with_0_policies.copy()
655        if orphaned_chans:
656            for short_channel_id in orphaned_chans:
657                self.remove_channel(short_channel_id)
658            self.update_counts()
659            self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels')
660
661    def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes) -> bool:
662        """Returns True iff the channel update was successfully added and it was different than
663        what we had before (if any).
664        """
665        if not verify_sig_for_channel_update(msg_payload, start_node_id):
666            return False  # ignore
667        short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
668        msg_payload['start_node'] = start_node_id
669        key = (start_node_id, short_channel_id)
670        prev_chanupd = self._channel_updates_for_private_channels.get(key)
671        if prev_chanupd == msg_payload:
672            return False
673        self._channel_updates_for_private_channels[key] = msg_payload
674        return True
675
676    def remove_channel(self, short_channel_id: ShortChannelID):
677        # FIXME what about rm-ing policies?
678        with self.lock:
679            channel_info = self._channels.pop(short_channel_id, None)
680            if channel_info:
681                self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
682                self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
683        self._update_num_policies_for_chan(short_channel_id)
684        # delete from database
685        self._db_delete_channel(short_channel_id)
686
687    def get_node_addresses(self, node_id: bytes) -> Sequence[Tuple[str, int, int]]:
688        """Returns list of (host, port, timestamp)."""
689        addr_to_ts = self._addresses.get(node_id)
690        if not addr_to_ts:
691            return []
692        return [(str(net_addr.host), net_addr.port, ts)
693                for net_addr, ts in addr_to_ts.items()]
694
695    @sql
696    @profiler
697    def load_data(self):
698        if self.data_loaded.is_set():
699            return
700        # Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
701        c = self.conn.cursor()
702        c.execute("""SELECT * FROM address""")
703        for x in c:
704            node_id, host, port, timestamp = x
705            try:
706                net_addr = NetAddress(host, port)
707            except Exception:
708                continue
709            self._addresses[node_id][net_addr] = int(timestamp or 0)
710        def newest_ts_for_node_id(node_id):
711            newest_ts = 0
712            for addr, ts in self._addresses[node_id].items():
713                newest_ts = max(newest_ts, ts)
714            return newest_ts
715        sorted_node_ids = sorted(self._addresses.keys(), key=newest_ts_for_node_id, reverse=True)
716        self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
717        c.execute("""SELECT * FROM channel_info""")
718        for short_channel_id, msg in c:
719            try:
720                ci = ChannelInfo.from_raw_msg(msg)
721            except IncompatibleOrInsaneFeatures:
722                continue
723            self._channels[ShortChannelID.normalize(short_channel_id)] = ci
724        c.execute("""SELECT * FROM node_info""")
725        for node_id, msg in c:
726            try:
727                node_info, node_addresses = NodeInfo.from_raw_msg(msg)
728            except IncompatibleOrInsaneFeatures:
729                continue
730            # don't load node_addresses because they dont have timestamps
731            self._nodes[node_id] = node_info
732        c.execute("""SELECT * FROM policy""")
733        for key, msg in c:
734            p = Policy.from_raw_msg(key, msg)
735            self._policies[(p.start_node, p.short_channel_id)] = p
736        for channel_info in self._channels.values():
737            self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
738            self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
739            self._update_num_policies_for_chan(channel_info.short_channel_id)
740        self.logger.info(f'data loaded. {len(self._channels)} chans. {len(self._policies)} policies. '
741                         f'{len(self._channels_for_node)} nodes.')
742        self.update_counts()
743        (nchans_with_0p, nchans_with_1p, nchans_with_2p) = self.get_num_channels_partitioned_by_policy_count()
744        self.logger.info(f'num_channels_partitioned_by_policy_count. '
745                         f'0p: {nchans_with_0p}, 1p: {nchans_with_1p}, 2p: {nchans_with_2p}')
746        self.asyncio_loop.call_soon_threadsafe(self.data_loaded.set)
747        util.trigger_callback('gossip_db_loaded')
748
749    def _update_num_policies_for_chan(self, short_channel_id: ShortChannelID) -> None:
750        channel_info = self.get_channel_info(short_channel_id)
751        if channel_info is None:
752            with self.lock:
753                self._chans_with_0_policies.discard(short_channel_id)
754                self._chans_with_1_policies.discard(short_channel_id)
755                self._chans_with_2_policies.discard(short_channel_id)
756            return
757        p1 = self.get_policy_for_node(short_channel_id, channel_info.node1_id)
758        p2 = self.get_policy_for_node(short_channel_id, channel_info.node2_id)
759        with self.lock:
760            self._chans_with_0_policies.discard(short_channel_id)
761            self._chans_with_1_policies.discard(short_channel_id)
762            self._chans_with_2_policies.discard(short_channel_id)
763            if p1 is not None and p2 is not None:
764                self._chans_with_2_policies.add(short_channel_id)
765            elif p1 is None and p2 is None:
766                self._chans_with_0_policies.add(short_channel_id)
767            else:
768                self._chans_with_1_policies.add(short_channel_id)
769
770    def get_num_channels_partitioned_by_policy_count(self) -> Tuple[int, int, int]:
771        nchans_with_0p = len(self._chans_with_0_policies)
772        nchans_with_1p = len(self._chans_with_1_policies)
773        nchans_with_2p = len(self._chans_with_2_policies)
774        return nchans_with_0p, nchans_with_1p, nchans_with_2p
775
776    def get_policy_for_node(
777            self,
778            short_channel_id: bytes,
779            node_id: bytes,
780            *,
781            my_channels: Dict[ShortChannelID, 'Channel'] = None,
782            private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
783    ) -> Optional['Policy']:
784        channel_info = self.get_channel_info(short_channel_id)
785        if channel_info is not None:  # publicly announced channel
786            policy = self._policies.get((node_id, short_channel_id))
787            if policy:
788                return policy
789        else:  # private channel
790            chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
791            if chan_upd_dict:
792                return Policy.from_msg(chan_upd_dict)
793        # check if it's one of our own channels
794        if my_channels:
795            policy = get_mychannel_policy(short_channel_id, node_id, my_channels)
796            if policy:
797                return policy
798        if private_route_edges:
799            route_edge = private_route_edges.get(short_channel_id, None)
800            if route_edge:
801                return Policy.from_route_edge(route_edge)
802
803    def get_channel_info(
804            self,
805            short_channel_id: ShortChannelID,
806            *,
807            my_channels: Dict[ShortChannelID, 'Channel'] = None,
808            private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
809    ) -> Optional[ChannelInfo]:
810        ret = self._channels.get(short_channel_id)
811        if ret:
812            return ret
813        # check if it's one of our own channels
814        if my_channels:
815            channel_info = get_mychannel_info(short_channel_id, my_channels)
816            if channel_info:
817                return channel_info
818        if private_route_edges:
819            route_edge = private_route_edges.get(short_channel_id)
820            if route_edge:
821                return ChannelInfo.from_route_edge(route_edge)
822
823    def get_channels_for_node(
824            self,
825            node_id: bytes,
826            *,
827            my_channels: Dict[ShortChannelID, 'Channel'] = None,
828            private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
829    ) -> Set[bytes]:
830        """Returns the set of short channel IDs where node_id is one of the channel participants."""
831        if not self.data_loaded.is_set():
832            raise Exception("channelDB data not loaded yet!")
833        relevant_channels = self._channels_for_node.get(node_id) or set()
834        relevant_channels = set(relevant_channels)  # copy
835        # add our own channels  # TODO maybe slow?
836        if my_channels:
837            for chan in my_channels.values():
838                if node_id in (chan.node_id, chan.get_local_pubkey()):
839                    relevant_channels.add(chan.short_channel_id)
840        # add private channels  # TODO maybe slow?
841        if private_route_edges:
842            for route_edge in private_route_edges.values():
843                if node_id in (route_edge.start_node, route_edge.end_node):
844                    relevant_channels.add(route_edge.short_channel_id)
845        return relevant_channels
846
847    def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *,
848                              my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[Tuple[bytes, bytes]]:
849        channel_info = self.get_channel_info(short_channel_id)
850        if channel_info is not None:  # publicly announced channel
851            return channel_info.node1_id, channel_info.node2_id
852        # check if it's one of our own channels
853        if not my_channels:
854            return
855        chan = my_channels.get(short_channel_id)  # type: Optional[Channel]
856        if not chan:
857            return
858        return chan.get_local_pubkey(), chan.node_id
859
860    def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']:
861        return self._nodes.get(node_id)
862
863    def get_node_infos(self) -> Dict[bytes, NodeInfo]:
864        with self.lock:
865            return self._nodes.copy()
866
867    def get_node_policies(self) -> Dict[Tuple[bytes, ShortChannelID], Policy]:
868        with self.lock:
869            return self._policies.copy()
870
871    def get_node_by_prefix(self, prefix):
872        with self.lock:
873            for k in self._addresses.keys():
874                if k.startswith(prefix):
875                    return k
876        raise Exception('node not found')
877
878    def to_dict(self) -> dict:
879        """ Generates a graph representation in terms of a dictionary.
880
881        The dictionary contains only native python types and can be encoded
882        to json.
883        """
884        with self.lock:
885            graph = {'nodes': [], 'channels': []}
886
887            # gather nodes
888            for pk, nodeinfo in self._nodes.items():
889                # use _asdict() to convert NamedTuples to json encodable dicts
890                graph['nodes'].append(
891                    nodeinfo._asdict(),
892                )
893                graph['nodes'][-1]['addresses'] = [
894                    {'host': str(addr.host), 'port': addr.port, 'timestamp': ts}
895                    for addr, ts in self._addresses[pk].items()
896                ]
897
898            # gather channels
899            for cid, channelinfo in self._channels.items():
900                graph['channels'].append(
901                    channelinfo._asdict(),
902                )
903                policy1 = self._policies.get(
904                    (channelinfo.node1_id, channelinfo.short_channel_id))
905                policy2 = self._policies.get(
906                    (channelinfo.node2_id, channelinfo.short_channel_id))
907                graph['channels'][-1]['policy1'] = policy1._asdict() if policy1 else None
908                graph['channels'][-1]['policy2'] = policy2._asdict() if policy2 else None
909
910        # need to use json_normalize otherwise json encoding in rpc server fails
911        graph = json_normalize(graph)
912        return graph
913