1#!/usr/local/bin/python3.8
2#
3# Electrum - lightweight Bitcoin client
4# Copyright (C) 2011 thomasv@gitorious
5#
6# Permission is hereby granted, free of charge, to any person
7# obtaining a copy of this software and associated documentation files
8# (the "Software"), to deal in the Software without restriction,
9# including without limitation the rights to use, copy, modify, merge,
10# publish, distribute, sublicense, and/or sell copies of the Software,
11# and to permit persons to whom the Software is furnished to do so,
12# subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be
15# included in all copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
21# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
22# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
23# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24# SOFTWARE.
25import os
26import re
27import ssl
28import sys
29import traceback
30import asyncio
31import socket
32from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence, Dict
33from collections import defaultdict
34from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
35import itertools
36import logging
37import hashlib
38import functools
39
40import aiorpcx
41from aiorpcx import TaskGroup
42from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer
43from aiorpcx.curio import timeout_after, TaskTimeout
44from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
45from aiorpcx.rawsocket import RSClient
46import certifi
47
48from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy,
49                   is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
50                   is_int_or_float, is_non_negative_int_or_float)
51from . import util
52from . import x509
53from . import pem
54from . import version
55from . import blockchain
56from .blockchain import Blockchain, HEADER_SIZE
57from . import bitcoin
58from . import constants
59from .i18n import _
60from .logging import Logger
61from .transaction import Transaction
62
63if TYPE_CHECKING:
64    from .network import Network
65    from .simple_config import SimpleConfig
66
67
68ca_path = certifi.where()
69
70BUCKET_NAME_OF_ONION_SERVERS = 'onion'
71
72MAX_INCOMING_MSG_SIZE = 1_000_000  # in bytes
73
74_KNOWN_NETWORK_PROTOCOLS = {'t', 's'}
75PREFERRED_NETWORK_PROTOCOL = 's'
76assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS
77
78
79class NetworkTimeout:
80    # seconds
81    class Generic:
82        NORMAL = 30
83        RELAXED = 45
84        MOST_RELAXED = 600
85
86    class Urgent(Generic):
87        NORMAL = 10
88        RELAXED = 20
89        MOST_RELAXED = 60
90
91
92def assert_non_negative_integer(val: Any) -> None:
93    if not is_non_negative_integer(val):
94        raise RequestCorrupted(f'{val!r} should be a non-negative integer')
95
96
97def assert_integer(val: Any) -> None:
98    if not is_integer(val):
99        raise RequestCorrupted(f'{val!r} should be an integer')
100
101
102def assert_int_or_float(val: Any) -> None:
103    if not is_int_or_float(val):
104        raise RequestCorrupted(f'{val!r} should be int or float')
105
106
107def assert_non_negative_int_or_float(val: Any) -> None:
108    if not is_non_negative_int_or_float(val):
109        raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
110
111
112def assert_hash256_str(val: Any) -> None:
113    if not is_hash256_str(val):
114        raise RequestCorrupted(f'{val!r} should be a hash256 str')
115
116
117def assert_hex_str(val: Any) -> None:
118    if not is_hex_str(val):
119        raise RequestCorrupted(f'{val!r} should be a hex str')
120
121
122def assert_dict_contains_field(d: Any, *, field_name: str) -> Any:
123    if not isinstance(d, dict):
124        raise RequestCorrupted(f'{d!r} should be a dict')
125    if field_name not in d:
126        raise RequestCorrupted(f'required field {field_name!r} missing from dict')
127    return d[field_name]
128
129
130def assert_list_or_tuple(val: Any) -> None:
131    if not isinstance(val, (list, tuple)):
132        raise RequestCorrupted(f'{val!r} should be a list or tuple')
133
134
135class NotificationSession(RPCSession):
136
137    def __init__(self, *args, interface: 'Interface', **kwargs):
138        super(NotificationSession, self).__init__(*args, **kwargs)
139        self.subscriptions = defaultdict(list)
140        self.cache = {}
141        self.default_timeout = NetworkTimeout.Generic.NORMAL
142        self._msg_counter = itertools.count(start=1)
143        self.interface = interface
144        self.cost_hard_limit = 0  # disable aiorpcx resource limits
145
146    async def handle_request(self, request):
147        self.maybe_log(f"--> {request}")
148        try:
149            if isinstance(request, Notification):
150                params, result = request.args[:-1], request.args[-1]
151                key = self.get_hashable_key_for_rpc_call(request.method, params)
152                if key in self.subscriptions:
153                    self.cache[key] = result
154                    for queue in self.subscriptions[key]:
155                        await queue.put(request.args)
156                else:
157                    raise Exception(f'unexpected notification')
158            else:
159                raise Exception(f'unexpected request. not a notification')
160        except Exception as e:
161            self.interface.logger.info(f"error handling request {request}. exc: {repr(e)}")
162            await self.close()
163
164    async def send_request(self, *args, timeout=None, **kwargs):
165        # note: semaphores/timeouts/backpressure etc are handled by
166        # aiorpcx. the timeout arg here in most cases should not be set
167        msg_id = next(self._msg_counter)
168        self.maybe_log(f"<-- {args} {kwargs} (id: {msg_id})")
169        try:
170            # note: RPCSession.send_request raises TaskTimeout in case of a timeout.
171            # TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups
172            response = await asyncio.wait_for(
173                super().send_request(*args, **kwargs),
174                timeout)
175        except (TaskTimeout, asyncio.TimeoutError) as e:
176            raise RequestTimedOut(f'request timed out: {args} (id: {msg_id})') from e
177        except CodeMessageError as e:
178            self.maybe_log(f"--> {repr(e)} (id: {msg_id})")
179            raise
180        else:
181            self.maybe_log(f"--> {response} (id: {msg_id})")
182            return response
183
184    def set_default_timeout(self, timeout):
185        self.sent_request_timeout = timeout
186        self.max_send_delay = timeout
187
188    async def subscribe(self, method: str, params: List, queue: asyncio.Queue):
189        # note: until the cache is written for the first time,
190        # each 'subscribe' call might make a request on the network.
191        key = self.get_hashable_key_for_rpc_call(method, params)
192        self.subscriptions[key].append(queue)
193        if key in self.cache:
194            result = self.cache[key]
195        else:
196            result = await self.send_request(method, params)
197            self.cache[key] = result
198        await queue.put(params + [result])
199
200    def unsubscribe(self, queue):
201        """Unsubscribe a callback to free object references to enable GC."""
202        # note: we can't unsubscribe from the server, so we keep receiving
203        # subsequent notifications
204        for v in self.subscriptions.values():
205            if queue in v:
206                v.remove(queue)
207
208    @classmethod
209    def get_hashable_key_for_rpc_call(cls, method, params):
210        """Hashable index for subscriptions and cache"""
211        return str(method) + repr(params)
212
213    def maybe_log(self, msg: str) -> None:
214        if not self.interface: return
215        if self.interface.debug or self.interface.network.debug:
216            self.interface.logger.debug(msg)
217
218    def default_framer(self):
219        # overridden so that max_size can be customized
220        max_size = int(self.interface.network.config.get('network_max_incoming_msg_size',
221                                                         MAX_INCOMING_MSG_SIZE))
222        return NewlineFramer(max_size=max_size)
223
224
225class NetworkException(Exception): pass
226
227
228class GracefulDisconnect(NetworkException):
229    log_level = logging.INFO
230
231    def __init__(self, *args, log_level=None, **kwargs):
232        Exception.__init__(self, *args, **kwargs)
233        if log_level is not None:
234            self.log_level = log_level
235
236
237class RequestTimedOut(GracefulDisconnect):
238    def __str__(self):
239        return _("Network request timed out.")
240
241
242class RequestCorrupted(Exception): pass
243
244class ErrorParsingSSLCert(Exception): pass
245class ErrorGettingSSLCertFromServer(Exception): pass
246class ErrorSSLCertFingerprintMismatch(Exception): pass
247class InvalidOptionCombination(Exception): pass
248class ConnectError(NetworkException): pass
249
250
251class _RSClient(RSClient):
252    async def create_connection(self):
253        try:
254            return await super().create_connection()
255        except OSError as e:
256            # note: using "from e" here will set __cause__ of ConnectError
257            raise ConnectError(e) from e
258
259
260class ServerAddr:
261
262    def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
263        assert isinstance(host, str), repr(host)
264        if protocol is None:
265            protocol = 's'
266        if not host:
267            raise ValueError('host must not be empty')
268        if host[0] == '[' and host[-1] == ']':  # IPv6
269            host = host[1:-1]
270        try:
271            net_addr = NetAddress(host, port)  # this validates host and port
272        except Exception as e:
273            raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
274        if protocol not in _KNOWN_NETWORK_PROTOCOLS:
275            raise ValueError(f"invalid network protocol: {protocol}")
276        self.host = str(net_addr.host)  # canonical form (if e.g. IPv6 address)
277        self.port = int(net_addr.port)
278        self.protocol = protocol
279        self._net_addr_str = str(net_addr)
280
281    @classmethod
282    def from_str(cls, s: str) -> 'ServerAddr':
283        # host might be IPv6 address, hence do rsplit:
284        host, port, protocol = str(s).rsplit(':', 2)
285        return ServerAddr(host=host, port=port, protocol=protocol)
286
287    @classmethod
288    def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']:
289        """Construct ServerAddr from str, guessing missing details.
290        Ongoing compatibility not guaranteed.
291        """
292        if not s:
293            return None
294        items = str(s).rsplit(':', 2)
295        if len(items) < 2:
296            return None  # although maybe we could guess the port too?
297        host = items[0]
298        port = items[1]
299        if len(items) >= 3:
300            protocol = items[2]
301        else:
302            protocol = PREFERRED_NETWORK_PROTOCOL
303        return ServerAddr(host=host, port=port, protocol=protocol)
304
305    def to_friendly_name(self) -> str:
306        # note: this method is closely linked to from_str_with_inference
307        if self.protocol == 's':  # hide trailing ":s"
308            return self.net_addr_str()
309        return str(self)
310
311    def __str__(self):
312        return '{}:{}'.format(self.net_addr_str(), self.protocol)
313
314    def to_json(self) -> str:
315        return str(self)
316
317    def __repr__(self):
318        return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
319
320    def net_addr_str(self) -> str:
321        return self._net_addr_str
322
323    def __eq__(self, other):
324        if not isinstance(other, ServerAddr):
325            return False
326        return (self.host == other.host
327                and self.port == other.port
328                and self.protocol == other.protocol)
329
330    def __ne__(self, other):
331        return not (self == other)
332
333    def __hash__(self):
334        return hash((self.host, self.port, self.protocol))
335
336
337def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
338    filename = host
339    try:
340        ip = ip_address(host)
341    except ValueError:
342        pass
343    else:
344        if isinstance(ip, IPv6Address):
345            filename = f"ipv6_{ip.packed.hex()}"
346    return os.path.join(config.path, 'certs', filename)
347
348
349class Interface(Logger):
350
351    LOGGING_SHORTCUT = 'i'
352
353    def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
354        self.ready = asyncio.Future()
355        self.got_disconnected = asyncio.Event()
356        self.server = server
357        Logger.__init__(self)
358        assert network.config.path
359        self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
360        self.blockchain = None  # type: Optional[Blockchain]
361        self._requested_chunks = set()  # type: Set[int]
362        self.network = network
363        self.proxy = MySocksProxy.from_proxy_dict(proxy)
364        self.session = None  # type: Optional[NotificationSession]
365        self._ipaddr_bucket = None
366
367        # Latest block header and corresponding height, as claimed by the server.
368        # Note that these values are updated before they are verified.
369        # Especially during initial header sync, verification can take a long time.
370        # Failing verification will get the interface closed.
371        self.tip_header = None
372        self.tip = 0
373
374        self.fee_estimates_eta = {}  # type: Dict[int, int]
375
376        # Dump network messages (only for this interface).  Set at runtime from the console.
377        self.debug = False
378
379        self.taskgroup = SilentTaskGroup()
380
381        async def spawn_task():
382            task = await self.network.taskgroup.spawn(self.run())
383            if sys.version_info >= (3, 8):
384                task.set_name(f"interface::{str(server)}")
385        asyncio.run_coroutine_threadsafe(spawn_task(), self.network.asyncio_loop)
386
387    @property
388    def host(self):
389        return self.server.host
390
391    @property
392    def port(self):
393        return self.server.port
394
395    @property
396    def protocol(self):
397        return self.server.protocol
398
399    def diagnostic_name(self):
400        return self.server.net_addr_str()
401
402    def __str__(self):
403        return f"<Interface {self.diagnostic_name()}>"
404
405    async def is_server_ca_signed(self, ca_ssl_context):
406        """Given a CA enforcing SSL context, returns True if the connection
407        can be established. Returns False if the server has a self-signed
408        certificate but otherwise is okay. Any other failures raise.
409        """
410        try:
411            await self.open_session(ca_ssl_context, exit_early=True)
412        except ConnectError as e:
413            cause = e.__cause__
414            if isinstance(cause, ssl.SSLError) and cause.reason == 'CERTIFICATE_VERIFY_FAILED':
415                # failures due to self-signed certs are normal
416                return False
417            raise
418        return True
419
420    async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context):
421        ca_signed = await self.is_server_ca_signed(ca_ssl_context)
422        if ca_signed:
423            if self._get_expected_fingerprint():
424                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
425            with open(self.cert_path, 'w') as f:
426                # empty file means this is CA signed, not self-signed
427                f.write('')
428        else:
429            await self._save_certificate()
430
431    def _is_saved_ssl_cert_available(self):
432        if not os.path.exists(self.cert_path):
433            return False
434        with open(self.cert_path, 'r') as f:
435            contents = f.read()
436        if contents == '':  # CA signed
437            if self._get_expected_fingerprint():
438                raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers")
439            return True
440        # pinned self-signed cert
441        try:
442            b = pem.dePem(contents, 'CERTIFICATE')
443        except SyntaxError as e:
444            self.logger.info(f"error parsing already saved cert: {e}")
445            raise ErrorParsingSSLCert(e) from e
446        try:
447            x = x509.X509(b)
448        except Exception as e:
449            self.logger.info(f"error parsing already saved cert: {e}")
450            raise ErrorParsingSSLCert(e) from e
451        try:
452            x.check_date()
453        except x509.CertificateError as e:
454            self.logger.info(f"certificate has expired: {e}")
455            os.unlink(self.cert_path)  # delete pinned cert only in this case
456            return False
457        self._verify_certificate_fingerprint(bytearray(b))
458        return True
459
460    async def _get_ssl_context(self):
461        if self.protocol != 's':
462            # using plaintext TCP
463            return None
464
465        # see if we already have cert for this server; or get it for the first time
466        ca_sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_path)
467        if not self._is_saved_ssl_cert_available():
468            try:
469                await self._try_saving_ssl_cert_for_first_time(ca_sslc)
470            except (OSError, ConnectError, aiorpcx.socks.SOCKSError) as e:
471                raise ErrorGettingSSLCertFromServer(e) from e
472        # now we have a file saved in our certificate store
473        siz = os.stat(self.cert_path).st_size
474        if siz == 0:
475            # CA signed cert
476            sslc = ca_sslc
477        else:
478            # pinned self-signed cert
479            sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
480            sslc.check_hostname = 0
481        return sslc
482
483    def handle_disconnect(func):
484        @functools.wraps(func)
485        async def wrapper_func(self: 'Interface', *args, **kwargs):
486            try:
487                return await func(self, *args, **kwargs)
488            except GracefulDisconnect as e:
489                self.logger.log(e.log_level, f"disconnecting due to {repr(e)}")
490            except aiorpcx.jsonrpc.RPCError as e:
491                self.logger.warning(f"disconnecting due to {repr(e)}")
492                self.logger.debug(f"(disconnect) trace for {repr(e)}", exc_info=True)
493            finally:
494                self.got_disconnected.set()
495                await self.network.connection_down(self)
496                # if was not 'ready' yet, schedule waiting coroutines:
497                self.ready.cancel()
498        return wrapper_func
499
500    @ignore_exceptions  # do not kill network.taskgroup
501    @log_exceptions
502    @handle_disconnect
503    async def run(self):
504        try:
505            ssl_context = await self._get_ssl_context()
506        except (ErrorParsingSSLCert, ErrorGettingSSLCertFromServer) as e:
507            self.logger.info(f'disconnecting due to: {repr(e)}')
508            return
509        try:
510            await self.open_session(ssl_context)
511        except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e:
512            # make SSL errors for main interface more visible (to help servers ops debug cert pinning issues)
513            if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError)
514                    and self.is_main_server() and not self.network.auto_connect):
515                self.logger.warning(f'Cannot connect to main server due to SSL error '
516                                    f'(maybe cert changed compared to "{self.cert_path}"). Exc: {repr(e)}')
517            else:
518                self.logger.info(f'disconnecting due to: {repr(e)}')
519            return
520
521    def _mark_ready(self) -> None:
522        if self.ready.cancelled():
523            raise GracefulDisconnect('conn establishment was too slow; *ready* future was cancelled')
524        if self.ready.done():
525            return
526
527        assert self.tip_header
528        chain = blockchain.check_header(self.tip_header)
529        if not chain:
530            self.blockchain = blockchain.get_best_chain()
531        else:
532            self.blockchain = chain
533        assert self.blockchain is not None
534
535        self.logger.info(f"set blockchain with height {self.blockchain.height()}")
536
537        self.ready.set_result(1)
538
539    async def _save_certificate(self) -> None:
540        if not os.path.exists(self.cert_path):
541            # we may need to retry this a few times, in case the handshake hasn't completed
542            for _ in range(10):
543                dercert = await self._fetch_certificate()
544                if dercert:
545                    self.logger.info("succeeded in getting cert")
546                    self._verify_certificate_fingerprint(dercert)
547                    with open(self.cert_path, 'w') as f:
548                        cert = ssl.DER_cert_to_PEM_cert(dercert)
549                        # workaround android bug
550                        cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
551                        f.write(cert)
552                        # even though close flushes we can't fsync when closed.
553                        # and we must flush before fsyncing, cause flush flushes to OS buffer
554                        # fsync writes to OS buffer to disk
555                        f.flush()
556                        os.fsync(f.fileno())
557                    break
558                await asyncio.sleep(1)
559            else:
560                raise GracefulDisconnect("could not get certificate after 10 tries")
561
562    async def _fetch_certificate(self) -> bytes:
563        sslc = ssl.SSLContext()
564        async with _RSClient(session_factory=RPCSession,
565                             host=self.host, port=self.port,
566                             ssl=sslc, proxy=self.proxy) as session:
567            asyncio_transport = session.transport._asyncio_transport  # type: asyncio.BaseTransport
568            ssl_object = asyncio_transport.get_extra_info("ssl_object")  # type: ssl.SSLObject
569            return ssl_object.getpeercert(binary_form=True)
570
571    def _get_expected_fingerprint(self) -> Optional[str]:
572        if self.is_main_server():
573            return self.network.config.get("serverfingerprint")
574
575    def _verify_certificate_fingerprint(self, certificate):
576        expected_fingerprint = self._get_expected_fingerprint()
577        if not expected_fingerprint:
578            return
579        fingerprint = hashlib.sha256(certificate).hexdigest()
580        fingerprints_match = fingerprint.lower() == expected_fingerprint.lower()
581        if not fingerprints_match:
582            util.trigger_callback('cert_mismatch')
583            raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch')
584        self.logger.info("cert fingerprint verification passed")
585
586    async def get_block_header(self, height, assert_mode):
587        self.logger.info(f'requesting block header {height} in mode {assert_mode}')
588        # use lower timeout as we usually have network.bhi_lock here
589        timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent)
590        res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout)
591        return blockchain.deserialize_header(bytes.fromhex(res), height)
592
593    async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
594        if not is_non_negative_integer(height):
595            raise Exception(f"{repr(height)} is not a block height")
596        index = height // 2016
597        if can_return_early and index in self._requested_chunks:
598            return
599        self.logger.info(f"requesting chunk from height {height}")
600        size = 2016
601        if tip is not None:
602            size = min(size, tip - index * 2016 + 1)
603            size = max(size, 0)
604        try:
605            self._requested_chunks.add(index)
606            res = await self.session.send_request('blockchain.block.headers', [index * 2016, size])
607        finally:
608            self._requested_chunks.discard(index)
609        assert_dict_contains_field(res, field_name='count')
610        assert_dict_contains_field(res, field_name='hex')
611        assert_dict_contains_field(res, field_name='max')
612        assert_non_negative_integer(res['count'])
613        assert_non_negative_integer(res['max'])
614        assert_hex_str(res['hex'])
615        if len(res['hex']) != HEADER_SIZE * 2 * res['count']:
616            raise RequestCorrupted('inconsistent chunk hex and count')
617        # we never request more than 2016 headers, but we enforce those fit in a single response
618        if res['max'] < 2016:
619            raise RequestCorrupted(f"server uses too low 'max' count for block.headers: {res['max']} < 2016")
620        if res['count'] != size:
621            raise RequestCorrupted(f"expected {size} headers but only got {res['count']}")
622        conn = self.blockchain.connect_chunk(index, res['hex'])
623        if not conn:
624            return conn, 0
625        return conn, res['count']
626
627    def is_main_server(self) -> bool:
628        return (self.network.interface == self or
629                self.network.interface is None and self.network.default_server == self.server)
630
631    async def open_session(self, sslc, exit_early=False):
632        session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface)
633        async with _RSClient(session_factory=session_factory,
634                             host=self.host, port=self.port,
635                             ssl=sslc, proxy=self.proxy) as session:
636            self.session = session  # type: NotificationSession
637            self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic))
638            try:
639                ver = await session.send_request('server.version', [self.client_name(), version.PROTOCOL_VERSION])
640            except aiorpcx.jsonrpc.RPCError as e:
641                raise GracefulDisconnect(e)  # probably 'unsupported protocol version'
642            if exit_early:
643                return
644            if ver[1] != version.PROTOCOL_VERSION:
645                raise GracefulDisconnect(f'server violated protocol-version-negotiation. '
646                                         f'we asked for {version.PROTOCOL_VERSION!r}, they sent {ver[1]!r}')
647            if not self.network.check_interface_against_healthy_spread_of_connected_servers(self):
648                raise GracefulDisconnect(f'too many connected servers already '
649                                         f'in bucket {self.bucket_based_on_ipaddress()}')
650            self.logger.info(f"connection established. version: {ver}")
651
652            try:
653                async with self.taskgroup as group:
654                    await group.spawn(self.ping)
655                    await group.spawn(self.request_fee_estimates)
656                    await group.spawn(self.run_fetch_blocks)
657                    await group.spawn(self.monitor_connection)
658            except aiorpcx.jsonrpc.RPCError as e:
659                if e.code in (JSONRPC.EXCESSIVE_RESOURCE_USAGE,
660                              JSONRPC.SERVER_BUSY,
661                              JSONRPC.METHOD_NOT_FOUND):
662                    raise GracefulDisconnect(e, log_level=logging.WARNING) from e
663                raise
664
665    async def monitor_connection(self):
666        while True:
667            await asyncio.sleep(1)
668            if not self.session or self.session.is_closing():
669                raise GracefulDisconnect('session was closed')
670
671    async def ping(self):
672        while True:
673            await asyncio.sleep(300)
674            await self.session.send_request('server.ping')
675
676    async def request_fee_estimates(self):
677        from .simple_config import FEE_ETA_TARGETS
678        while True:
679            async with TaskGroup() as group:
680                fee_tasks = []
681                for i in FEE_ETA_TARGETS:
682                    fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
683            for nblock_target, task in fee_tasks:
684                fee = task.result()
685                if fee < 0: continue
686                assert isinstance(fee, int)
687                self.fee_estimates_eta[nblock_target] = fee
688            self.network.update_fee_estimates()
689            await asyncio.sleep(60)
690
691    async def close(self, *, force_after: int = None):
692        """Closes the connection and waits for it to be closed.
693        We try to flush buffered data to the wire, so this can take some time.
694        """
695        if force_after is None:
696            # We give up after a while and just abort the connection.
697            # Note: specifically if the server is running Fulcrum, waiting seems hopeless,
698            #       the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76)
699            force_after = 1  # seconds
700        if self.session:
701            await self.session.close(force_after=force_after)
702        # monitor_connection will cancel tasks
703
704    async def run_fetch_blocks(self):
705        header_queue = asyncio.Queue()
706        await self.session.subscribe('blockchain.headers.subscribe', [], header_queue)
707        while True:
708            item = await header_queue.get()
709            raw_header = item[0]
710            height = raw_header['height']
711            header = blockchain.deserialize_header(bfh(raw_header['hex']), height)
712            self.tip_header = header
713            self.tip = height
714            if self.tip < constants.net.max_checkpoint():
715                raise GracefulDisconnect('server tip below max checkpoint')
716            self._mark_ready()
717            await self._process_header_at_tip()
718            # header processing done
719            util.trigger_callback('blockchain_updated')
720            util.trigger_callback('network_updated')
721            await self.network.switch_unwanted_fork_interface()
722            await self.network.switch_lagging_interface()
723
724    async def _process_header_at_tip(self):
725        height, header = self.tip, self.tip_header
726        async with self.network.bhi_lock:
727            if self.blockchain.height() >= height and self.blockchain.check_header(header):
728                # another interface amended the blockchain
729                self.logger.info(f"skipping header {height}")
730                return
731            _, height = await self.step(height, header)
732            # in the simple case, height == self.tip+1
733            if height <= self.tip:
734                await self.sync_until(height)
735
736    async def sync_until(self, height, next_height=None):
737        if next_height is None:
738            next_height = self.tip
739        last = None
740        while last is None or height <= next_height:
741            prev_last, prev_height = last, height
742            if next_height > height + 10:
743                could_connect, num_headers = await self.request_chunk(height, next_height)
744                if not could_connect:
745                    if height <= constants.net.max_checkpoint():
746                        raise GracefulDisconnect('server chain conflicts with checkpoints or genesis')
747                    last, height = await self.step(height)
748                    continue
749                util.trigger_callback('network_updated')
750                height = (height // 2016 * 2016) + num_headers
751                assert height <= next_height+1, (height, self.tip)
752                last = 'catchup'
753            else:
754                last, height = await self.step(height)
755            assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until'
756        return last, height
757
758    async def step(self, height, header=None):
759        assert 0 <= height <= self.tip, (height, self.tip)
760        if header is None:
761            header = await self.get_block_header(height, 'catchup')
762
763        chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
764        if chain:
765            self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
766            # note: there is an edge case here that is not handled.
767            # we might know the blockhash (enough for check_header) but
768            # not have the header itself. e.g. regtest chain with only genesis.
769            # this situation resolves itself on the next block
770            return 'catchup', height+1
771
772        can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
773        if not can_connect:
774            self.logger.info(f"can't connect {height}")
775            height, header, bad, bad_header = await self._search_headers_backwards(height, header)
776            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
777            can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
778            assert chain or can_connect
779        if can_connect:
780            self.logger.info(f"could connect {height}")
781            height += 1
782            if isinstance(can_connect, Blockchain):  # not when mocking
783                self.blockchain = can_connect
784                self.blockchain.save_header(header)
785            return 'catchup', height
786
787        good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain)
788        return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header)
789
790    async def _search_headers_binary(self, height, bad, bad_header, chain):
791        assert bad == bad_header['block_height']
792        _assert_header_does_not_check_against_any_chain(bad_header)
793
794        self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
795        good = height
796        while True:
797            assert good < bad, (good, bad)
798            height = (good + bad) // 2
799            self.logger.info(f"binary step. good {good}, bad {bad}, height {height}")
800            header = await self.get_block_header(height, 'binary')
801            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
802            if chain:
803                self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain
804                good = height
805            else:
806                bad = height
807                bad_header = header
808            if good + 1 == bad:
809                break
810
811        mock = 'mock' in bad_header and bad_header['mock']['connect'](height)
812        real = not mock and self.blockchain.can_connect(bad_header, check_height=False)
813        if not real and not mock:
814            raise Exception('unexpected bad header during binary: {}'.format(bad_header))
815        _assert_header_does_not_check_against_any_chain(bad_header)
816
817        self.logger.info(f"binary search exited. good {good}, bad {bad}")
818        return good, bad, bad_header
819
820    async def _resolve_potential_chain_fork_given_forkpoint(self, good, bad, bad_header):
821        assert good + 1 == bad
822        assert bad == bad_header['block_height']
823        _assert_header_does_not_check_against_any_chain(bad_header)
824        # 'good' is the height of a block 'good_header', somewhere in self.blockchain.
825        # bad_header connects to good_header; bad_header itself is NOT in self.blockchain.
826
827        bh = self.blockchain.height()
828        assert bh >= good, (bh, good)
829        if bh == good:
830            height = good + 1
831            self.logger.info(f"catching up from {height}")
832            return 'no_fork', height
833
834        # this is a new fork we don't yet have
835        height = bad + 1
836        self.logger.info(f"new fork at bad height {bad}")
837        forkfun = self.blockchain.fork if 'mock' not in bad_header else bad_header['mock']['fork']
838        b = forkfun(bad_header)  # type: Blockchain
839        self.blockchain = b
840        assert b.forkpoint == bad
841        return 'fork', height
842
843    async def _search_headers_backwards(self, height, header):
844        async def iterate():
845            nonlocal height, header
846            checkp = False
847            if height <= constants.net.max_checkpoint():
848                height = constants.net.max_checkpoint()
849                checkp = True
850            header = await self.get_block_header(height, 'backward')
851            chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
852            can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height)
853            if chain or can_connect:
854                return False
855            if checkp:
856                raise GracefulDisconnect("server chain conflicts with checkpoints")
857            return True
858
859        bad, bad_header = height, header
860        _assert_header_does_not_check_against_any_chain(bad_header)
861        with blockchain.blockchains_lock: chains = list(blockchain.blockchains.values())
862        local_max = max([0] + [x.height() for x in chains]) if 'mock' not in header else float('inf')
863        height = min(local_max + 1, height - 1)
864        while await iterate():
865            bad, bad_header = height, header
866            delta = self.tip - height
867            height = self.tip - 2 * delta
868
869        _assert_header_does_not_check_against_any_chain(bad_header)
870        self.logger.info(f"exiting backward mode at {height}")
871        return height, header, bad, bad_header
872
873    @classmethod
874    def client_name(cls) -> str:
875        return f'electrum/{version.ELECTRUM_VERSION}'
876
877    def is_tor(self):
878        return self.host.endswith('.onion')
879
880    def ip_addr(self) -> Optional[str]:
881        session = self.session
882        if not session: return None
883        peer_addr = session.remote_address()
884        if not peer_addr: return None
885        return str(peer_addr.host)
886
887    def bucket_based_on_ipaddress(self) -> str:
888        def do_bucket():
889            if self.is_tor():
890                return BUCKET_NAME_OF_ONION_SERVERS
891            try:
892                ip_addr = ip_address(self.ip_addr())  # type: Union[IPv4Address, IPv6Address]
893            except ValueError:
894                return ''
895            if not ip_addr:
896                return ''
897            if ip_addr.is_loopback:  # localhost is exempt
898                return ''
899            if ip_addr.version == 4:
900                slash16 = IPv4Network(ip_addr).supernet(prefixlen_diff=32-16)
901                return str(slash16)
902            elif ip_addr.version == 6:
903                slash48 = IPv6Network(ip_addr).supernet(prefixlen_diff=128-48)
904                return str(slash48)
905            return ''
906
907        if not self._ipaddr_bucket:
908            self._ipaddr_bucket = do_bucket()
909        return self._ipaddr_bucket
910
911    async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
912        if not is_hash256_str(tx_hash):
913            raise Exception(f"{repr(tx_hash)} is not a txid")
914        if not is_non_negative_integer(tx_height):
915            raise Exception(f"{repr(tx_height)} is not a block height")
916        # do request
917        res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
918        # check response
919        block_height = assert_dict_contains_field(res, field_name='block_height')
920        merkle = assert_dict_contains_field(res, field_name='merkle')
921        pos = assert_dict_contains_field(res, field_name='pos')
922        # note: tx_height was just a hint to the server, don't enforce the response to match it
923        assert_non_negative_integer(block_height)
924        assert_non_negative_integer(pos)
925        assert_list_or_tuple(merkle)
926        for item in merkle:
927            assert_hash256_str(item)
928        return res
929
930    async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
931        if not is_hash256_str(tx_hash):
932            raise Exception(f"{repr(tx_hash)} is not a txid")
933        raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
934        # validate response
935        if not is_hex_str(raw):
936            raise RequestCorrupted(f"received garbage (non-hex) as tx data (txid {tx_hash}): {raw!r}")
937        tx = Transaction(raw)
938        try:
939            tx.deserialize()  # see if raises
940        except Exception as e:
941            raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e
942        if tx.txid() != tx_hash:
943            raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})")
944        return raw
945
946    async def get_history_for_scripthash(self, sh: str) -> List[dict]:
947        if not is_hash256_str(sh):
948            raise Exception(f"{repr(sh)} is not a scripthash")
949        # do request
950        res = await self.session.send_request('blockchain.scripthash.get_history', [sh])
951        # check response
952        assert_list_or_tuple(res)
953        prev_height = 1
954        for tx_item in res:
955            height = assert_dict_contains_field(tx_item, field_name='height')
956            assert_dict_contains_field(tx_item, field_name='tx_hash')
957            assert_integer(height)
958            assert_hash256_str(tx_item['tx_hash'])
959            if height in (-1, 0):
960                assert_dict_contains_field(tx_item, field_name='fee')
961                assert_non_negative_integer(tx_item['fee'])
962                prev_height = - float("inf")  # this ensures confirmed txs can't follow mempool txs
963            else:
964                # check monotonicity of heights
965                if height < prev_height:
966                    raise RequestCorrupted(f'heights of confirmed txs must be in increasing order')
967                prev_height = height
968        hashes = set(map(lambda item: item['tx_hash'], res))
969        if len(hashes) != len(res):
970            # Either server is sending garbage... or maybe if server is race-prone
971            # a recently mined tx could be included in both last block and mempool?
972            # Still, it's simplest to just disregard the response.
973            raise RequestCorrupted(f"server history has non-unique txids for sh={sh}")
974        return res
975
976    async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
977        if not is_hash256_str(sh):
978            raise Exception(f"{repr(sh)} is not a scripthash")
979        # do request
980        res = await self.session.send_request('blockchain.scripthash.listunspent', [sh])
981        # check response
982        assert_list_or_tuple(res)
983        for utxo_item in res:
984            assert_dict_contains_field(utxo_item, field_name='tx_pos')
985            assert_dict_contains_field(utxo_item, field_name='value')
986            assert_dict_contains_field(utxo_item, field_name='tx_hash')
987            assert_dict_contains_field(utxo_item, field_name='height')
988            assert_non_negative_integer(utxo_item['tx_pos'])
989            assert_non_negative_integer(utxo_item['value'])
990            assert_non_negative_integer(utxo_item['height'])
991            assert_hash256_str(utxo_item['tx_hash'])
992        return res
993
994    async def get_balance_for_scripthash(self, sh: str) -> dict:
995        if not is_hash256_str(sh):
996            raise Exception(f"{repr(sh)} is not a scripthash")
997        # do request
998        res = await self.session.send_request('blockchain.scripthash.get_balance', [sh])
999        # check response
1000        assert_dict_contains_field(res, field_name='confirmed')
1001        assert_dict_contains_field(res, field_name='unconfirmed')
1002        assert_non_negative_integer(res['confirmed'])
1003        assert_non_negative_integer(res['unconfirmed'])
1004        return res
1005
1006    async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool):
1007        if not is_non_negative_integer(tx_height):
1008            raise Exception(f"{repr(tx_height)} is not a block height")
1009        if not is_non_negative_integer(tx_pos):
1010            raise Exception(f"{repr(tx_pos)} should be non-negative integer")
1011        # do request
1012        res = await self.session.send_request(
1013            'blockchain.transaction.id_from_pos',
1014            [tx_height, tx_pos, merkle],
1015        )
1016        # check response
1017        if merkle:
1018            assert_dict_contains_field(res, field_name='tx_hash')
1019            assert_dict_contains_field(res, field_name='merkle')
1020            assert_hash256_str(res['tx_hash'])
1021            assert_list_or_tuple(res['merkle'])
1022            for node_hash in res['merkle']:
1023                assert_hash256_str(node_hash)
1024        else:
1025            assert_hash256_str(res)
1026        return res
1027
1028    async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]:
1029        # do request
1030        res = await self.session.send_request('mempool.get_fee_histogram')
1031        # check response
1032        assert_list_or_tuple(res)
1033        prev_fee = float('inf')
1034        for fee, s in res:
1035            assert_non_negative_int_or_float(fee)
1036            assert_non_negative_integer(s)
1037            if fee >= prev_fee:  # check monotonicity
1038                raise RequestCorrupted(f'fees must be in decreasing order')
1039            prev_fee = fee
1040        return res
1041
1042    async def get_server_banner(self) -> str:
1043        # do request
1044        res = await self.session.send_request('server.banner')
1045        # check response
1046        if not isinstance(res, str):
1047            raise RequestCorrupted(f'{res!r} should be a str')
1048        return res
1049
1050    async def get_donation_address(self) -> str:
1051        # do request
1052        res = await self.session.send_request('server.donation_address')
1053        # check response
1054        if not res:  # ignore empty string
1055            return ''
1056        if not bitcoin.is_address(res):
1057            # note: do not hard-fail -- allow server to use future-type
1058            #       bitcoin address we do not recognize
1059            self.logger.info(f"invalid donation address from server: {repr(res)}")
1060            res = ''
1061        return res
1062
1063    async def get_relay_fee(self) -> int:
1064        """Returns the min relay feerate in sat/kbyte."""
1065        # do request
1066        res = await self.session.send_request('blockchain.relayfee')
1067        # check response
1068        assert_non_negative_int_or_float(res)
1069        relayfee = int(res * bitcoin.COIN)
1070        relayfee = max(0, relayfee)
1071        return relayfee
1072
1073    async def get_estimatefee(self, num_blocks: int) -> int:
1074        """Returns a feerate estimate for getting confirmed within
1075        num_blocks blocks, in sat/kbyte.
1076        """
1077        if not is_non_negative_integer(num_blocks):
1078            raise Exception(f"{repr(num_blocks)} is not a num_blocks")
1079        # do request
1080        res = await self.session.send_request('blockchain.estimatefee', [num_blocks])
1081        # check response
1082        if res != -1:
1083            assert_non_negative_int_or_float(res)
1084            res = int(res * bitcoin.COIN)
1085        return res
1086
1087
1088def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
1089    chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
1090    if chain_bad:
1091        raise Exception('bad_header must not check!')
1092
1093
1094def check_cert(host, cert):
1095    try:
1096        b = pem.dePem(cert, 'CERTIFICATE')
1097        x = x509.X509(b)
1098    except:
1099        traceback.print_exc(file=sys.stdout)
1100        return
1101
1102    try:
1103        x.check_date()
1104        expired = False
1105    except:
1106        expired = True
1107
1108    m = "host: %s\n"%host
1109    m += "has_expired: %s\n"% expired
1110    util.print_msg(m)
1111
1112
1113# Used by tests
1114def _match_hostname(name, val):
1115    if val == name:
1116        return True
1117
1118    return val.startswith('*.') and name.endswith(val[1:])
1119
1120
1121def test_certificates():
1122    from .simple_config import SimpleConfig
1123    config = SimpleConfig()
1124    mydir = os.path.join(config.path, "certs")
1125    certs = os.listdir(mydir)
1126    for c in certs:
1127        p = os.path.join(mydir,c)
1128        with open(p, encoding='utf-8') as f:
1129            cert = f.read()
1130        check_cert(c, cert)
1131
1132if __name__ == "__main__":
1133    test_certificates()
1134