1import time
2from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any
3from decimal import Decimal
4
5import attr
6
7from .json_db import StoredObject
8from .i18n import _
9from .util import age, InvoiceError
10from .lnaddr import lndecode, LnAddr
11from . import constants
12from .bitcoin import COIN, TOTAL_COIN_SUPPLY_LIMIT_IN_BTC
13from .transaction import PartialTxOutput
14
15if TYPE_CHECKING:
16    from .paymentrequest import PaymentRequest
17
18# convention: 'invoices' = outgoing , 'request' = incoming
19
20# types of payment requests
21PR_TYPE_ONCHAIN = 0
22PR_TYPE_LN = 2
23
24# status of payment requests
25PR_UNPAID   = 0     # if onchain: invoice amt not reached by txs in mempool+chain. if LN: invoice not paid.
26PR_EXPIRED  = 1     # invoice is unpaid and expiry time reached
27PR_UNKNOWN  = 2     # e.g. invoice not found
28PR_PAID     = 3     # if onchain: paid and mined (1 conf). if LN: invoice is paid.
29PR_INFLIGHT = 4     # only for LN. payment attempt in progress
30PR_FAILED   = 5     # only for LN. we attempted to pay it, but all attempts failed
31PR_ROUTING  = 6     # only for LN. *unused* atm.
32PR_UNCONFIRMED = 7  # only onchain. invoice is satisfied but tx is not mined yet.
33
34pr_color = {
35    PR_UNPAID:   (.7, .7, .7, 1),
36    PR_PAID:     (.2, .9, .2, 1),
37    PR_UNKNOWN:  (.7, .7, .7, 1),
38    PR_EXPIRED:  (.9, .2, .2, 1),
39    PR_INFLIGHT: (.9, .6, .3, 1),
40    PR_FAILED:   (.9, .2, .2, 1),
41    PR_ROUTING: (.9, .6, .3, 1),
42    PR_UNCONFIRMED: (.9, .6, .3, 1),
43}
44
45pr_tooltips = {
46    PR_UNPAID:_('Unpaid'),
47    PR_PAID:_('Paid'),
48    PR_UNKNOWN:_('Unknown'),
49    PR_EXPIRED:_('Expired'),
50    PR_INFLIGHT:_('In progress'),
51    PR_FAILED:_('Failed'),
52    PR_ROUTING: _('Computing route...'),
53    PR_UNCONFIRMED: _('Unconfirmed'),
54}
55
56PR_DEFAULT_EXPIRATION_WHEN_CREATING = 24*60*60  # 1 day
57pr_expiration_values = {
58    0: _('Never'),
59    10*60: _('10 minutes'),
60    60*60: _('1 hour'),
61    24*60*60: _('1 day'),
62    7*24*60*60: _('1 week'),
63}
64assert PR_DEFAULT_EXPIRATION_WHEN_CREATING in pr_expiration_values
65
66
67def _decode_outputs(outputs) -> List[PartialTxOutput]:
68    ret = []
69    for output in outputs:
70        if not isinstance(output, PartialTxOutput):
71            output = PartialTxOutput.from_legacy_tuple(*output)
72        ret.append(output)
73    return ret
74
75
76# hack: BOLT-11 is not really clear on what an expiry of 0 means.
77# It probably interprets it as 0 seconds, so already expired...
78# Our higher level invoices code however uses 0 for "never".
79# Hence set some high expiration here
80LN_EXPIRY_NEVER = 100 * 365 * 24 * 60 * 60  # 100 years
81
82@attr.s
83class Invoice(StoredObject):
84    type = attr.ib(type=int, kw_only=True)
85
86    message: str
87    exp: int
88    time: int
89
90    def is_lightning(self):
91        return self.type == PR_TYPE_LN
92
93    def get_status_str(self, status):
94        status_str = pr_tooltips[status]
95        if status == PR_UNPAID:
96            if self.exp > 0 and self.exp != LN_EXPIRY_NEVER:
97                expiration = self.exp + self.time
98                status_str = _('Expires') + ' ' + age(expiration, include_seconds=True)
99        return status_str
100
101    def get_amount_sat(self) -> Union[int, Decimal, str, None]:
102        """Returns a decimal satoshi amount, or '!' or None."""
103        raise NotImplementedError()
104
105    @classmethod
106    def from_json(cls, x: dict) -> 'Invoice':
107        # note: these raise if x has extra fields
108        if x.get('type') == PR_TYPE_LN:
109            return LNInvoice(**x)
110        else:
111            return OnchainInvoice(**x)
112
113
114@attr.s
115class OnchainInvoice(Invoice):
116    message = attr.ib(type=str, kw_only=True)
117    amount_sat = attr.ib(kw_only=True)  # type: Union[int, str]  # in satoshis. can be '!'
118    exp = attr.ib(type=int, kw_only=True, validator=attr.validators.instance_of(int))
119    time = attr.ib(type=int, kw_only=True, validator=attr.validators.instance_of(int))
120    id = attr.ib(type=str, kw_only=True)
121    outputs = attr.ib(kw_only=True, converter=_decode_outputs)  # type: List[PartialTxOutput]
122    bip70 = attr.ib(type=str, kw_only=True)  # type: Optional[str]
123    requestor = attr.ib(type=str, kw_only=True)  # type: Optional[str]
124    height = attr.ib(type=int, kw_only=True, validator=attr.validators.instance_of(int))
125
126    def get_address(self) -> str:
127        """returns the first address, to be displayed in GUI"""
128        return self.outputs[0].address
129
130    def get_amount_sat(self) -> Union[int, str]:
131        return self.amount_sat or 0
132
133    @amount_sat.validator
134    def _validate_amount(self, attribute, value):
135        if isinstance(value, int):
136            if not (0 <= value <= TOTAL_COIN_SUPPLY_LIMIT_IN_BTC * COIN):
137                raise InvoiceError(f"amount is out-of-bounds: {value!r} sat")
138        elif isinstance(value, str):
139            if value != "!":
140                raise InvoiceError(f"unexpected amount: {value!r}")
141        else:
142            raise InvoiceError(f"unexpected amount: {value!r}")
143
144    @classmethod
145    def from_bip70_payreq(cls, pr: 'PaymentRequest', height:int) -> 'OnchainInvoice':
146        return OnchainInvoice(
147            type=PR_TYPE_ONCHAIN,
148            amount_sat=pr.get_amount(),
149            outputs=pr.get_outputs(),
150            message=pr.get_memo(),
151            id=pr.get_id(),
152            time=pr.get_time(),
153            exp=pr.get_expiration_date() - pr.get_time(),
154            bip70=pr.raw.hex(),
155            requestor=pr.get_requestor(),
156            height=height,
157        )
158
159@attr.s
160class LNInvoice(Invoice):
161    invoice = attr.ib(type=str)
162    amount_msat = attr.ib(kw_only=True)  # type: Optional[int]  # needed for zero amt invoices
163
164    __lnaddr = None
165
166    @invoice.validator
167    def _validate_invoice_str(self, attribute, value):
168        lndecode(value)  # this checks the str can be decoded
169
170    @amount_msat.validator
171    def _validate_amount(self, attribute, value):
172        if value is None:
173            return
174        if isinstance(value, int):
175            if not (0 <= value <= TOTAL_COIN_SUPPLY_LIMIT_IN_BTC * COIN * 1000):
176                raise InvoiceError(f"amount is out-of-bounds: {value!r} msat")
177        else:
178            raise InvoiceError(f"unexpected amount: {value!r}")
179
180    @property
181    def _lnaddr(self) -> LnAddr:
182        if self.__lnaddr is None:
183            self.__lnaddr = lndecode(self.invoice)
184        return self.__lnaddr
185
186    @property
187    def rhash(self) -> str:
188        return self._lnaddr.paymenthash.hex()
189
190    def get_amount_msat(self) -> Optional[int]:
191        amount_btc = self._lnaddr.amount
192        amount = int(amount_btc * COIN * 1000) if amount_btc else None
193        return amount or self.amount_msat
194
195    def get_amount_sat(self) -> Union[Decimal, None]:
196        amount_msat = self.get_amount_msat()
197        if amount_msat is None:
198            return None
199        return Decimal(amount_msat) / 1000
200
201    @property
202    def exp(self) -> int:
203        return self._lnaddr.get_expiry()
204
205    @property
206    def time(self) -> int:
207        return self._lnaddr.date
208
209    @property
210    def message(self) -> str:
211        return self._lnaddr.get_description()
212
213    @classmethod
214    def from_bech32(cls, invoice: str) -> 'LNInvoice':
215        """Constructs LNInvoice object from BOLT-11 string.
216        Might raise InvoiceError.
217        """
218        try:
219            lnaddr = lndecode(invoice)
220        except Exception as e:
221            raise InvoiceError(e) from e
222        amount_msat = lnaddr.get_amount_msat()
223        return LNInvoice(
224            type=PR_TYPE_LN,
225            invoice=invoice,
226            amount_msat=amount_msat,
227        )
228
229    def to_debug_json(self) -> Dict[str, Any]:
230        d = self.to_json()
231        d.update({
232            'pubkey': self._lnaddr.pubkey.serialize().hex(),
233            'amount_BTC': str(self._lnaddr.amount),
234            'rhash': self._lnaddr.paymenthash.hex(),
235            'description': self._lnaddr.get_description(),
236            'exp': self._lnaddr.get_expiry(),
237            'time': self._lnaddr.date,
238            # 'tags': str(lnaddr.tags),
239        })
240        return d
241