1# -*- coding: utf-8 -*-
2#
3# Electrum - lightweight Bitcoin client
4# Copyright (C) 2018 The Electrum developers
5#
6# Permission is hereby granted, free of charge, to any person
7# obtaining a copy of this software and associated documentation files
8# (the "Software"), to deal in the Software without restriction,
9# including without limitation the rights to use, copy, modify, merge,
10# publish, distribute, sublicense, and/or sell copies of the Software,
11# and to permit persons to whom the Software is furnished to do so,
12# subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be
15# included in all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24# SOFTWARE.
25
26import queue
27from collections import defaultdict
28from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
29import time
30from threading import RLock
31import attr
32from math import inf
33
34from .util import bh2u, profiler, with_lock
35from .logging import Logger
36from .lnutil import (NUM_MAX_EDGES_IN_PAYMENT_PATH, ShortChannelID, LnFeatures,
37                     NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE)
38from .channel_db import ChannelDB, Policy, NodeInfo
39
40if TYPE_CHECKING:
41    from .lnchannel import Channel
42
43DEFAULT_PENALTY_BASE_MSAT = 500  # how much base fee we apply for unknown sending capability of a channel
44DEFAULT_PENALTY_PROPORTIONAL_MILLIONTH = 100  # how much relative fee we apply for unknown sending capability of a channel
45BLACKLIST_DURATION = 3600  # how long (in seconds) a channel remains blacklisted
46HINT_DURATION = 3600  # how long (in seconds) a liquidity hint remains valid
47
48
49class NoChannelPolicy(Exception):
50    def __init__(self, short_channel_id: bytes):
51        short_channel_id = ShortChannelID.normalize(short_channel_id)
52        super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}')
53
54
55class LNPathInconsistent(Exception): pass
56
57
58def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int:
59    return fee_base_msat \
60           + (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
61
62
63@attr.s(slots=True)
64class PathEdge:
65    start_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
66    end_node = attr.ib(type=bytes, kw_only=True, repr=lambda val: val.hex())
67    short_channel_id = attr.ib(type=ShortChannelID, kw_only=True, repr=lambda val: str(val))
68
69    @property
70    def node_id(self) -> bytes:
71        # legacy compat  # TODO rm
72        return self.end_node
73
74@attr.s
75class RouteEdge(PathEdge):
76    fee_base_msat = attr.ib(type=int, kw_only=True)
77    fee_proportional_millionths = attr.ib(type=int, kw_only=True)
78    cltv_expiry_delta = attr.ib(type=int, kw_only=True)
79    node_features = attr.ib(type=int, kw_only=True, repr=lambda val: str(int(val)))  # note: for end node!
80
81    def fee_for_edge(self, amount_msat: int) -> int:
82        return fee_for_edge_msat(forwarded_amount_msat=amount_msat,
83                                 fee_base_msat=self.fee_base_msat,
84                                 fee_proportional_millionths=self.fee_proportional_millionths)
85
86    @classmethod
87    def from_channel_policy(
88            cls,
89            *,
90            channel_policy: 'Policy',
91            short_channel_id: bytes,
92            start_node: bytes,
93            end_node: bytes,
94            node_info: Optional[NodeInfo],  # for end_node
95    ) -> 'RouteEdge':
96        assert isinstance(short_channel_id, bytes)
97        assert type(start_node) is bytes
98        assert type(end_node) is bytes
99        return RouteEdge(
100            start_node=start_node,
101            end_node=end_node,
102            short_channel_id=ShortChannelID.normalize(short_channel_id),
103            fee_base_msat=channel_policy.fee_base_msat,
104            fee_proportional_millionths=channel_policy.fee_proportional_millionths,
105            cltv_expiry_delta=channel_policy.cltv_expiry_delta,
106            node_features=node_info.features if node_info else 0)
107
108    def is_sane_to_use(self, amount_msat: int) -> bool:
109        # TODO revise ad-hoc heuristics
110        # cltv cannot be more than 2 weeks
111        if self.cltv_expiry_delta > 14 * 144:
112            return False
113        total_fee = self.fee_for_edge(amount_msat)
114        if not is_fee_sane(total_fee, payment_amount_msat=amount_msat):
115            return False
116        return True
117
118    def has_feature_varonion(self) -> bool:
119        features = LnFeatures(self.node_features)
120        return features.supports(LnFeatures.VAR_ONION_OPT)
121
122    def is_trampoline(self) -> bool:
123        return False
124
125@attr.s
126class TrampolineEdge(RouteEdge):
127    invoice_routing_info = attr.ib(type=bytes, default=None)
128    invoice_features = attr.ib(type=int, default=None)
129    # this is re-defined from parent just to specify a default value:
130    short_channel_id = attr.ib(default=ShortChannelID(8), repr=lambda val: str(val))
131
132    def is_trampoline(self):
133        return True
134
135
136LNPaymentPath = Sequence[PathEdge]
137LNPaymentRoute = Sequence[RouteEdge]
138
139
140def is_route_sane_to_use(route: LNPaymentRoute, invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool:
141    """Run some sanity checks on the whole route, before attempting to use it.
142    called when we are paying; so e.g. lower cltv is better
143    """
144    if len(route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
145        return False
146    amt = invoice_amount_msat
147    cltv = min_final_cltv_expiry
148    for route_edge in reversed(route[1:]):
149        if not route_edge.is_sane_to_use(amt): return False
150        amt += route_edge.fee_for_edge(amt)
151        cltv += route_edge.cltv_expiry_delta
152    total_fee = amt - invoice_amount_msat
153    # TODO revise ad-hoc heuristics
154    if cltv > NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE:
155        return False
156    if not is_fee_sane(total_fee, payment_amount_msat=invoice_amount_msat):
157        return False
158    return True
159
160
161def is_fee_sane(fee_msat: int, *, payment_amount_msat: int) -> bool:
162    # fees <= 5 sat are fine
163    if fee_msat <= 5_000:
164        return True
165    # fees <= 1 % of payment are fine
166    if 100 * fee_msat <= payment_amount_msat:
167        return True
168    return False
169
170
171class LiquidityHint:
172    """Encodes the amounts that can and cannot be sent over the direction of a
173    channel and whether the channel is blacklisted.
174
175    A LiquidityHint is the value of a dict, which is keyed to node ids and the
176    channel.
177    """
178    def __init__(self):
179        # use "can_send_forward + can_send_backward < cannot_send_forward + cannot_send_backward" as a sanity check?
180        self._can_send_forward = None
181        self._cannot_send_forward = None
182        self._can_send_backward = None
183        self._cannot_send_backward = None
184        self.blacklist_timestamp = 0
185        self.hint_timestamp = 0
186        self._inflight_htlcs_forward = 0
187        self._inflight_htlcs_backward = 0
188
189    def is_hint_invalid(self) -> bool:
190        now = int(time.time())
191        return now - self.hint_timestamp > HINT_DURATION
192
193    @property
194    def can_send_forward(self):
195        return None if self.is_hint_invalid() else self._can_send_forward
196
197    @can_send_forward.setter
198    def can_send_forward(self, amount):
199        # we don't want to record less significant info
200        # (sendable amount is lower than known sendable amount):
201        if self._can_send_forward and self._can_send_forward > amount:
202            return
203        self._can_send_forward = amount
204        # we make a sanity check that sendable amount is lower than not sendable amount
205        if self._cannot_send_forward and self._can_send_forward > self._cannot_send_forward:
206            self._cannot_send_forward = None
207
208    @property
209    def can_send_backward(self):
210        return None if self.is_hint_invalid() else self._can_send_backward
211
212    @can_send_backward.setter
213    def can_send_backward(self, amount):
214        if self._can_send_backward and self._can_send_backward > amount:
215            return
216        self._can_send_backward = amount
217        if self._cannot_send_backward and self._can_send_backward > self._cannot_send_backward:
218            self._cannot_send_backward = None
219
220    @property
221    def cannot_send_forward(self):
222        return None if self.is_hint_invalid() else self._cannot_send_forward
223
224    @cannot_send_forward.setter
225    def cannot_send_forward(self, amount):
226        # we don't want to record less significant info
227        # (not sendable amount is higher than known not sendable amount):
228        if self._cannot_send_forward and self._cannot_send_forward < amount:
229            return
230        self._cannot_send_forward = amount
231        if self._can_send_forward and self._can_send_forward > self._cannot_send_forward:
232            self._can_send_forward = None
233        # if we can't send over the channel, we should be able to send in the
234        # reverse direction
235        self.can_send_backward = amount
236
237    @property
238    def cannot_send_backward(self):
239        return None if self.is_hint_invalid() else self._cannot_send_backward
240
241    @cannot_send_backward.setter
242    def cannot_send_backward(self, amount):
243        if self._cannot_send_backward and self._cannot_send_backward < amount:
244            return
245        self._cannot_send_backward = amount
246        if self._can_send_backward and self._can_send_backward > self._cannot_send_backward:
247            self._can_send_backward = None
248        self.can_send_forward = amount
249
250    def can_send(self, is_forward_direction: bool):
251        # make info invalid after some time?
252        if is_forward_direction:
253            return self.can_send_forward
254        else:
255            return self.can_send_backward
256
257    def cannot_send(self, is_forward_direction: bool):
258        # make info invalid after some time?
259        if is_forward_direction:
260            return self.cannot_send_forward
261        else:
262            return self.cannot_send_backward
263
264    def update_can_send(self, is_forward_direction: bool, amount: int):
265        self.hint_timestamp = int(time.time())
266        if is_forward_direction:
267            self.can_send_forward = amount
268        else:
269            self.can_send_backward = amount
270
271    def update_cannot_send(self, is_forward_direction: bool, amount: int):
272        self.hint_timestamp = int(time.time())
273        if is_forward_direction:
274            self.cannot_send_forward = amount
275        else:
276            self.cannot_send_backward = amount
277
278    def num_inflight_htlcs(self, is_forward_direction: bool) -> int:
279        if is_forward_direction:
280            return self._inflight_htlcs_forward
281        else:
282            return self._inflight_htlcs_backward
283
284    def add_htlc(self, is_forward_direction: bool):
285        if is_forward_direction:
286            self._inflight_htlcs_forward += 1
287        else:
288            self._inflight_htlcs_backward += 1
289
290    def remove_htlc(self, is_forward_direction: bool):
291        if is_forward_direction:
292            self._inflight_htlcs_forward = max(0, self._inflight_htlcs_forward - 1)
293        else:
294            self._inflight_htlcs_backward = max(0, self._inflight_htlcs_forward - 1)
295
296    def __repr__(self):
297        is_blacklisted = False if not self.blacklist_timestamp else int(time.time()) - self.blacklist_timestamp < BLACKLIST_DURATION
298        return f"forward: can send: {self._can_send_forward} msat, cannot send: {self._cannot_send_forward} msat, htlcs: {self._inflight_htlcs_forward}\n" \
299               f"backward: can send: {self._can_send_backward} msat, cannot send: {self._cannot_send_backward} msat, htlcs: {self._inflight_htlcs_backward}\n" \
300               f"blacklisted: {is_blacklisted}"
301
302
303class LiquidityHintMgr:
304    """Implements liquidity hints for channels in the graph.
305
306    This class can be used to update liquidity information about channels in the
307    graph. Implements a penalty function for edge weighting in the pathfinding
308    algorithm that favors channels which can route payments and penalizes
309    channels that cannot.
310    """
311    # TODO: hints based on node pairs only (shadow channels, non-strict forwarding)?
312    def __init__(self):
313        self.lock = RLock()
314        self._liquidity_hints: Dict[ShortChannelID, LiquidityHint] = {}
315
316    @with_lock
317    def get_hint(self, channel_id: ShortChannelID) -> LiquidityHint:
318        hint = self._liquidity_hints.get(channel_id)
319        if not hint:
320            hint = LiquidityHint()
321            self._liquidity_hints[channel_id] = hint
322        return hint
323
324    @with_lock
325    def update_can_send(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID, amount: int):
326        hint = self.get_hint(channel_id)
327        hint.update_can_send(node_from < node_to, amount)
328
329    @with_lock
330    def update_cannot_send(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID, amount: int):
331        hint = self.get_hint(channel_id)
332        hint.update_cannot_send(node_from < node_to, amount)
333
334    @with_lock
335    def add_htlc(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID):
336        hint = self.get_hint(channel_id)
337        hint.add_htlc(node_from < node_to)
338
339    @with_lock
340    def remove_htlc(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID):
341        hint = self.get_hint(channel_id)
342        hint.remove_htlc(node_from < node_to)
343
344    def penalty(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID, amount: int) -> float:
345        """Gives a penalty when sending from node1 to node2 over channel_id with an
346        amount in units of millisatoshi.
347
348        The penalty depends on the can_send and cannot_send values that was
349        possibly recorded in previous payment attempts.
350
351        A channel that can send an amount is assigned a penalty of zero, a
352        channel that cannot send an amount is assigned an infinite penalty.
353        If the sending amount lies between can_send and cannot_send, there's
354        uncertainty and we give a default penalty. The default penalty
355        serves the function of giving a positive offset (the Dijkstra
356        algorithm doesn't work with negative weights), from which we can discount
357        from. There is a competition between low-fee channels and channels where
358        we know with some certainty that they can support a payment. The penalty
359        ultimately boils down to: how much more fees do we want to pay for
360        certainty of payment success? This can be tuned via DEFAULT_PENALTY_BASE_MSAT
361        and DEFAULT_PENALTY_PROPORTIONAL_MILLIONTH. A base _and_ relative penalty
362        was chosen such that the penalty will be able to compete with the regular
363        base and relative fees.
364        """
365        # we only evaluate hints here, so use dict get (to not create many hints with self.get_hint)
366        hint = self._liquidity_hints.get(channel_id)
367        if not hint:
368            can_send, cannot_send, num_inflight_htlcs = None, None, 0
369        else:
370            can_send = hint.can_send(node_from < node_to)
371            cannot_send = hint.cannot_send(node_from < node_to)
372            num_inflight_htlcs = hint.num_inflight_htlcs(node_from < node_to)
373
374        if cannot_send is not None and amount >= cannot_send:
375            return inf
376        if can_send is not None and amount <= can_send:
377            return 0
378        success_fee = fee_for_edge_msat(amount, DEFAULT_PENALTY_BASE_MSAT, DEFAULT_PENALTY_PROPORTIONAL_MILLIONTH)
379        inflight_htlc_fee = num_inflight_htlcs * success_fee
380        return success_fee + inflight_htlc_fee
381
382    @with_lock
383    def add_to_blacklist(self, channel_id: ShortChannelID):
384        hint = self.get_hint(channel_id)
385        now = int(time.time())
386        hint.blacklist_timestamp = now
387
388    @with_lock
389    def get_blacklist(self) -> Set[ShortChannelID]:
390        now = int(time.time())
391        return set(k for k, v in self._liquidity_hints.items() if now - v.blacklist_timestamp < BLACKLIST_DURATION)
392
393    @with_lock
394    def clear_blacklist(self):
395        for k, v in self._liquidity_hints.items():
396            v.blacklist_timestamp = 0
397
398    @with_lock
399    def reset_liquidity_hints(self):
400        for k, v in self._liquidity_hints.items():
401            v.hint_timestamp = 0
402
403    def __repr__(self):
404        string = "liquidity hints:\n"
405        if self._liquidity_hints:
406            for k, v in self._liquidity_hints.items():
407                string += f"{k}: {v}\n"
408        return string
409
410
411class LNPathFinder(Logger):
412
413    def __init__(self, channel_db: ChannelDB):
414        Logger.__init__(self)
415        self.channel_db = channel_db
416        self.liquidity_hints = LiquidityHintMgr()
417
418    def update_liquidity_hints(
419            self,
420            route: LNPaymentRoute,
421            amount_msat: int,
422            failing_channel: ShortChannelID=None
423    ):
424        # go through the route and record successes until the failing channel is reached,
425        # for the failing channel, add a cannot_send liquidity hint
426        # note: actual routable amounts are slightly different than reported here
427        # as fees would need to be added
428        for r in route:
429            if r.short_channel_id != failing_channel:
430                self.logger.info(f"report {r.short_channel_id} to be able to forward {amount_msat} msat")
431                self.liquidity_hints.update_can_send(r.start_node, r.end_node, r.short_channel_id, amount_msat)
432            else:
433                self.logger.info(f"report {r.short_channel_id} to be unable to forward {amount_msat} msat")
434                self.liquidity_hints.update_cannot_send(r.start_node, r.end_node, r.short_channel_id, amount_msat)
435                break
436
437    def update_inflight_htlcs(self, route: LNPaymentRoute, add_htlcs: bool):
438        self.logger.info(f"{'Adding' if add_htlcs else 'Removing'} inflight htlcs to graph (liquidity hints).")
439        for r in route:
440            if add_htlcs:
441                self.liquidity_hints.add_htlc(r.start_node, r.end_node, r.short_channel_id)
442            else:
443                self.liquidity_hints.remove_htlc(r.start_node, r.end_node, r.short_channel_id)
444
445    def _edge_cost(
446            self,
447            *,
448            short_channel_id: bytes,
449            start_node: bytes,
450            end_node: bytes,
451            payment_amt_msat: int,
452            ignore_costs=False,
453            is_mine=False,
454            my_channels: Dict[ShortChannelID, 'Channel'] = None,
455            private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
456    ) -> Tuple[float, int]:
457        """Heuristic cost (distance metric) of going through a channel.
458        Returns (heuristic_cost, fee_for_edge_msat).
459        """
460        if private_route_edges is None:
461            private_route_edges = {}
462        channel_info = self.channel_db.get_channel_info(
463            short_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
464        if channel_info is None:
465            return float('inf'), 0
466        channel_policy = self.channel_db.get_policy_for_node(
467            short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges)
468        if channel_policy is None:
469            return float('inf'), 0
470        # channels that did not publish both policies often return temporary channel failure
471        channel_policy_backwards = self.channel_db.get_policy_for_node(
472            short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges)
473        if (channel_policy_backwards is None
474                and not is_mine
475                and short_channel_id not in private_route_edges):
476            return float('inf'), 0
477        if channel_policy.is_disabled():
478            return float('inf'), 0
479        if payment_amt_msat < channel_policy.htlc_minimum_msat:
480            return float('inf'), 0  # payment amount too little
481        if channel_info.capacity_sat is not None and \
482                payment_amt_msat // 1000 > channel_info.capacity_sat:
483            return float('inf'), 0  # payment amount too large
484        if channel_policy.htlc_maximum_msat is not None and \
485                payment_amt_msat > channel_policy.htlc_maximum_msat:
486            return float('inf'), 0  # payment amount too large
487        route_edge = private_route_edges.get(short_channel_id, None)
488        if route_edge is None:
489            node_info = self.channel_db.get_node_info_for_node_id(node_id=end_node)
490            route_edge = RouteEdge.from_channel_policy(
491                channel_policy=channel_policy,
492                short_channel_id=short_channel_id,
493                start_node=start_node,
494                end_node=end_node,
495                node_info=node_info)
496        if not route_edge.is_sane_to_use(payment_amt_msat):
497            return float('inf'), 0  # thanks but no thanks
498        # Distance metric notes:  # TODO constants are ad-hoc
499        # ( somewhat based on https://github.com/lightningnetwork/lnd/pull/1358 )
500        # - Edges have a base cost. (more edges -> less likely none will fail)
501        # - The larger the payment amount, and the longer the CLTV,
502        #   the more irritating it is if the HTLC gets stuck.
503        # - Paying lower fees is better. :)
504        if ignore_costs:
505            return DEFAULT_PENALTY_BASE_MSAT, 0
506        fee_msat = route_edge.fee_for_edge(payment_amt_msat)
507        cltv_cost = route_edge.cltv_expiry_delta * payment_amt_msat * 15 / 1_000_000_000
508        # the liquidty penalty takes care we favor edges that should be able to forward
509        # the payment and penalize edges that cannot
510        liquidity_penalty = self.liquidity_hints.penalty(start_node, end_node, short_channel_id, payment_amt_msat)
511        overall_cost = fee_msat + cltv_cost + liquidity_penalty
512        return overall_cost, fee_msat
513
514    def get_distances(
515            self,
516            *,
517            nodeA: bytes,
518            nodeB: bytes,
519            invoice_amount_msat: int,
520            my_channels: Dict[ShortChannelID, 'Channel'] = None,
521            private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
522    ) -> Dict[bytes, PathEdge]:
523        # note: we don't lock self.channel_db, so while the path finding runs,
524        #       the underlying graph could potentially change... (not good but maybe ~OK?)
525
526        # run Dijkstra
527        # The search is run in the REVERSE direction, from nodeB to nodeA,
528        # to properly calculate compound routing fees.
529        blacklist = self.liquidity_hints.get_blacklist()
530        distance_from_start = defaultdict(lambda: float('inf'))
531        distance_from_start[nodeB] = 0
532        prev_node = {}  # type: Dict[bytes, PathEdge]
533        nodes_to_explore = queue.PriorityQueue()
534        nodes_to_explore.put((0, invoice_amount_msat, nodeB))  # order of fields (in tuple) matters!
535
536        # main loop of search
537        while nodes_to_explore.qsize() > 0:
538            dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get()
539            if edge_endnode == nodeA:
540                break
541            if dist_to_edge_endnode != distance_from_start[edge_endnode]:
542                # queue.PriorityQueue does not implement decrease_priority,
543                # so instead of decreasing priorities, we add items again into the queue.
544                # so there are duplicates in the queue, that we discard now:
545                continue
546            for edge_channel_id in self.channel_db.get_channels_for_node(
547                    edge_endnode, my_channels=my_channels, private_route_edges=private_route_edges):
548                assert isinstance(edge_channel_id, bytes)
549                if blacklist and edge_channel_id in blacklist:
550                    continue
551                channel_info = self.channel_db.get_channel_info(
552                    edge_channel_id, my_channels=my_channels, private_route_edges=private_route_edges)
553                if channel_info is None:
554                    continue
555                edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
556                is_mine = edge_channel_id in my_channels
557                if is_mine:
558                    if edge_startnode == nodeA:  # payment outgoing, on our channel
559                        if not my_channels[edge_channel_id].can_pay(amount_msat, check_frozen=True):
560                            continue
561                    else:  # payment incoming, on our channel. (funny business, cycle weirdness)
562                        assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
563                        if not my_channels[edge_channel_id].can_receive(amount_msat, check_frozen=True):
564                            continue
565                edge_cost, fee_for_edge_msat = self._edge_cost(
566                    short_channel_id=edge_channel_id,
567                    start_node=edge_startnode,
568                    end_node=edge_endnode,
569                    payment_amt_msat=amount_msat,
570                    ignore_costs=(edge_startnode == nodeA),
571                    is_mine=is_mine,
572                    my_channels=my_channels,
573                    private_route_edges=private_route_edges)
574                alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
575                if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
576                    distance_from_start[edge_startnode] = alt_dist_to_neighbour
577                    prev_node[edge_startnode] = PathEdge(
578                        start_node=edge_startnode,
579                        end_node=edge_endnode,
580                        short_channel_id=ShortChannelID(edge_channel_id))
581                    amount_to_forward_msat = amount_msat + fee_for_edge_msat
582                    nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode))
583
584        return prev_node
585
586    @profiler
587    def find_path_for_payment(
588            self,
589            *,
590            nodeA: bytes,
591            nodeB: bytes,
592            invoice_amount_msat: int,
593            my_channels: Dict[ShortChannelID, 'Channel'] = None,
594            private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
595    ) -> Optional[LNPaymentPath]:
596        """Return a path from nodeA to nodeB."""
597        assert type(nodeA) is bytes
598        assert type(nodeB) is bytes
599        assert type(invoice_amount_msat) is int
600        if my_channels is None:
601            my_channels = {}
602
603        prev_node = self.get_distances(
604            nodeA=nodeA,
605            nodeB=nodeB,
606            invoice_amount_msat=invoice_amount_msat,
607            my_channels=my_channels,
608            private_route_edges=private_route_edges)
609
610        if nodeA not in prev_node:
611            return None  # no path found
612
613        # backtrack from search_end (nodeA) to search_start (nodeB)
614        # FIXME paths cannot be longer than 20 edges (onion packet)...
615        edge_startnode = nodeA
616        path = []
617        while edge_startnode != nodeB:
618            edge = prev_node[edge_startnode]
619            path += [edge]
620            edge_startnode = edge.node_id
621        return path
622
623    def create_route_from_path(
624            self,
625            path: Optional[LNPaymentPath],
626            *,
627            my_channels: Dict[ShortChannelID, 'Channel'] = None,
628            private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
629    ) -> LNPaymentRoute:
630        if path is None:
631            raise Exception('cannot create route from None path')
632        if private_route_edges is None:
633            private_route_edges = {}
634        route = []
635        prev_end_node = path[0].start_node
636        for path_edge in path:
637            short_channel_id = path_edge.short_channel_id
638            _endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels)
639            if _endnodes and sorted(_endnodes) != sorted([path_edge.start_node, path_edge.end_node]):
640                raise LNPathInconsistent("endpoints of edge inconsistent with short_channel_id")
641            if path_edge.start_node != prev_end_node:
642                raise LNPathInconsistent("edges do not chain together")
643            route_edge = private_route_edges.get(short_channel_id, None)
644            if route_edge is None:
645                channel_policy = self.channel_db.get_policy_for_node(
646                    short_channel_id=short_channel_id,
647                    node_id=path_edge.start_node,
648                    my_channels=my_channels)
649                if channel_policy is None:
650                    raise NoChannelPolicy(short_channel_id)
651                node_info = self.channel_db.get_node_info_for_node_id(node_id=path_edge.end_node)
652                route_edge = RouteEdge.from_channel_policy(
653                    channel_policy=channel_policy,
654                    short_channel_id=short_channel_id,
655                    start_node=path_edge.start_node,
656                    end_node=path_edge.end_node,
657                    node_info=node_info)
658            route.append(route_edge)
659            prev_end_node = path_edge.end_node
660        return route
661
662    def find_route(
663            self,
664            *,
665            nodeA: bytes,
666            nodeB: bytes,
667            invoice_amount_msat: int,
668            path = None,
669            my_channels: Dict[ShortChannelID, 'Channel'] = None,
670            private_route_edges: Dict[ShortChannelID, RouteEdge] = None,
671    ) -> Optional[LNPaymentRoute]:
672        route = None
673        if not path:
674            path = self.find_path_for_payment(
675                nodeA=nodeA,
676                nodeB=nodeB,
677                invoice_amount_msat=invoice_amount_msat,
678                my_channels=my_channels,
679                private_route_edges=private_route_edges)
680        if path:
681            route = self.create_route_from_path(
682                path, my_channels=my_channels, private_route_edges=private_route_edges)
683        return route
684