1#! /usr/bin/python3
2
3from pyln.spec.bolt7 import (channel_announcement, channel_update,
4                             node_announcement)
5from pyln.proto import ShortChannelId, PublicKey
6from typing import Any, Dict, List, Optional, Union, cast
7
8import io
9import struct
10
11# These duplicate constants in lightning/common/gossip_store.h
12GOSSIP_STORE_VERSION = 9
13GOSSIP_STORE_LEN_DELETED_BIT = 0x80000000
14GOSSIP_STORE_LEN_PUSH_BIT = 0x40000000
15GOSSIP_STORE_LEN_MASK = (~(GOSSIP_STORE_LEN_PUSH_BIT
16                           | GOSSIP_STORE_LEN_DELETED_BIT))
17
18# These duplicate constants in lightning/gossipd/gossip_store_wiregen.h
19WIRE_GOSSIP_STORE_PRIVATE_CHANNEL = 4104
20WIRE_GOSSIP_STORE_PRIVATE_UPDATE = 4102
21WIRE_GOSSIP_STORE_DELETE_CHAN = 4103
22WIRE_GOSSIP_STORE_ENDED = 4105
23WIRE_GOSSIP_STORE_CHANNEL_AMOUNT = 4101
24
25
26class GossipStoreHeader(object):
27    def __init__(self, buf: bytes):
28        length, self.crc, self.timestamp = struct.unpack('>III', buf)
29        self.deleted = (length & GOSSIP_STORE_LEN_DELETED_BIT) != 0
30        self.length = (length & GOSSIP_STORE_LEN_MASK)
31
32
33class GossmapHalfchannel(object):
34    """One direction of a GossmapChannel."""
35    def __init__(self, channel: 'GossmapChannel', direction: int,
36                 timestamp: int, cltv_expiry_delta: int,
37                 htlc_minimum_msat: int, htlc_maximum_msat: int,
38                 fee_base_msat: int, fee_proportional_millionths: int):
39
40        self.channel = channel
41        self.direction = direction
42        self.source = channel.node1 if direction == 0 else channel.node2
43        self.destination = channel.node2 if direction == 0 else channel.node1
44
45        self.timestamp: int = timestamp
46        self.cltv_expiry_delta: int = cltv_expiry_delta
47        self.htlc_minimum_msat: int = htlc_minimum_msat
48        self.htlc_maximum_msat: Optional[int] = htlc_maximum_msat
49        self.fee_base_msat: int = fee_base_msat
50        self.fee_proportional_millionths: int = fee_proportional_millionths
51
52    def __repr__(self):
53        return "GossmapHalfchannel[{}x{}]".format(str(self.channel.scid), self.direction)
54
55
56class GossmapNodeId(object):
57    def __init__(self, buf: Union[bytes, str]):
58        if isinstance(buf, str):
59            buf = bytes.fromhex(buf)
60        if len(buf) != 33 or (buf[0] != 2 and buf[0] != 3):
61            raise ValueError("{} is not a valid node_id".format(buf.hex()))
62        self.nodeid = buf
63
64    def to_pubkey(self) -> PublicKey:
65        return PublicKey(self.nodeid)
66
67    def __eq__(self, other):
68        if not isinstance(other, GossmapNodeId):
69            return False
70        return self.nodeid.__eq__(other.nodeid)
71
72    def __lt__(self, other):
73        if not isinstance(other, GossmapNodeId):
74            raise ValueError(f"Cannot compare GossmapNodeId with {type(other)}")
75        return self.nodeid.__lt__(other.nodeid)  # yes, that works
76
77    def __hash__(self):
78        return self.nodeid.__hash__()
79
80    def __repr__(self):
81        return "GossmapNodeId[{}]".format(self.nodeid.hex())
82
83    @classmethod
84    def from_str(cls, s: str):
85        if s.startswith('0x'):
86            s = s[2:]
87        if len(s) != 66:
88            raise ValueError(f"{s} is not a valid hexstring of a node_id")
89        return cls(bytes.fromhex(s))
90
91
92class GossmapChannel(object):
93    """A channel: fields of channel_announcement are in .fields, optional updates are in .updates_fields, which can be None if there has been no channel update."""
94    def __init__(self,
95                 fields: Dict[str, Any],
96                 announce_offset: int,
97                 scid,
98                 node1: 'GossmapNode',
99                 node2: 'GossmapNode',
100                 is_private: bool):
101        self.fields = fields
102        self.announce_offset = announce_offset
103        self.is_private = is_private
104        self.scid = scid
105        self.node1 = node1
106        self.node2 = node2
107        self.updates_fields: List[Optional[Dict[str, Any]]] = [None, None]
108        self.updates_offset: List[Optional[int]] = [None, None]
109        self.satoshis = None
110        self.half_channels: List[Optional[GossmapHalfchannel]] = [None, None]
111
112    def _update_channel(self,
113                        direction: int,
114                        fields: Dict[str, Any],
115                        off: int):
116        self.updates_fields[direction] = fields
117        self.updates_offset[direction] = off
118
119        half = GossmapHalfchannel(self, direction,
120                                  fields['timestamp'],
121                                  fields['cltv_expiry_delta'],
122                                  fields['htlc_minimum_msat'],
123                                  fields.get('htlc_maximum_msat', None),
124                                  fields['fee_base_msat'],
125                                  fields['fee_proportional_millionths'])
126        self.half_channels[direction] = half
127
128    def get_direction(self, direction: int):
129        """ returns the GossmapHalfchannel if known by channel_update """
130        if not 0 <= direction <= 1:
131            raise ValueError("direction can only be 0 or 1")
132        return self.half_channels[direction]
133
134    def __repr__(self):
135        return "GossmapChannel[{}]".format(str(self.scid))
136
137
138class GossmapNode(object):
139    """A node: fields of node_announcement are in .announce_fields, which can be None of there has been no node announcement.
140
141.channels is a list of the GossmapChannels attached to this node.
142"""
143    def __init__(self, node_id: Union[GossmapNodeId, bytes, str]):
144        if isinstance(node_id, bytes) or isinstance(node_id, str):
145            node_id = GossmapNodeId(node_id)
146        self.announce_fields: Optional[Dict[str, Any]] = None
147        self.announce_offset: Optional[int] = None
148        self.channels: List[GossmapChannel] = []
149        self.node_id = node_id
150
151    def __repr__(self):
152        return "GossmapNode[{}]".format(self.node_id.nodeid.hex())
153
154    def __eq__(self, other):
155        if not isinstance(other, GossmapNode):
156            return False
157        return self.node_id.__eq__(other.node_id)
158
159    def __lt__(self, other):
160        if not isinstance(other, GossmapNode):
161            raise ValueError(f"Cannot compare GossmapNode with {type(other)}")
162        return self.node_id.__lt__(other.node_id)
163
164
165class Gossmap(object):
166    """Class to represent the gossip map of the network"""
167    def __init__(self, store_filename: str = "gossip_store"):
168        self.store_filename = store_filename
169        self.store_file = open(store_filename, "rb")
170        self.store_buf = bytes()
171        self.nodes: Dict[GossmapNodeId, GossmapNode] = {}
172        self.channels: Dict[ShortChannelId, GossmapChannel] = {}
173        self._last_scid: Optional[str] = None
174        version = self.store_file.read(1)
175        if version[0] != GOSSIP_STORE_VERSION:
176            raise ValueError("Invalid gossip store version {}".format(int(version)))
177        self.bytes_read = 1
178        self.refresh()
179
180    def _new_channel(self,
181                     fields: Dict[str, Any],
182                     announce_offset: int,
183                     scid: ShortChannelId,
184                     node1: GossmapNode,
185                     node2: GossmapNode,
186                     is_private: bool):
187        c = GossmapChannel(fields, announce_offset,
188                           scid, node1, node2,
189                           is_private)
190        self._last_scid = scid
191        self.channels[scid] = c
192        node1.channels.append(c)
193        node2.channels.append(c)
194
195    def _del_channel(self, scid: ShortChannelId):
196        c = self.channels[scid]
197        del self.channels[scid]
198        c.node1.channels.remove(c)
199        c.node2.channels.remove(c)
200        # Beware self-channels n1-n1!
201        if len(c.node1.channels) == 0 and c.node1 != c.node2:
202            del self.nodes[c.node1.node_id]
203        if len(c.node2.channels) == 0:
204            del self.nodes[c.node2.node_id]
205
206    def _add_channel(self, rec: bytes, off: int, is_private: bool):
207        fields = channel_announcement.read(io.BytesIO(rec[2:]), {})
208        # Add nodes one the fly
209        node1_id = GossmapNodeId(fields['node_id_1'])
210        node2_id = GossmapNodeId(fields['node_id_2'])
211        if node1_id not in self.nodes:
212            self.nodes[node1_id] = GossmapNode(node1_id)
213        if node2_id not in self.nodes:
214            self.nodes[node2_id] = GossmapNode(node2_id)
215        self._new_channel(fields, off,
216                          ShortChannelId.from_int(fields['short_channel_id']),
217                          self.get_node(node1_id), self.get_node(node2_id),
218                          is_private)
219
220    def _set_channel_amount(self, rec: bytes):
221        """ Sets channel capacity of last added channel """
222        sats, = struct.unpack(">Q", rec[2:])
223        self.channels[self._last_scid].satoshis = sats
224
225    def get_channel(self, short_channel_id: ShortChannelId):
226        """ Resolves a channel by its short channel id """
227        if isinstance(short_channel_id, str):
228            short_channel_id = ShortChannelId.from_str(short_channel_id)
229        return self.channels.get(short_channel_id)
230
231    def get_node(self, node_id: Union[GossmapNodeId, str]):
232        """ Resolves a node by its public key node_id """
233        if isinstance(node_id, str):
234            node_id = GossmapNodeId.from_str(node_id)
235        return self.nodes.get(cast(GossmapNodeId, node_id))
236
237    def _update_channel(self, rec: bytes, off: int):
238        fields = channel_update.read(io.BytesIO(rec[2:]), {})
239        direction = fields['channel_flags'] & 1
240        c = self.channels[ShortChannelId.from_int(fields['short_channel_id'])]
241        c._update_channel(direction, fields, off)
242
243    def _add_node_announcement(self, rec: bytes, off: int):
244        fields = node_announcement.read(io.BytesIO(rec[2:]), {})
245        node_id = GossmapNodeId(fields['node_id'])
246        self.nodes[node_id].announce_fields = fields
247        self.nodes[node_id].announce_offset = off
248
249    def reopen_store(self):
250        """FIXME: Implement!"""
251        assert False
252
253    def _remove_channel_by_deletemsg(self, rec: bytes):
254        scidint, = struct.unpack(">Q", rec[2:])
255        scid = ShortChannelId.from_int(scidint)
256        # It might have already been deleted when we skipped it.
257        if scid in self.channels:
258            self._del_channel(scid)
259
260    def _pull_bytes(self, length: int) -> bool:
261        """Pull bytes from file into our internal buffer"""
262        if len(self.store_buf) < length:
263            self.store_buf += self.store_file.read(length
264                                                   - len(self.store_buf))
265        return len(self.store_buf) >= length
266
267    def _read_record(self) -> Optional[bytes]:
268        """If a whole record is not in the file, returns None.
269        If deleted, returns empty."""
270        if not self._pull_bytes(12):
271            return None
272        hdr = GossipStoreHeader(self.store_buf[:12])
273        if not self._pull_bytes(12 + hdr.length):
274            return None
275        self.bytes_read += len(self.store_buf)
276        ret = self.store_buf[12:]
277        self.store_buf = bytes()
278        if hdr.deleted:
279            ret = bytes()
280        return ret
281
282    def refresh(self):
283        """Catch up with any changes to the gossip store"""
284        while True:
285            off = self.bytes_read
286            rec = self._read_record()
287            # EOF?
288            if rec is None:
289                break
290            # Deleted?
291            if len(rec) == 0:
292                continue
293
294            rectype, = struct.unpack(">H", rec[:2])
295            if rectype == channel_announcement.number:
296                self._add_channel(rec, off, False)
297            elif rectype == WIRE_GOSSIP_STORE_PRIVATE_CHANNEL:
298                self._add_channel(rec[2 + 8 + 2:], off + 2 + 8 + 2, True)
299            elif rectype == WIRE_GOSSIP_STORE_CHANNEL_AMOUNT:
300                self._set_channel_amount(rec)
301            elif rectype == channel_update.number:
302                self._update_channel(rec, off)
303            elif rectype == WIRE_GOSSIP_STORE_PRIVATE_UPDATE:
304                self._update_channel(rec[2 + 2:], off + 2 + 2)
305            elif rectype == WIRE_GOSSIP_STORE_DELETE_CHAN:
306                self._remove_channel_by_deletemsg(rec)
307            elif rectype == node_announcement.number:
308                self._add_node_announcement(rec, off)
309            elif rectype == WIRE_GOSSIP_STORE_ENDED:
310                self.reopen_store()
311            else:
312                continue
313