1# Copyright (C) 2018 The Electrum developers
2# Distributed under the MIT software license, see the accompanying
3# file LICENCE or http://www.opensource.org/licenses/mit-license.php
4
5from enum import IntFlag, IntEnum
6import enum
7import json
8from collections import namedtuple, defaultdict
9from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set, Sequence
10import re
11import time
12import attr
13from aiorpcx import NetAddress
14
15from .util import bfh, bh2u, inv_dict, UserFacingException
16from .util import list_enabled_bits
17from .crypto import sha256
18from .transaction import (Transaction, PartialTransaction, PartialTxInput, TxOutpoint,
19                          PartialTxOutput, opcodes, TxOutput)
20from .ecc import CURVE_ORDER, sig_string_from_der_sig, ECPubkey, string_to_number
21from . import ecc, bitcoin, crypto, transaction
22from .bitcoin import (push_script, redeem_script_to_address, address_to_script,
23                      construct_witness, construct_script)
24from . import segwit_addr
25from .i18n import _
26from .lnaddr import lndecode
27from .bip32 import BIP32Node, BIP32_PRIME
28from .transaction import BCDataStream
29
30if TYPE_CHECKING:
31    from .lnchannel import Channel, AbstractChannel
32    from .lnrouter import LNPaymentRoute
33    from .lnonion import OnionRoutingFailure
34
35
36# defined in BOLT-03:
37HTLC_TIMEOUT_WEIGHT = 663
38HTLC_SUCCESS_WEIGHT = 703
39COMMITMENT_TX_WEIGHT = 724
40HTLC_OUTPUT_WEIGHT = 172
41
42LN_MAX_FUNDING_SAT = pow(2, 24) - 1
43
44# dummy address for fee estimation of funding tx
45def ln_dummy_address():
46    return redeem_script_to_address('p2wsh', '')
47
48from .json_db import StoredObject
49
50
51def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[bytes, bytes]:
52    funding_txid_bytes = bytes.fromhex(funding_txid)[::-1]
53    i = int.from_bytes(funding_txid_bytes, 'big') ^ funding_index
54    return i.to_bytes(32, 'big'), funding_txid_bytes
55
56hex_to_bytes = lambda v: v if isinstance(v, bytes) else bytes.fromhex(v) if v is not None else None
57json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
58
59
60@attr.s
61class OnlyPubkeyKeypair(StoredObject):
62    pubkey = attr.ib(type=bytes, converter=hex_to_bytes)
63
64@attr.s
65class Keypair(OnlyPubkeyKeypair):
66    privkey = attr.ib(type=bytes, converter=hex_to_bytes)
67
68@attr.s
69class Config(StoredObject):
70    # shared channel config fields
71    payment_basepoint = attr.ib(type=OnlyPubkeyKeypair, converter=json_to_keypair)
72    multisig_key = attr.ib(type=OnlyPubkeyKeypair, converter=json_to_keypair)
73    htlc_basepoint = attr.ib(type=OnlyPubkeyKeypair, converter=json_to_keypair)
74    delayed_basepoint = attr.ib(type=OnlyPubkeyKeypair, converter=json_to_keypair)
75    revocation_basepoint = attr.ib(type=OnlyPubkeyKeypair, converter=json_to_keypair)
76    to_self_delay = attr.ib(type=int)  # applies to OTHER ctx
77    dust_limit_sat = attr.ib(type=int)  # applies to SAME ctx
78    max_htlc_value_in_flight_msat = attr.ib(type=int)  # max val of INCOMING htlcs
79    max_accepted_htlcs = attr.ib(type=int)  # max num of INCOMING htlcs
80    initial_msat = attr.ib(type=int)
81    reserve_sat = attr.ib(type=int)  # applies to OTHER ctx
82    htlc_minimum_msat = attr.ib(type=int)  # smallest value for INCOMING htlc
83    upfront_shutdown_script = attr.ib(type=bytes, converter=hex_to_bytes)
84
85    def validate_params(self, *, funding_sat: int) -> None:
86        conf_name = type(self).__name__
87        for key in (
88                self.payment_basepoint,
89                self.multisig_key,
90                self.htlc_basepoint,
91                self.delayed_basepoint,
92                self.revocation_basepoint
93        ):
94            if not (len(key.pubkey) == 33 and ecc.ECPubkey.is_pubkey_bytes(key.pubkey)):
95                raise Exception(f"{conf_name}. invalid pubkey in channel config")
96        if self.reserve_sat < self.dust_limit_sat:
97            raise Exception(f"{conf_name}. MUST set channel_reserve_satoshis greater than or equal to dust_limit_satoshis")
98        # technically this could be using the lower DUST_LIMIT_DEFAULT_SAT_SEGWIT
99        # but other implementations are checking against this value too; also let's be conservative
100        if self.dust_limit_sat < bitcoin.DUST_LIMIT_DEFAULT_SAT_LEGACY:
101            raise Exception(f"{conf_name}. dust limit too low: {self.dust_limit_sat} sat")
102        if self.reserve_sat > funding_sat // 100:
103            raise Exception(f"{conf_name}. reserve too high: {self.reserve_sat}, funding_sat: {funding_sat}")
104        if self.htlc_minimum_msat > 1_000:
105            raise Exception(f"{conf_name}. htlc_minimum_msat too high: {self.htlc_minimum_msat} msat")
106        HTLC_MINIMUM_MSAT_MIN = 0  # should be at least 1 really, but apparently some nodes are sending zero...
107        if self.htlc_minimum_msat < HTLC_MINIMUM_MSAT_MIN:
108            raise Exception(f"{conf_name}. htlc_minimum_msat too low: {self.htlc_minimum_msat} msat < {HTLC_MINIMUM_MSAT_MIN}")
109        if self.max_accepted_htlcs < 1:
110            raise Exception(f"{conf_name}. max_accepted_htlcs too low: {self.max_accepted_htlcs}")
111        if self.max_accepted_htlcs > 483:
112            raise Exception(f"{conf_name}. max_accepted_htlcs too high: {self.max_accepted_htlcs}")
113        if self.to_self_delay > MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED:
114            raise Exception(f"{conf_name}. to_self_delay too high: {self.to_self_delay} > {MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED}")
115        if self.max_htlc_value_in_flight_msat < min(1000 * funding_sat, 100_000_000):
116            raise Exception(f"{conf_name}. max_htlc_value_in_flight_msat is too small: {self.max_htlc_value_in_flight_msat}")
117
118
119@attr.s
120class LocalConfig(Config):
121    channel_seed = attr.ib(type=bytes, converter=hex_to_bytes)  # type: Optional[bytes]
122    funding_locked_received = attr.ib(type=bool)
123    was_announced = attr.ib(type=bool)
124    current_commitment_signature = attr.ib(type=bytes, converter=hex_to_bytes)
125    current_htlc_signatures = attr.ib(type=bytes, converter=hex_to_bytes)
126    per_commitment_secret_seed = attr.ib(type=bytes, converter=hex_to_bytes)
127
128    @classmethod
129    def from_seed(self, **kwargs):
130        channel_seed = kwargs['channel_seed']
131        static_remotekey = kwargs.pop('static_remotekey')
132        node = BIP32Node.from_rootseed(channel_seed, xtype='standard')
133        keypair_generator = lambda family: generate_keypair(node, family)
134        kwargs['per_commitment_secret_seed'] = keypair_generator(LnKeyFamily.REVOCATION_ROOT).privkey
135        kwargs['multisig_key'] = keypair_generator(LnKeyFamily.MULTISIG)
136        kwargs['htlc_basepoint'] = keypair_generator(LnKeyFamily.HTLC_BASE)
137        kwargs['delayed_basepoint'] = keypair_generator(LnKeyFamily.DELAY_BASE)
138        kwargs['revocation_basepoint'] = keypair_generator(LnKeyFamily.REVOCATION_BASE)
139        kwargs['payment_basepoint'] = OnlyPubkeyKeypair(static_remotekey) if static_remotekey else keypair_generator(LnKeyFamily.PAYMENT_BASE)
140        return LocalConfig(**kwargs)
141
142    def validate_params(self, *, funding_sat: int) -> None:
143        conf_name = type(self).__name__
144        # run base checks regardless whether LOCAL/REMOTE config
145        super().validate_params(funding_sat=funding_sat)
146        # run some stricter checks on LOCAL config (make sure we ourselves do the sane thing,
147        # even if we are lenient with REMOTE for compatibility reasons)
148        HTLC_MINIMUM_MSAT_MIN = 1
149        if self.htlc_minimum_msat < HTLC_MINIMUM_MSAT_MIN:
150            raise Exception(f"{conf_name}. htlc_minimum_msat too low: {self.htlc_minimum_msat} msat < {HTLC_MINIMUM_MSAT_MIN}")
151
152@attr.s
153class RemoteConfig(Config):
154    next_per_commitment_point = attr.ib(type=bytes, converter=hex_to_bytes)
155    current_per_commitment_point = attr.ib(default=None, type=bytes, converter=hex_to_bytes)
156
157@attr.s
158class FeeUpdate(StoredObject):
159    rate = attr.ib(type=int)  # in sat/kw
160    ctn_local = attr.ib(default=None, type=int)
161    ctn_remote = attr.ib(default=None, type=int)
162
163@attr.s
164class ChannelConstraints(StoredObject):
165    capacity = attr.ib(type=int)  # in sat
166    is_initiator = attr.ib(type=bool)  # note: sometimes also called "funder"
167    funding_txn_minimum_depth = attr.ib(type=int)
168
169
170CHANNEL_BACKUP_VERSION = 0
171
172@attr.s
173class ChannelBackupStorage(StoredObject):
174    funding_txid = attr.ib(type=str)
175    funding_index = attr.ib(type=int, converter=int)
176    funding_address = attr.ib(type=str)
177    is_initiator = attr.ib(type=bool)
178
179    def funding_outpoint(self):
180        return Outpoint(self.funding_txid, self.funding_index)
181
182    def channel_id(self):
183        chan_id, _ = channel_id_from_funding_tx(self.funding_txid, self.funding_index)
184        return chan_id
185
186@attr.s
187class OnchainChannelBackupStorage(ChannelBackupStorage):
188    node_id_prefix = attr.ib(type=bytes, converter=hex_to_bytes)
189
190@attr.s
191class ImportedChannelBackupStorage(ChannelBackupStorage):
192    node_id = attr.ib(type=bytes, converter=hex_to_bytes)
193    privkey = attr.ib(type=bytes, converter=hex_to_bytes)
194    host = attr.ib(type=str)
195    port = attr.ib(type=int, converter=int)
196    channel_seed = attr.ib(type=bytes, converter=hex_to_bytes)
197    local_delay = attr.ib(type=int, converter=int)
198    remote_delay = attr.ib(type=int, converter=int)
199    remote_payment_pubkey = attr.ib(type=bytes, converter=hex_to_bytes)
200    remote_revocation_pubkey = attr.ib(type=bytes, converter=hex_to_bytes)
201
202    def to_bytes(self) -> bytes:
203        vds = BCDataStream()
204        vds.write_int16(CHANNEL_BACKUP_VERSION)
205        vds.write_boolean(self.is_initiator)
206        vds.write_bytes(self.privkey, 32)
207        vds.write_bytes(self.channel_seed, 32)
208        vds.write_bytes(self.node_id, 33)
209        vds.write_bytes(bfh(self.funding_txid), 32)
210        vds.write_int16(self.funding_index)
211        vds.write_string(self.funding_address)
212        vds.write_bytes(self.remote_payment_pubkey, 33)
213        vds.write_bytes(self.remote_revocation_pubkey, 33)
214        vds.write_int16(self.local_delay)
215        vds.write_int16(self.remote_delay)
216        vds.write_string(self.host)
217        vds.write_int16(self.port)
218        return bytes(vds.input)
219
220    @staticmethod
221    def from_bytes(s):
222        vds = BCDataStream()
223        vds.write(s)
224        version = vds.read_int16()
225        if version != CHANNEL_BACKUP_VERSION:
226            raise Exception(f"unknown version for channel backup: {version}")
227        return ImportedChannelBackupStorage(
228            is_initiator = vds.read_boolean(),
229            privkey = vds.read_bytes(32).hex(),
230            channel_seed = vds.read_bytes(32).hex(),
231            node_id = vds.read_bytes(33).hex(),
232            funding_txid = vds.read_bytes(32).hex(),
233            funding_index = vds.read_int16(),
234            funding_address = vds.read_string(),
235            remote_payment_pubkey = vds.read_bytes(33).hex(),
236            remote_revocation_pubkey = vds.read_bytes(33).hex(),
237            local_delay = vds.read_int16(),
238            remote_delay = vds.read_int16(),
239            host = vds.read_string(),
240            port = vds.read_int16())
241
242
243
244class ScriptHtlc(NamedTuple):
245    redeem_script: bytes
246    htlc: 'UpdateAddHtlc'
247
248
249# FIXME duplicate of TxOutpoint in transaction.py??
250@attr.s
251class Outpoint(StoredObject):
252    txid = attr.ib(type=str)
253    output_index = attr.ib(type=int)
254
255    def to_str(self):
256        return "{}:{}".format(self.txid, self.output_index)
257
258
259class HtlcLog(NamedTuple):
260    success: bool
261    amount_msat: int  # amount for receiver (e.g. from invoice)
262    route: Optional['LNPaymentRoute'] = None
263    preimage: Optional[bytes] = None
264    error_bytes: Optional[bytes] = None
265    failure_msg: Optional['OnionRoutingFailure'] = None
266    sender_idx: Optional[int] = None
267
268    def formatted_tuple(self):
269        route = self.route
270        route_str = '%d'%len(route)
271        short_channel_id = None
272        if not self.success:
273            sender_idx = self.sender_idx
274            failure_msg = self.failure_msg
275            if sender_idx is not None:
276                try:
277                    short_channel_id = route[sender_idx + 1].short_channel_id
278                except IndexError:
279                    # payment destination reported error
280                    short_channel_id = _("Destination node")
281            message = failure_msg.code_name()
282        else:
283            short_channel_id = route[-1].short_channel_id
284            message = _('Success')
285        chan_str = str(short_channel_id) if short_channel_id else _("Unknown")
286        return route_str, chan_str, message
287
288
289class LightningError(Exception): pass
290class LightningPeerConnectionClosed(LightningError): pass
291class UnableToDeriveSecret(LightningError): pass
292class HandshakeFailed(LightningError): pass
293class ConnStringFormatError(LightningError): pass
294class RemoteMisbehaving(LightningError): pass
295class UpfrontShutdownScriptViolation(RemoteMisbehaving): pass
296
297class NotFoundChanAnnouncementForUpdate(Exception): pass
298class InvalidGossipMsg(Exception):
299    """e.g. signature check failed"""
300
301class PaymentFailure(UserFacingException): pass
302class NoPathFound(PaymentFailure):
303    def __str__(self):
304        return _('No path found')
305
306# TODO make some of these values configurable?
307REDEEM_AFTER_DOUBLE_SPENT_DELAY = 30
308
309CHANNEL_OPENING_TIMEOUT = 24*60*60
310
311# Small capacity channels are problematic for many reasons. As the onchain fees start to become
312# significant compared to the capacity, things start to break down. e.g. the counterparty
313# force-closing the channel costs much of the funds in the channel.
314# Closing a channel uses ~200 vbytes onchain, feerates could spike to 100 sat/vbyte or even higher;
315# that in itself is already 20_000 sats. This mining fee is reserved and cannot be used for payments.
316# The value below is chosen arbitrarily to be one order of magnitude higher than that.
317MIN_FUNDING_SAT = 200_000
318
319##### CLTV-expiry-delta-related values
320# see https://github.com/lightningnetwork/lightning-rfc/blob/master/02-peer-protocol.md#cltv_expiry_delta-selection
321
322# the minimum cltv_expiry accepted for newly received HTLCs
323# note: when changing, consider Blockchain.is_tip_stale()
324MIN_FINAL_CLTV_EXPIRY_ACCEPTED = 144
325# set it a tiny bit higher for invoices as blocks could get mined
326# during forward path of payment
327MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE = MIN_FINAL_CLTV_EXPIRY_ACCEPTED + 3
328
329# the deadline for offered HTLCs:
330# the deadline after which the channel has to be failed and timed out on-chain
331NBLOCK_DEADLINE_AFTER_EXPIRY_FOR_OFFERED_HTLCS = 1
332
333# the deadline for received HTLCs this node has fulfilled:
334# the deadline after which the channel has to be failed and the HTLC fulfilled on-chain before its cltv_expiry
335NBLOCK_DEADLINE_BEFORE_EXPIRY_FOR_RECEIVED_HTLCS = 72
336
337NBLOCK_CLTV_EXPIRY_TOO_FAR_INTO_FUTURE = 28 * 144
338
339MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED = 2016
340
341class RevocationStore:
342    # closely based on code in lightningnetwork/lnd
343
344    START_INDEX = 2 ** 48 - 1
345
346    def __init__(self, storage):
347        if len(storage) == 0:
348            storage['index'] = self.START_INDEX
349            storage['buckets'] = {}
350        self.storage = storage
351        self.buckets = storage['buckets']
352
353    def add_next_entry(self, hsh):
354        index = self.storage['index']
355        new_element = ShachainElement(index=index, secret=hsh)
356        bucket = count_trailing_zeros(index)
357        for i in range(0, bucket):
358            this_bucket = self.buckets[i]
359            e = shachain_derive(new_element, this_bucket.index)
360            if e != this_bucket:
361                raise Exception("hash is not derivable: {} {} {}".format(bh2u(e.secret), bh2u(this_bucket.secret), this_bucket.index))
362        self.buckets[bucket] = new_element
363        self.storage['index'] = index - 1
364
365    def retrieve_secret(self, index: int) -> bytes:
366        assert index <= self.START_INDEX, index
367        for i in range(0, 49):
368            bucket = self.buckets.get(i)
369            if bucket is None:
370                raise UnableToDeriveSecret()
371            try:
372                element = shachain_derive(bucket, index)
373            except UnableToDeriveSecret:
374                continue
375            return element.secret
376        raise UnableToDeriveSecret()
377
378    def __eq__(self, o):
379        return type(o) is RevocationStore and self.serialize() == o.serialize()
380
381    def __hash__(self):
382        return hash(json.dumps(self.serialize(), sort_keys=True))
383
384
385def count_trailing_zeros(index):
386    """ BOLT-03 (where_to_put_secret) """
387    try:
388        return list(reversed(bin(index)[2:])).index("1")
389    except ValueError:
390        return 48
391
392def shachain_derive(element, to_index):
393    def get_prefix(index, pos):
394        mask = (1 << 64) - 1 - ((1 << pos) - 1)
395        return index & mask
396    from_index = element.index
397    zeros = count_trailing_zeros(from_index)
398    if from_index != get_prefix(to_index, zeros):
399        raise UnableToDeriveSecret("prefixes are different; index not derivable")
400    return ShachainElement(
401        get_per_commitment_secret_from_seed(element.secret, to_index, zeros),
402        to_index)
403
404ShachainElement = namedtuple("ShachainElement", ["secret", "index"])
405ShachainElement.__str__ = lambda self: "ShachainElement(" + bh2u(self.secret) + "," + str(self.index) + ")"
406
407def get_per_commitment_secret_from_seed(seed: bytes, i: int, bits: int = 48) -> bytes:
408    """Generate per commitment secret."""
409    per_commitment_secret = bytearray(seed)
410    for bitindex in range(bits - 1, -1, -1):
411        mask = 1 << bitindex
412        if i & mask:
413            per_commitment_secret[bitindex // 8] ^= 1 << (bitindex % 8)
414            per_commitment_secret = bytearray(sha256(per_commitment_secret))
415    bajts = bytes(per_commitment_secret)
416    return bajts
417
418def secret_to_pubkey(secret: int) -> bytes:
419    assert type(secret) is int
420    return ecc.ECPrivkey.from_secret_scalar(secret).get_public_key_bytes(compressed=True)
421
422def privkey_to_pubkey(priv: bytes) -> bytes:
423    return ecc.ECPrivkey(priv[:32]).get_public_key_bytes()
424
425def derive_pubkey(basepoint: bytes, per_commitment_point: bytes) -> bytes:
426    p = ecc.ECPubkey(basepoint) + ecc.GENERATOR * ecc.string_to_number(sha256(per_commitment_point + basepoint))
427    return p.get_public_key_bytes()
428
429def derive_privkey(secret: int, per_commitment_point: bytes) -> int:
430    assert type(secret) is int
431    basepoint_bytes = secret_to_pubkey(secret)
432    basepoint = secret + ecc.string_to_number(sha256(per_commitment_point + basepoint_bytes))
433    basepoint %= CURVE_ORDER
434    return basepoint
435
436def derive_blinded_pubkey(basepoint: bytes, per_commitment_point: bytes) -> bytes:
437    k1 = ecc.ECPubkey(basepoint) * ecc.string_to_number(sha256(basepoint + per_commitment_point))
438    k2 = ecc.ECPubkey(per_commitment_point) * ecc.string_to_number(sha256(per_commitment_point + basepoint))
439    return (k1 + k2).get_public_key_bytes()
440
441def derive_blinded_privkey(basepoint_secret: bytes, per_commitment_secret: bytes) -> bytes:
442    basepoint = ecc.ECPrivkey(basepoint_secret).get_public_key_bytes(compressed=True)
443    per_commitment_point = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True)
444    k1 = ecc.string_to_number(basepoint_secret) * ecc.string_to_number(sha256(basepoint + per_commitment_point))
445    k2 = ecc.string_to_number(per_commitment_secret) * ecc.string_to_number(sha256(per_commitment_point + basepoint))
446    sum = (k1 + k2) % ecc.CURVE_ORDER
447    return int.to_bytes(sum, length=32, byteorder='big', signed=False)
448
449
450def make_htlc_tx_output(amount_msat, local_feerate, revocationpubkey, local_delayedpubkey, success, to_self_delay):
451    assert type(amount_msat) is int
452    assert type(local_feerate) is int
453    assert type(revocationpubkey) is bytes
454    assert type(local_delayedpubkey) is bytes
455    script = bfh(construct_script([
456        opcodes.OP_IF,
457        revocationpubkey,
458        opcodes.OP_ELSE,
459        to_self_delay,
460        opcodes.OP_CHECKSEQUENCEVERIFY,
461        opcodes.OP_DROP,
462        local_delayedpubkey,
463        opcodes.OP_ENDIF,
464        opcodes.OP_CHECKSIG,
465    ]))
466
467    p2wsh = bitcoin.redeem_script_to_address('p2wsh', bh2u(script))
468    weight = HTLC_SUCCESS_WEIGHT if success else HTLC_TIMEOUT_WEIGHT
469    fee = local_feerate * weight
470    fee = fee // 1000 * 1000
471    final_amount_sat = (amount_msat - fee) // 1000
472    assert final_amount_sat > 0, final_amount_sat
473    output = PartialTxOutput.from_address_and_value(p2wsh, final_amount_sat)
474    return script, output
475
476def make_htlc_tx_witness(remotehtlcsig: bytes, localhtlcsig: bytes,
477                         payment_preimage: bytes, witness_script: bytes) -> bytes:
478    assert type(remotehtlcsig) is bytes
479    assert type(localhtlcsig) is bytes
480    assert type(payment_preimage) is bytes
481    assert type(witness_script) is bytes
482    return bfh(construct_witness([0, remotehtlcsig, localhtlcsig, payment_preimage, witness_script]))
483
484def make_htlc_tx_inputs(htlc_output_txid: str, htlc_output_index: int,
485                        amount_msat: int, witness_script: str) -> List[PartialTxInput]:
486    assert type(htlc_output_txid) is str
487    assert type(htlc_output_index) is int
488    assert type(amount_msat) is int
489    assert type(witness_script) is str
490    txin = PartialTxInput(prevout=TxOutpoint(txid=bfh(htlc_output_txid), out_idx=htlc_output_index),
491                          nsequence=0)
492    txin.witness_script = bfh(witness_script)
493    txin.script_sig = b''
494    txin._trusted_value_sats = amount_msat // 1000
495    c_inputs = [txin]
496    return c_inputs
497
498def make_htlc_tx(*, cltv_expiry: int, inputs: List[PartialTxInput], output: PartialTxOutput) -> PartialTransaction:
499    assert type(cltv_expiry) is int
500    c_outputs = [output]
501    tx = PartialTransaction.from_io(inputs, c_outputs, locktime=cltv_expiry, version=2)
502    return tx
503
504def make_offered_htlc(revocation_pubkey: bytes, remote_htlcpubkey: bytes,
505                      local_htlcpubkey: bytes, payment_hash: bytes) -> bytes:
506    assert type(revocation_pubkey) is bytes
507    assert type(remote_htlcpubkey) is bytes
508    assert type(local_htlcpubkey) is bytes
509    assert type(payment_hash) is bytes
510    script = bfh(construct_script([
511        opcodes.OP_DUP,
512        opcodes.OP_HASH160,
513        bitcoin.hash_160(revocation_pubkey),
514        opcodes.OP_EQUAL,
515        opcodes.OP_IF,
516        opcodes.OP_CHECKSIG,
517        opcodes.OP_ELSE,
518        remote_htlcpubkey,
519        opcodes.OP_SWAP,
520        opcodes.OP_SIZE,
521        32,
522        opcodes.OP_EQUAL,
523        opcodes.OP_NOTIF,
524        opcodes.OP_DROP,
525        2,
526        opcodes.OP_SWAP,
527        local_htlcpubkey,
528        2,
529        opcodes.OP_CHECKMULTISIG,
530        opcodes.OP_ELSE,
531        opcodes.OP_HASH160,
532        crypto.ripemd(payment_hash),
533        opcodes.OP_EQUALVERIFY,
534        opcodes.OP_CHECKSIG,
535        opcodes.OP_ENDIF,
536        opcodes.OP_ENDIF,
537    ]))
538    return script
539
540def make_received_htlc(revocation_pubkey: bytes, remote_htlcpubkey: bytes,
541                       local_htlcpubkey: bytes, payment_hash: bytes, cltv_expiry: int) -> bytes:
542    for i in [revocation_pubkey, remote_htlcpubkey, local_htlcpubkey, payment_hash]:
543        assert type(i) is bytes
544    assert type(cltv_expiry) is int
545
546    script = bfh(construct_script([
547        opcodes.OP_DUP,
548        opcodes.OP_HASH160,
549        bitcoin.hash_160(revocation_pubkey),
550        opcodes.OP_EQUAL,
551        opcodes.OP_IF,
552        opcodes.OP_CHECKSIG,
553        opcodes.OP_ELSE,
554        remote_htlcpubkey,
555        opcodes.OP_SWAP,
556        opcodes.OP_SIZE,
557        32,
558        opcodes.OP_EQUAL,
559        opcodes.OP_IF,
560        opcodes.OP_HASH160,
561        crypto.ripemd(payment_hash),
562        opcodes.OP_EQUALVERIFY,
563        2,
564        opcodes.OP_SWAP,
565        local_htlcpubkey,
566        2,
567        opcodes.OP_CHECKMULTISIG,
568        opcodes.OP_ELSE,
569        opcodes.OP_DROP,
570        cltv_expiry,
571        opcodes.OP_CHECKLOCKTIMEVERIFY,
572        opcodes.OP_DROP,
573        opcodes.OP_CHECKSIG,
574        opcodes.OP_ENDIF,
575        opcodes.OP_ENDIF,
576    ]))
577    return script
578
579def make_htlc_output_witness_script(is_received_htlc: bool, remote_revocation_pubkey: bytes, remote_htlc_pubkey: bytes,
580                                    local_htlc_pubkey: bytes, payment_hash: bytes, cltv_expiry: Optional[int]) -> bytes:
581    if is_received_htlc:
582        return make_received_htlc(revocation_pubkey=remote_revocation_pubkey,
583                                  remote_htlcpubkey=remote_htlc_pubkey,
584                                  local_htlcpubkey=local_htlc_pubkey,
585                                  payment_hash=payment_hash,
586                                  cltv_expiry=cltv_expiry)
587    else:
588        return make_offered_htlc(revocation_pubkey=remote_revocation_pubkey,
589                                 remote_htlcpubkey=remote_htlc_pubkey,
590                                 local_htlcpubkey=local_htlc_pubkey,
591                                 payment_hash=payment_hash)
592
593
594def get_ordered_channel_configs(chan: 'AbstractChannel', for_us: bool) -> Tuple[Union[LocalConfig, RemoteConfig],
595                                                                                Union[LocalConfig, RemoteConfig]]:
596    conf =       chan.config[LOCAL] if     for_us else chan.config[REMOTE]
597    other_conf = chan.config[LOCAL] if not for_us else chan.config[REMOTE]
598    return conf, other_conf
599
600
601def possible_output_idxs_of_htlc_in_ctx(*, chan: 'Channel', pcp: bytes, subject: 'HTLCOwner',
602                                        htlc_direction: 'Direction', ctx: Transaction,
603                                        htlc: 'UpdateAddHtlc') -> Set[int]:
604    amount_msat, cltv_expiry, payment_hash = htlc.amount_msat, htlc.cltv_expiry, htlc.payment_hash
605    for_us = subject == LOCAL
606    conf, other_conf = get_ordered_channel_configs(chan=chan, for_us=for_us)
607
608    other_revocation_pubkey = derive_blinded_pubkey(other_conf.revocation_basepoint.pubkey, pcp)
609    other_htlc_pubkey = derive_pubkey(other_conf.htlc_basepoint.pubkey, pcp)
610    htlc_pubkey = derive_pubkey(conf.htlc_basepoint.pubkey, pcp)
611    preimage_script = make_htlc_output_witness_script(is_received_htlc=htlc_direction == RECEIVED,
612                                                      remote_revocation_pubkey=other_revocation_pubkey,
613                                                      remote_htlc_pubkey=other_htlc_pubkey,
614                                                      local_htlc_pubkey=htlc_pubkey,
615                                                      payment_hash=payment_hash,
616                                                      cltv_expiry=cltv_expiry)
617    htlc_address = redeem_script_to_address('p2wsh', bh2u(preimage_script))
618    candidates = ctx.get_output_idxs_from_address(htlc_address)
619    return {output_idx for output_idx in candidates
620            if ctx.outputs()[output_idx].value == htlc.amount_msat // 1000}
621
622
623def map_htlcs_to_ctx_output_idxs(*, chan: 'Channel', ctx: Transaction, pcp: bytes,
624                                 subject: 'HTLCOwner', ctn: int) -> Dict[Tuple['Direction', 'UpdateAddHtlc'], Tuple[int, int]]:
625    """Returns a dict from (htlc_dir, htlc) to (ctx_output_idx, htlc_relative_idx)"""
626    htlc_to_ctx_output_idx_map = {}  # type: Dict[Tuple[Direction, UpdateAddHtlc], int]
627    unclaimed_ctx_output_idxs = set(range(len(ctx.outputs())))
628    offered_htlcs = chan.included_htlcs(subject, SENT, ctn=ctn)
629    offered_htlcs.sort(key=lambda htlc: htlc.cltv_expiry)
630    received_htlcs = chan.included_htlcs(subject, RECEIVED, ctn=ctn)
631    received_htlcs.sort(key=lambda htlc: htlc.cltv_expiry)
632    for direction, htlcs in zip([SENT, RECEIVED], [offered_htlcs, received_htlcs]):
633        for htlc in htlcs:
634            cands = sorted(possible_output_idxs_of_htlc_in_ctx(chan=chan,
635                                                               pcp=pcp,
636                                                               subject=subject,
637                                                               htlc_direction=direction,
638                                                               ctx=ctx,
639                                                               htlc=htlc))
640            for ctx_output_idx in cands:
641                if ctx_output_idx in unclaimed_ctx_output_idxs:
642                    unclaimed_ctx_output_idxs.discard(ctx_output_idx)
643                    htlc_to_ctx_output_idx_map[(direction, htlc)] = ctx_output_idx
644                    break
645    # calc htlc_relative_idx
646    inverse_map = {ctx_output_idx: (direction, htlc)
647                   for ((direction, htlc), ctx_output_idx) in htlc_to_ctx_output_idx_map.items()}
648
649    return {inverse_map[ctx_output_idx]: (ctx_output_idx, htlc_relative_idx)
650            for htlc_relative_idx, ctx_output_idx in enumerate(sorted(inverse_map))}
651
652
653def make_htlc_tx_with_open_channel(*, chan: 'Channel', pcp: bytes, subject: 'HTLCOwner', ctn: int,
654                                   htlc_direction: 'Direction', commit: Transaction, ctx_output_idx: int,
655                                   htlc: 'UpdateAddHtlc', name: str = None) -> Tuple[bytes, PartialTransaction]:
656    amount_msat, cltv_expiry, payment_hash = htlc.amount_msat, htlc.cltv_expiry, htlc.payment_hash
657    for_us = subject == LOCAL
658    conf, other_conf = get_ordered_channel_configs(chan=chan, for_us=for_us)
659
660    delayedpubkey = derive_pubkey(conf.delayed_basepoint.pubkey, pcp)
661    other_revocation_pubkey = derive_blinded_pubkey(other_conf.revocation_basepoint.pubkey, pcp)
662    other_htlc_pubkey = derive_pubkey(other_conf.htlc_basepoint.pubkey, pcp)
663    htlc_pubkey = derive_pubkey(conf.htlc_basepoint.pubkey, pcp)
664    # HTLC-success for the HTLC spending from a received HTLC output
665    # if we do not receive, and the commitment tx is not for us, they receive, so it is also an HTLC-success
666    is_htlc_success = htlc_direction == RECEIVED
667    witness_script_of_htlc_tx_output, htlc_tx_output = make_htlc_tx_output(
668        amount_msat = amount_msat,
669        local_feerate = chan.get_feerate(subject, ctn=ctn),
670        revocationpubkey=other_revocation_pubkey,
671        local_delayedpubkey=delayedpubkey,
672        success = is_htlc_success,
673        to_self_delay = other_conf.to_self_delay)
674    preimage_script = make_htlc_output_witness_script(is_received_htlc=is_htlc_success,
675                                                      remote_revocation_pubkey=other_revocation_pubkey,
676                                                      remote_htlc_pubkey=other_htlc_pubkey,
677                                                      local_htlc_pubkey=htlc_pubkey,
678                                                      payment_hash=payment_hash,
679                                                      cltv_expiry=cltv_expiry)
680    htlc_tx_inputs = make_htlc_tx_inputs(
681        commit.txid(), ctx_output_idx,
682        amount_msat=amount_msat,
683        witness_script=bh2u(preimage_script))
684    if is_htlc_success:
685        cltv_expiry = 0
686    htlc_tx = make_htlc_tx(cltv_expiry=cltv_expiry, inputs=htlc_tx_inputs, output=htlc_tx_output)
687    return witness_script_of_htlc_tx_output, htlc_tx
688
689def make_funding_input(local_funding_pubkey: bytes, remote_funding_pubkey: bytes,
690        funding_pos: int, funding_txid: str, funding_sat: int) -> PartialTxInput:
691    pubkeys = sorted([bh2u(local_funding_pubkey), bh2u(remote_funding_pubkey)])
692    # commitment tx input
693    prevout = TxOutpoint(txid=bfh(funding_txid), out_idx=funding_pos)
694    c_input = PartialTxInput(prevout=prevout)
695    c_input.script_type = 'p2wsh'
696    c_input.pubkeys = [bfh(pk) for pk in pubkeys]
697    c_input.num_sig = 2
698    c_input._trusted_value_sats = funding_sat
699    return c_input
700
701class HTLCOwner(IntFlag):
702    LOCAL = 1
703    REMOTE = -LOCAL
704
705    def inverted(self) -> 'HTLCOwner':
706        return -self
707
708    def __neg__(self) -> 'HTLCOwner':
709        return HTLCOwner(super().__neg__())
710
711
712class Direction(IntFlag):
713    SENT = -1     # in the context of HTLCs: "offered" HTLCs
714    RECEIVED = 1  # in the context of HTLCs: "received" HTLCs
715
716SENT = Direction.SENT
717RECEIVED = Direction.RECEIVED
718
719LOCAL = HTLCOwner.LOCAL
720REMOTE = HTLCOwner.REMOTE
721
722def make_commitment_outputs(*, fees_per_participant: Mapping[HTLCOwner, int], local_amount_msat: int, remote_amount_msat: int,
723        local_script: str, remote_script: str, htlcs: List[ScriptHtlc], dust_limit_sat: int) -> Tuple[List[PartialTxOutput], List[PartialTxOutput]]:
724    # BOLT-03: "Base commitment transaction fees are extracted from the funder's amount;
725    #           if that amount is insufficient, the entire amount of the funder's output is used."
726    #   -> if funder cannot afford feerate, their output might go negative, so take max(0, x) here:
727    to_local_amt = max(0, local_amount_msat - fees_per_participant[LOCAL])
728    to_local = PartialTxOutput(scriptpubkey=bfh(local_script), value=to_local_amt // 1000)
729    to_remote_amt = max(0, remote_amount_msat - fees_per_participant[REMOTE])
730    to_remote = PartialTxOutput(scriptpubkey=bfh(remote_script), value=to_remote_amt // 1000)
731
732    non_htlc_outputs = [to_local, to_remote]
733    htlc_outputs = []
734    for script, htlc in htlcs:
735        addr = bitcoin.redeem_script_to_address('p2wsh', bh2u(script))
736        htlc_outputs.append(PartialTxOutput(scriptpubkey=bfh(address_to_script(addr)),
737                                            value=htlc.amount_msat // 1000))
738
739    # trim outputs
740    c_outputs_filtered = list(filter(lambda x: x.value >= dust_limit_sat, non_htlc_outputs + htlc_outputs))
741    return htlc_outputs, c_outputs_filtered
742
743
744def offered_htlc_trim_threshold_sat(*, dust_limit_sat: int, feerate: int) -> int:
745    # offered htlcs strictly below this amount will be trimmed (from ctx).
746    # feerate is in sat/kw
747    # returns value in sat
748    weight = HTLC_TIMEOUT_WEIGHT
749    return dust_limit_sat + weight * feerate // 1000
750
751
752def received_htlc_trim_threshold_sat(*, dust_limit_sat: int, feerate: int) -> int:
753    # received htlcs strictly below this amount will be trimmed (from ctx).
754    # feerate is in sat/kw
755    # returns value in sat
756    weight = HTLC_SUCCESS_WEIGHT
757    return dust_limit_sat + weight * feerate // 1000
758
759
760def fee_for_htlc_output(*, feerate: int) -> int:
761    # feerate is in sat/kw
762    # returns fee in msat
763    return feerate * HTLC_OUTPUT_WEIGHT
764
765
766def calc_fees_for_commitment_tx(*, num_htlcs: int, feerate: int,
767                                is_local_initiator: bool, round_to_sat: bool = True) -> Dict['HTLCOwner', int]:
768    # feerate is in sat/kw
769    # returns fees in msats
770    # note: BOLT-02 specifies that msat fees need to be rounded down to sat.
771    #       However, the rounding needs to happen for the total fees, so if the return value
772    #       is to be used as part of additional fee calculation then rounding should be done after that.
773    overall_weight = COMMITMENT_TX_WEIGHT + num_htlcs * HTLC_OUTPUT_WEIGHT
774    fee = feerate * overall_weight
775    if round_to_sat:
776        fee = fee // 1000 * 1000
777    return {
778        LOCAL: fee if is_local_initiator else 0,
779        REMOTE: fee if not is_local_initiator else 0,
780    }
781
782
783def make_commitment(
784        *,
785        ctn: int,
786        local_funding_pubkey: bytes,
787        remote_funding_pubkey: bytes,
788        remote_payment_pubkey: bytes,
789        funder_payment_basepoint: bytes,
790        fundee_payment_basepoint: bytes,
791        revocation_pubkey: bytes,
792        delayed_pubkey: bytes,
793        to_self_delay: int,
794        funding_txid: str,
795        funding_pos: int,
796        funding_sat: int,
797        local_amount: int,
798        remote_amount: int,
799        dust_limit_sat: int,
800        fees_per_participant: Mapping[HTLCOwner, int],
801        htlcs: List[ScriptHtlc]
802) -> PartialTransaction:
803    c_input = make_funding_input(local_funding_pubkey, remote_funding_pubkey,
804                                 funding_pos, funding_txid, funding_sat)
805    obs = get_obscured_ctn(ctn, funder_payment_basepoint, fundee_payment_basepoint)
806    locktime = (0x20 << 24) + (obs & 0xffffff)
807    sequence = (0x80 << 24) + (obs >> 24)
808    c_input.nsequence = sequence
809
810    c_inputs = [c_input]
811
812    # commitment tx outputs
813    local_address = make_commitment_output_to_local_address(revocation_pubkey, to_self_delay, delayed_pubkey)
814    remote_address = make_commitment_output_to_remote_address(remote_payment_pubkey)
815    # note: it is assumed that the given 'htlcs' are all non-dust (dust htlcs already trimmed)
816
817    # BOLT-03: "Transaction Input and Output Ordering
818    #           Lexicographic ordering: see BIP69. In the case of identical HTLC outputs,
819    #           the outputs are ordered in increasing cltv_expiry order."
820    # so we sort by cltv_expiry now; and the later BIP69-sort is assumed to be *stable*
821    htlcs = list(htlcs)
822    htlcs.sort(key=lambda x: x.htlc.cltv_expiry)
823
824    htlc_outputs, c_outputs_filtered = make_commitment_outputs(
825        fees_per_participant=fees_per_participant,
826        local_amount_msat=local_amount,
827        remote_amount_msat=remote_amount,
828        local_script=address_to_script(local_address),
829        remote_script=address_to_script(remote_address),
830        htlcs=htlcs,
831        dust_limit_sat=dust_limit_sat)
832
833    assert sum(x.value for x in c_outputs_filtered) <= funding_sat, (c_outputs_filtered, funding_sat)
834
835    # create commitment tx
836    tx = PartialTransaction.from_io(c_inputs, c_outputs_filtered, locktime=locktime, version=2)
837    return tx
838
839def make_commitment_output_to_local_witness_script(
840        revocation_pubkey: bytes, to_self_delay: int, delayed_pubkey: bytes) -> bytes:
841    script = bfh(construct_script([
842        opcodes.OP_IF,
843        revocation_pubkey,
844        opcodes.OP_ELSE,
845        to_self_delay,
846        opcodes.OP_CHECKSEQUENCEVERIFY,
847        opcodes.OP_DROP,
848        delayed_pubkey,
849        opcodes.OP_ENDIF,
850        opcodes.OP_CHECKSIG,
851    ]))
852    return script
853
854def make_commitment_output_to_local_address(
855        revocation_pubkey: bytes, to_self_delay: int, delayed_pubkey: bytes) -> str:
856    local_script = make_commitment_output_to_local_witness_script(revocation_pubkey, to_self_delay, delayed_pubkey)
857    return bitcoin.redeem_script_to_address('p2wsh', bh2u(local_script))
858
859def make_commitment_output_to_remote_address(remote_payment_pubkey: bytes) -> str:
860    return bitcoin.pubkey_to_address('p2wpkh', bh2u(remote_payment_pubkey))
861
862def sign_and_get_sig_string(tx: PartialTransaction, local_config, remote_config):
863    tx.sign({bh2u(local_config.multisig_key.pubkey): (local_config.multisig_key.privkey, True)})
864    sig = tx.inputs()[0].part_sigs[local_config.multisig_key.pubkey]
865    sig_64 = sig_string_from_der_sig(sig[:-1])
866    return sig_64
867
868def funding_output_script(local_config, remote_config) -> str:
869    return funding_output_script_from_keys(local_config.multisig_key.pubkey, remote_config.multisig_key.pubkey)
870
871def funding_output_script_from_keys(pubkey1: bytes, pubkey2: bytes) -> str:
872    pubkeys = sorted([bh2u(pubkey1), bh2u(pubkey2)])
873    return transaction.multisig_script(pubkeys, 2)
874
875
876def get_obscured_ctn(ctn: int, funder: bytes, fundee: bytes) -> int:
877    mask = int.from_bytes(sha256(funder + fundee)[-6:], 'big')
878    return ctn ^ mask
879
880def extract_ctn_from_tx(tx: Transaction, txin_index: int, funder_payment_basepoint: bytes,
881                        fundee_payment_basepoint: bytes) -> int:
882    tx.deserialize()
883    locktime = tx.locktime
884    sequence = tx.inputs()[txin_index].nsequence
885    obs = ((sequence & 0xffffff) << 24) + (locktime & 0xffffff)
886    return get_obscured_ctn(obs, funder_payment_basepoint, fundee_payment_basepoint)
887
888def extract_ctn_from_tx_and_chan(tx: Transaction, chan: 'AbstractChannel') -> int:
889    funder_conf = chan.config[LOCAL] if     chan.is_initiator() else chan.config[REMOTE]
890    fundee_conf = chan.config[LOCAL] if not chan.is_initiator() else chan.config[REMOTE]
891    return extract_ctn_from_tx(tx, txin_index=0,
892                               funder_payment_basepoint=funder_conf.payment_basepoint.pubkey,
893                               fundee_payment_basepoint=fundee_conf.payment_basepoint.pubkey)
894
895def get_ecdh(priv: bytes, pub: bytes) -> bytes:
896    pt = ECPubkey(pub) * string_to_number(priv)
897    return sha256(pt.get_public_key_bytes())
898
899
900class LnFeatureContexts(enum.Flag):
901    INIT = enum.auto()
902    NODE_ANN = enum.auto()
903    CHAN_ANN_AS_IS = enum.auto()
904    CHAN_ANN_ALWAYS_ODD = enum.auto()
905    CHAN_ANN_ALWAYS_EVEN = enum.auto()
906    INVOICE = enum.auto()
907
908LNFC = LnFeatureContexts
909
910_ln_feature_direct_dependencies = defaultdict(set)  # type: Dict[LnFeatures, Set[LnFeatures]]
911_ln_feature_contexts = {}  # type: Dict[LnFeatures, LnFeatureContexts]
912
913class LnFeatures(IntFlag):
914    OPTION_DATA_LOSS_PROTECT_REQ = 1 << 0
915    OPTION_DATA_LOSS_PROTECT_OPT = 1 << 1
916    _ln_feature_contexts[OPTION_DATA_LOSS_PROTECT_OPT] = (LNFC.INIT | LnFeatureContexts.NODE_ANN)
917    _ln_feature_contexts[OPTION_DATA_LOSS_PROTECT_REQ] = (LNFC.INIT | LnFeatureContexts.NODE_ANN)
918
919    INITIAL_ROUTING_SYNC = 1 << 3
920    _ln_feature_contexts[INITIAL_ROUTING_SYNC] = LNFC.INIT
921
922    OPTION_UPFRONT_SHUTDOWN_SCRIPT_REQ = 1 << 4
923    OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT = 1 << 5
924    _ln_feature_contexts[OPTION_UPFRONT_SHUTDOWN_SCRIPT_OPT] = (LNFC.INIT | LNFC.NODE_ANN)
925    _ln_feature_contexts[OPTION_UPFRONT_SHUTDOWN_SCRIPT_REQ] = (LNFC.INIT | LNFC.NODE_ANN)
926
927    GOSSIP_QUERIES_REQ = 1 << 6
928    GOSSIP_QUERIES_OPT = 1 << 7
929    _ln_feature_contexts[GOSSIP_QUERIES_OPT] = (LNFC.INIT | LNFC.NODE_ANN)
930    _ln_feature_contexts[GOSSIP_QUERIES_REQ] = (LNFC.INIT | LNFC.NODE_ANN)
931
932    VAR_ONION_REQ = 1 << 8
933    VAR_ONION_OPT = 1 << 9
934    _ln_feature_contexts[VAR_ONION_OPT] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE)
935    _ln_feature_contexts[VAR_ONION_REQ] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE)
936
937    GOSSIP_QUERIES_EX_REQ = 1 << 10
938    GOSSIP_QUERIES_EX_OPT = 1 << 11
939    _ln_feature_direct_dependencies[GOSSIP_QUERIES_EX_OPT] = {GOSSIP_QUERIES_OPT}
940    _ln_feature_contexts[GOSSIP_QUERIES_EX_OPT] = (LNFC.INIT | LNFC.NODE_ANN)
941    _ln_feature_contexts[GOSSIP_QUERIES_EX_REQ] = (LNFC.INIT | LNFC.NODE_ANN)
942
943    OPTION_STATIC_REMOTEKEY_REQ = 1 << 12
944    OPTION_STATIC_REMOTEKEY_OPT = 1 << 13
945    _ln_feature_contexts[OPTION_STATIC_REMOTEKEY_OPT] = (LNFC.INIT | LNFC.NODE_ANN)
946    _ln_feature_contexts[OPTION_STATIC_REMOTEKEY_REQ] = (LNFC.INIT | LNFC.NODE_ANN)
947
948    PAYMENT_SECRET_REQ = 1 << 14
949    PAYMENT_SECRET_OPT = 1 << 15
950    _ln_feature_direct_dependencies[PAYMENT_SECRET_OPT] = {VAR_ONION_OPT}
951    _ln_feature_contexts[PAYMENT_SECRET_OPT] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE)
952    _ln_feature_contexts[PAYMENT_SECRET_REQ] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE)
953
954    BASIC_MPP_REQ = 1 << 16
955    BASIC_MPP_OPT = 1 << 17
956    _ln_feature_direct_dependencies[BASIC_MPP_OPT] = {PAYMENT_SECRET_OPT}
957    _ln_feature_contexts[BASIC_MPP_OPT] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE)
958    _ln_feature_contexts[BASIC_MPP_REQ] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE)
959
960    OPTION_SUPPORT_LARGE_CHANNEL_REQ = 1 << 18
961    OPTION_SUPPORT_LARGE_CHANNEL_OPT = 1 << 19
962    _ln_feature_contexts[OPTION_SUPPORT_LARGE_CHANNEL_OPT] = (LNFC.INIT | LNFC.NODE_ANN)
963    _ln_feature_contexts[OPTION_SUPPORT_LARGE_CHANNEL_REQ] = (LNFC.INIT | LNFC.NODE_ANN)
964
965    OPTION_TRAMPOLINE_ROUTING_REQ = 1 << 24
966    OPTION_TRAMPOLINE_ROUTING_OPT = 1 << 25
967
968    _ln_feature_contexts[OPTION_TRAMPOLINE_ROUTING_REQ] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE)
969    _ln_feature_contexts[OPTION_TRAMPOLINE_ROUTING_OPT] = (LNFC.INIT | LNFC.NODE_ANN | LNFC.INVOICE)
970
971    # temporary
972    OPTION_TRAMPOLINE_ROUTING_REQ_ECLAIR = 1 << 50
973    OPTION_TRAMPOLINE_ROUTING_OPT_ECLAIR = 1 << 51
974
975    def validate_transitive_dependencies(self) -> bool:
976        # for all even bit set, set corresponding odd bit:
977        features = self  # copy
978        flags = list_enabled_bits(features)
979        for flag in flags:
980            if flag % 2 == 0:
981                features |= 1 << get_ln_flag_pair_of_bit(flag)
982        # Check dependencies. We only check that the direct dependencies of each flag set
983        # are satisfied: this implies that transitive dependencies are also satisfied.
984        flags = list_enabled_bits(features)
985        for flag in flags:
986            for dependency in _ln_feature_direct_dependencies[1 << flag]:
987                if not (dependency & features):
988                    return False
989        return True
990
991    def for_init_message(self) -> 'LnFeatures':
992        features = LnFeatures(0)
993        for flag in list_enabled_bits(self):
994            if LnFeatureContexts.INIT & _ln_feature_contexts[1 << flag]:
995                features |= (1 << flag)
996        return features
997
998    def for_node_announcement(self) -> 'LnFeatures':
999        features = LnFeatures(0)
1000        for flag in list_enabled_bits(self):
1001            if LnFeatureContexts.NODE_ANN & _ln_feature_contexts[1 << flag]:
1002                features |= (1 << flag)
1003        return features
1004
1005    def for_invoice(self) -> 'LnFeatures':
1006        features = LnFeatures(0)
1007        for flag in list_enabled_bits(self):
1008            if LnFeatureContexts.INVOICE & _ln_feature_contexts[1 << flag]:
1009                features |= (1 << flag)
1010        return features
1011
1012    def for_channel_announcement(self) -> 'LnFeatures':
1013        features = LnFeatures(0)
1014        for flag in list_enabled_bits(self):
1015            ctxs = _ln_feature_contexts[1 << flag]
1016            if LnFeatureContexts.CHAN_ANN_AS_IS & ctxs:
1017                features |= (1 << flag)
1018            elif LnFeatureContexts.CHAN_ANN_ALWAYS_EVEN & ctxs:
1019                if flag % 2 == 0:
1020                    features |= (1 << flag)
1021            elif LnFeatureContexts.CHAN_ANN_ALWAYS_ODD & ctxs:
1022                if flag % 2 == 0:
1023                    flag = get_ln_flag_pair_of_bit(flag)
1024                features |= (1 << flag)
1025        return features
1026
1027    def supports(self, feature: 'LnFeatures') -> bool:
1028        """Returns whether given feature is enabled.
1029
1030        Helper function that tries to hide the complexity of even/odd bits.
1031        For example, instead of:
1032          bool(myfeatures & LnFeatures.VAR_ONION_OPT or myfeatures & LnFeatures.VAR_ONION_REQ)
1033        you can do:
1034          myfeatures.supports(LnFeatures.VAR_ONION_OPT)
1035        """
1036        enabled_bits = list_enabled_bits(feature)
1037        if len(enabled_bits) != 1:
1038            raise ValueError(f"'feature' cannot be a combination of features: {feature}")
1039        flag = enabled_bits[0]
1040        our_flags = set(list_enabled_bits(self))
1041        return (flag in our_flags
1042                or get_ln_flag_pair_of_bit(flag) in our_flags)
1043
1044
1045del LNFC  # name is ambiguous without context
1046
1047# features that are actually implemented and understood in our codebase:
1048# (note: this is not what we send in e.g. init!)
1049# (note: specify both OPT and REQ here)
1050LN_FEATURES_IMPLEMENTED = (
1051        LnFeatures(0)
1052        | LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT | LnFeatures.OPTION_DATA_LOSS_PROTECT_REQ
1053        | LnFeatures.GOSSIP_QUERIES_OPT | LnFeatures.GOSSIP_QUERIES_REQ
1054        | LnFeatures.OPTION_STATIC_REMOTEKEY_OPT | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ
1055        | LnFeatures.VAR_ONION_OPT | LnFeatures.VAR_ONION_REQ
1056        | LnFeatures.PAYMENT_SECRET_OPT | LnFeatures.PAYMENT_SECRET_REQ
1057        | LnFeatures.BASIC_MPP_OPT | LnFeatures.BASIC_MPP_REQ
1058        | LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT | LnFeatures.OPTION_TRAMPOLINE_ROUTING_REQ
1059)
1060
1061
1062def get_ln_flag_pair_of_bit(flag_bit: int) -> int:
1063    """Ln Feature flags are assigned in pairs, one even, one odd. See BOLT-09.
1064    Return the other flag from the pair.
1065    e.g. 6 -> 7
1066    e.g. 7 -> 6
1067    """
1068    if flag_bit % 2 == 0:
1069        return flag_bit + 1
1070    else:
1071        return flag_bit - 1
1072
1073
1074
1075class IncompatibleOrInsaneFeatures(Exception): pass
1076class UnknownEvenFeatureBits(IncompatibleOrInsaneFeatures): pass
1077class IncompatibleLightningFeatures(IncompatibleOrInsaneFeatures): pass
1078
1079
1080def ln_compare_features(our_features: 'LnFeatures', their_features: int) -> 'LnFeatures':
1081    """Returns negotiated features.
1082    Raises IncompatibleLightningFeatures if incompatible.
1083    """
1084    our_flags = set(list_enabled_bits(our_features))
1085    their_flags = set(list_enabled_bits(their_features))
1086    # check that they have our required features, and disable the optional features they don't have
1087    for flag in our_flags:
1088        if flag not in their_flags and get_ln_flag_pair_of_bit(flag) not in their_flags:
1089            # they don't have this feature we wanted :(
1090            if flag % 2 == 0:  # even flags are compulsory
1091                raise IncompatibleLightningFeatures(f"remote does not support {LnFeatures(1 << flag)!r}")
1092            our_features ^= 1 << flag  # disable flag
1093        else:
1094            # They too have this flag.
1095            # For easier feature-bit-testing, if this is an even flag, we also
1096            # set the corresponding odd flag now.
1097            if flag % 2 == 0 and our_features & (1 << flag):
1098                our_features |= 1 << get_ln_flag_pair_of_bit(flag)
1099    # check that we have their required features
1100    for flag in their_flags:
1101        if flag not in our_flags and get_ln_flag_pair_of_bit(flag) not in our_flags:
1102            # we don't have this feature they wanted :(
1103            if flag % 2 == 0:  # even flags are compulsory
1104                raise IncompatibleLightningFeatures(f"remote wanted feature we don't have: {LnFeatures(1 << flag)!r}")
1105    return our_features
1106
1107
1108def validate_features(features: int) -> None:
1109    """Raises IncompatibleOrInsaneFeatures if
1110    - a mandatory feature is listed that we don't recognize, or
1111    - the features are inconsistent
1112    """
1113    features = LnFeatures(features)
1114    enabled_features = list_enabled_bits(features)
1115    for fbit in enabled_features:
1116        if (1 << fbit) & LN_FEATURES_IMPLEMENTED == 0 and fbit % 2 == 0:
1117            raise UnknownEvenFeatureBits(fbit)
1118    if not features.validate_transitive_dependencies():
1119        raise IncompatibleOrInsaneFeatures(f"not all transitive dependencies are set. "
1120                                           f"features={features}")
1121
1122
1123def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> bytes:
1124    """Returns secret to be put into invoice.
1125    Derivation is deterministic, based on the preimage.
1126    Crucially the payment_hash must be derived in an independent way from this.
1127    """
1128    # Note that this could be random data too, but then we would need to store it.
1129    # We derive it identically to clightning, so that we cannot be distinguished:
1130    # https://github.com/ElementsProject/lightning/blob/faac4b28adee5221e83787d64cd5d30b16b62097/lightningd/invoice.c#L115
1131    modified = bytearray(payment_preimage)
1132    modified[0] ^= 1
1133    return sha256(bytes(modified))
1134
1135
1136class LNPeerAddr:
1137    # note: while not programmatically enforced, this class is meant to be *immutable*
1138
1139    def __init__(self, host: str, port: int, pubkey: bytes):
1140        assert isinstance(host, str), repr(host)
1141        assert isinstance(port, int), repr(port)
1142        assert isinstance(pubkey, bytes), repr(pubkey)
1143        try:
1144            net_addr = NetAddress(host, port)  # this validates host and port
1145        except Exception as e:
1146            raise ValueError(f"cannot construct LNPeerAddr: invalid host or port (host={host}, port={port})") from e
1147        # note: not validating pubkey as it would be too expensive:
1148        # if not ECPubkey.is_pubkey_bytes(pubkey): raise ValueError()
1149        self.host = host
1150        self.port = port
1151        self.pubkey = pubkey
1152        self._net_addr = net_addr
1153
1154    def __str__(self):
1155        return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str())
1156
1157    def __repr__(self):
1158        return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>'
1159
1160    def net_addr(self) -> NetAddress:
1161        return self._net_addr
1162
1163    def net_addr_str(self) -> str:
1164        return str(self._net_addr)
1165
1166    def __eq__(self, other):
1167        if not isinstance(other, LNPeerAddr):
1168            return False
1169        return (self.host == other.host
1170                and self.port == other.port
1171                and self.pubkey == other.pubkey)
1172
1173    def __ne__(self, other):
1174        return not (self == other)
1175
1176    def __hash__(self):
1177        return hash((self.host, self.port, self.pubkey))
1178
1179
1180def get_compressed_pubkey_from_bech32(bech32_pubkey: str) -> bytes:
1181    decoded_bech32 = segwit_addr.bech32_decode(bech32_pubkey)
1182    hrp = decoded_bech32.hrp
1183    data_5bits = decoded_bech32.data
1184    if decoded_bech32.encoding is None:
1185        raise ValueError("Bad bech32 checksum")
1186    if decoded_bech32.encoding != segwit_addr.Encoding.BECH32:
1187        raise ValueError("Bad bech32 encoding: must be using vanilla BECH32")
1188    if hrp != 'ln':
1189        raise Exception('unexpected hrp: {}'.format(hrp))
1190    data_8bits = segwit_addr.convertbits(data_5bits, 5, 8, False)
1191    # pad with zeroes
1192    COMPRESSED_PUBKEY_LENGTH = 33
1193    data_8bits = data_8bits + ((COMPRESSED_PUBKEY_LENGTH - len(data_8bits)) * [0])
1194    return bytes(data_8bits)
1195
1196
1197def make_closing_tx(local_funding_pubkey: bytes, remote_funding_pubkey: bytes,
1198                    funding_txid: str, funding_pos: int, funding_sat: int,
1199                    outputs: List[PartialTxOutput]) -> PartialTransaction:
1200    c_input = make_funding_input(local_funding_pubkey, remote_funding_pubkey,
1201        funding_pos, funding_txid, funding_sat)
1202    c_input.nsequence = 0xFFFF_FFFF
1203    tx = PartialTransaction.from_io([c_input], outputs, locktime=0, version=2)
1204    return tx
1205
1206
1207def split_host_port(host_port: str) -> Tuple[str, str]: # port returned as string
1208    ipv6  = re.compile(r'\[(?P<host>[:0-9a-f]+)\](?P<port>:\d+)?$')
1209    other = re.compile(r'(?P<host>[^:]+)(?P<port>:\d+)?$')
1210    m = ipv6.match(host_port)
1211    if not m:
1212        m = other.match(host_port)
1213    if not m:
1214        raise ConnStringFormatError(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
1215    host = m.group('host')
1216    if m.group('port'):
1217        port = m.group('port')[1:]
1218    else:
1219        port = '9735'
1220    try:
1221        int(port)
1222    except ValueError:
1223        raise ConnStringFormatError(_('Port number must be decimal'))
1224    return host, port
1225
1226def extract_nodeid(connect_contents: str) -> Tuple[bytes, str]:
1227    rest = None
1228    try:
1229        # connection string?
1230        nodeid_hex, rest = connect_contents.split("@", 1)
1231    except ValueError:
1232        try:
1233            # invoice?
1234            invoice = lndecode(connect_contents)
1235            nodeid_bytes = invoice.pubkey.serialize()
1236            nodeid_hex = bh2u(nodeid_bytes)
1237        except:
1238            # node id as hex?
1239            nodeid_hex = connect_contents
1240    if rest == '':
1241        raise ConnStringFormatError(_('At least a hostname must be supplied after the at symbol.'))
1242    try:
1243        node_id = bfh(nodeid_hex)
1244        if len(node_id) != 33:
1245            raise Exception()
1246    except:
1247        raise ConnStringFormatError(_('Invalid node ID, must be 33 bytes and hexadecimal'))
1248    return node_id, rest
1249
1250
1251# key derivation
1252# see lnd/keychain/derivation.go
1253class LnKeyFamily(IntEnum):
1254    MULTISIG = 0 | BIP32_PRIME
1255    REVOCATION_BASE = 1 | BIP32_PRIME
1256    HTLC_BASE = 2 | BIP32_PRIME
1257    PAYMENT_BASE = 3 | BIP32_PRIME
1258    DELAY_BASE = 4 | BIP32_PRIME
1259    REVOCATION_ROOT = 5 | BIP32_PRIME
1260    NODE_KEY = 6
1261    BACKUP_CIPHER = 7 | BIP32_PRIME
1262
1263
1264def generate_keypair(node: BIP32Node, key_family: LnKeyFamily) -> Keypair:
1265    node2 = node.subkey_at_private_derivation([key_family, 0, 0])
1266    k = node2.eckey.get_secret_bytes()
1267    cK = ecc.ECPrivkey(k).get_public_key_bytes()
1268    return Keypair(cK, k)
1269
1270
1271
1272NUM_MAX_HOPS_IN_PAYMENT_PATH = 20
1273NUM_MAX_EDGES_IN_PAYMENT_PATH = NUM_MAX_HOPS_IN_PAYMENT_PATH
1274
1275
1276class ShortChannelID(bytes):
1277
1278    def __repr__(self):
1279        return f"<ShortChannelID: {format_short_channel_id(self)}>"
1280
1281    def __str__(self):
1282        return format_short_channel_id(self)
1283
1284    @classmethod
1285    def from_components(cls, block_height: int, tx_pos_in_block: int, output_index: int) -> 'ShortChannelID':
1286        bh = block_height.to_bytes(3, byteorder='big')
1287        tpos = tx_pos_in_block.to_bytes(3, byteorder='big')
1288        oi = output_index.to_bytes(2, byteorder='big')
1289        return ShortChannelID(bh + tpos + oi)
1290
1291    @classmethod
1292    def from_str(cls, scid: str) -> 'ShortChannelID':
1293        """Parses a formatted scid str, e.g. '643920x356x0'."""
1294        components = scid.split("x")
1295        if len(components) != 3:
1296            raise ValueError(f"failed to parse ShortChannelID: {scid!r}")
1297        try:
1298            components = [int(x) for x in components]
1299        except ValueError:
1300            raise ValueError(f"failed to parse ShortChannelID: {scid!r}") from None
1301        return ShortChannelID.from_components(*components)
1302
1303    @classmethod
1304    def normalize(cls, data: Union[None, str, bytes, 'ShortChannelID']) -> Optional['ShortChannelID']:
1305        if isinstance(data, ShortChannelID) or data is None:
1306            return data
1307        if isinstance(data, str):
1308            assert len(data) == 16
1309            return ShortChannelID.fromhex(data)
1310        if isinstance(data, (bytes, bytearray)):
1311            assert len(data) == 8
1312            return ShortChannelID(data)
1313
1314    @property
1315    def block_height(self) -> int:
1316        return int.from_bytes(self[:3], byteorder='big')
1317
1318    @property
1319    def txpos(self) -> int:
1320        return int.from_bytes(self[3:6], byteorder='big')
1321
1322    @property
1323    def output_index(self) -> int:
1324        return int.from_bytes(self[6:8], byteorder='big')
1325
1326
1327def format_short_channel_id(short_channel_id: Optional[bytes]):
1328    if not short_channel_id:
1329        return _('Not yet available')
1330    return str(int.from_bytes(short_channel_id[:3], 'big')) \
1331        + 'x' + str(int.from_bytes(short_channel_id[3:6], 'big')) \
1332        + 'x' + str(int.from_bytes(short_channel_id[6:], 'big'))
1333
1334
1335@attr.s(frozen=True)
1336class UpdateAddHtlc:
1337    amount_msat = attr.ib(type=int, kw_only=True)
1338    payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes, repr=lambda val: val.hex())
1339    cltv_expiry = attr.ib(type=int, kw_only=True)
1340    timestamp = attr.ib(type=int, kw_only=True)
1341    htlc_id = attr.ib(type=int, kw_only=True, default=None)
1342
1343    @classmethod
1344    def from_tuple(cls, amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc':
1345        return cls(amount_msat=amount_msat,
1346                   payment_hash=payment_hash,
1347                   cltv_expiry=cltv_expiry,
1348                   htlc_id=htlc_id,
1349                   timestamp=timestamp)
1350
1351    def to_tuple(self):
1352        return (self.amount_msat, self.payment_hash, self.cltv_expiry, self.htlc_id, self.timestamp)
1353
1354
1355class OnionFailureCodeMetaFlag(IntFlag):
1356    BADONION = 0x8000
1357    PERM     = 0x4000
1358    NODE     = 0x2000
1359    UPDATE   = 0x1000
1360
1361
1362