1#!/usr/local/bin/python3.8 2# 3# Electrum - lightweight Bitcoin client 4# Copyright (C) 2011 thomasv@gitorious 5# 6# Permission is hereby granted, free of charge, to any person 7# obtaining a copy of this software and associated documentation files 8# (the "Software"), to deal in the Software without restriction, 9# including without limitation the rights to use, copy, modify, merge, 10# publish, distribute, sublicense, and/or sell copies of the Software, 11# and to permit persons to whom the Software is furnished to do so, 12# subject to the following conditions: 13# 14# The above copyright notice and this permission notice shall be 15# included in all copies or substantial portions of the Software. 16# 17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 21# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 22# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 23# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24# SOFTWARE. 25import os 26import re 27import ssl 28import sys 29import traceback 30import asyncio 31import socket 32from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence, Dict 33from collections import defaultdict 34from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address 35import itertools 36import logging 37import hashlib 38import functools 39 40import aiorpcx 41from aiorpcx import TaskGroup 42from aiorpcx import RPCSession, Notification, NetAddress, NewlineFramer 43from aiorpcx.curio import timeout_after, TaskTimeout 44from aiorpcx.jsonrpc import JSONRPC, CodeMessageError 45from aiorpcx.rawsocket import RSClient 46import certifi 47 48from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy, 49 is_integer, is_non_negative_integer, is_hash256_str, is_hex_str, 50 is_int_or_float, is_non_negative_int_or_float) 51from . import util 52from . import x509 53from . import pem 54from . import version 55from . import blockchain 56from .blockchain import Blockchain, HEADER_SIZE 57from . import bitcoin 58from . import constants 59from .i18n import _ 60from .logging import Logger 61from .transaction import Transaction 62 63if TYPE_CHECKING: 64 from .network import Network 65 from .simple_config import SimpleConfig 66 67 68ca_path = certifi.where() 69 70BUCKET_NAME_OF_ONION_SERVERS = 'onion' 71 72MAX_INCOMING_MSG_SIZE = 1_000_000 # in bytes 73 74_KNOWN_NETWORK_PROTOCOLS = {'t', 's'} 75PREFERRED_NETWORK_PROTOCOL = 's' 76assert PREFERRED_NETWORK_PROTOCOL in _KNOWN_NETWORK_PROTOCOLS 77 78 79class NetworkTimeout: 80 # seconds 81 class Generic: 82 NORMAL = 30 83 RELAXED = 45 84 MOST_RELAXED = 600 85 86 class Urgent(Generic): 87 NORMAL = 10 88 RELAXED = 20 89 MOST_RELAXED = 60 90 91 92def assert_non_negative_integer(val: Any) -> None: 93 if not is_non_negative_integer(val): 94 raise RequestCorrupted(f'{val!r} should be a non-negative integer') 95 96 97def assert_integer(val: Any) -> None: 98 if not is_integer(val): 99 raise RequestCorrupted(f'{val!r} should be an integer') 100 101 102def assert_int_or_float(val: Any) -> None: 103 if not is_int_or_float(val): 104 raise RequestCorrupted(f'{val!r} should be int or float') 105 106 107def assert_non_negative_int_or_float(val: Any) -> None: 108 if not is_non_negative_int_or_float(val): 109 raise RequestCorrupted(f'{val!r} should be a non-negative int or float') 110 111 112def assert_hash256_str(val: Any) -> None: 113 if not is_hash256_str(val): 114 raise RequestCorrupted(f'{val!r} should be a hash256 str') 115 116 117def assert_hex_str(val: Any) -> None: 118 if not is_hex_str(val): 119 raise RequestCorrupted(f'{val!r} should be a hex str') 120 121 122def assert_dict_contains_field(d: Any, *, field_name: str) -> Any: 123 if not isinstance(d, dict): 124 raise RequestCorrupted(f'{d!r} should be a dict') 125 if field_name not in d: 126 raise RequestCorrupted(f'required field {field_name!r} missing from dict') 127 return d[field_name] 128 129 130def assert_list_or_tuple(val: Any) -> None: 131 if not isinstance(val, (list, tuple)): 132 raise RequestCorrupted(f'{val!r} should be a list or tuple') 133 134 135class NotificationSession(RPCSession): 136 137 def __init__(self, *args, interface: 'Interface', **kwargs): 138 super(NotificationSession, self).__init__(*args, **kwargs) 139 self.subscriptions = defaultdict(list) 140 self.cache = {} 141 self.default_timeout = NetworkTimeout.Generic.NORMAL 142 self._msg_counter = itertools.count(start=1) 143 self.interface = interface 144 self.cost_hard_limit = 0 # disable aiorpcx resource limits 145 146 async def handle_request(self, request): 147 self.maybe_log(f"--> {request}") 148 try: 149 if isinstance(request, Notification): 150 params, result = request.args[:-1], request.args[-1] 151 key = self.get_hashable_key_for_rpc_call(request.method, params) 152 if key in self.subscriptions: 153 self.cache[key] = result 154 for queue in self.subscriptions[key]: 155 await queue.put(request.args) 156 else: 157 raise Exception(f'unexpected notification') 158 else: 159 raise Exception(f'unexpected request. not a notification') 160 except Exception as e: 161 self.interface.logger.info(f"error handling request {request}. exc: {repr(e)}") 162 await self.close() 163 164 async def send_request(self, *args, timeout=None, **kwargs): 165 # note: semaphores/timeouts/backpressure etc are handled by 166 # aiorpcx. the timeout arg here in most cases should not be set 167 msg_id = next(self._msg_counter) 168 self.maybe_log(f"<-- {args} {kwargs} (id: {msg_id})") 169 try: 170 # note: RPCSession.send_request raises TaskTimeout in case of a timeout. 171 # TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups 172 response = await asyncio.wait_for( 173 super().send_request(*args, **kwargs), 174 timeout) 175 except (TaskTimeout, asyncio.TimeoutError) as e: 176 raise RequestTimedOut(f'request timed out: {args} (id: {msg_id})') from e 177 except CodeMessageError as e: 178 self.maybe_log(f"--> {repr(e)} (id: {msg_id})") 179 raise 180 else: 181 self.maybe_log(f"--> {response} (id: {msg_id})") 182 return response 183 184 def set_default_timeout(self, timeout): 185 self.sent_request_timeout = timeout 186 self.max_send_delay = timeout 187 188 async def subscribe(self, method: str, params: List, queue: asyncio.Queue): 189 # note: until the cache is written for the first time, 190 # each 'subscribe' call might make a request on the network. 191 key = self.get_hashable_key_for_rpc_call(method, params) 192 self.subscriptions[key].append(queue) 193 if key in self.cache: 194 result = self.cache[key] 195 else: 196 result = await self.send_request(method, params) 197 self.cache[key] = result 198 await queue.put(params + [result]) 199 200 def unsubscribe(self, queue): 201 """Unsubscribe a callback to free object references to enable GC.""" 202 # note: we can't unsubscribe from the server, so we keep receiving 203 # subsequent notifications 204 for v in self.subscriptions.values(): 205 if queue in v: 206 v.remove(queue) 207 208 @classmethod 209 def get_hashable_key_for_rpc_call(cls, method, params): 210 """Hashable index for subscriptions and cache""" 211 return str(method) + repr(params) 212 213 def maybe_log(self, msg: str) -> None: 214 if not self.interface: return 215 if self.interface.debug or self.interface.network.debug: 216 self.interface.logger.debug(msg) 217 218 def default_framer(self): 219 # overridden so that max_size can be customized 220 max_size = int(self.interface.network.config.get('network_max_incoming_msg_size', 221 MAX_INCOMING_MSG_SIZE)) 222 return NewlineFramer(max_size=max_size) 223 224 225class NetworkException(Exception): pass 226 227 228class GracefulDisconnect(NetworkException): 229 log_level = logging.INFO 230 231 def __init__(self, *args, log_level=None, **kwargs): 232 Exception.__init__(self, *args, **kwargs) 233 if log_level is not None: 234 self.log_level = log_level 235 236 237class RequestTimedOut(GracefulDisconnect): 238 def __str__(self): 239 return _("Network request timed out.") 240 241 242class RequestCorrupted(Exception): pass 243 244class ErrorParsingSSLCert(Exception): pass 245class ErrorGettingSSLCertFromServer(Exception): pass 246class ErrorSSLCertFingerprintMismatch(Exception): pass 247class InvalidOptionCombination(Exception): pass 248class ConnectError(NetworkException): pass 249 250 251class _RSClient(RSClient): 252 async def create_connection(self): 253 try: 254 return await super().create_connection() 255 except OSError as e: 256 # note: using "from e" here will set __cause__ of ConnectError 257 raise ConnectError(e) from e 258 259 260class ServerAddr: 261 262 def __init__(self, host: str, port: Union[int, str], *, protocol: str = None): 263 assert isinstance(host, str), repr(host) 264 if protocol is None: 265 protocol = 's' 266 if not host: 267 raise ValueError('host must not be empty') 268 if host[0] == '[' and host[-1] == ']': # IPv6 269 host = host[1:-1] 270 try: 271 net_addr = NetAddress(host, port) # this validates host and port 272 except Exception as e: 273 raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e 274 if protocol not in _KNOWN_NETWORK_PROTOCOLS: 275 raise ValueError(f"invalid network protocol: {protocol}") 276 self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address) 277 self.port = int(net_addr.port) 278 self.protocol = protocol 279 self._net_addr_str = str(net_addr) 280 281 @classmethod 282 def from_str(cls, s: str) -> 'ServerAddr': 283 # host might be IPv6 address, hence do rsplit: 284 host, port, protocol = str(s).rsplit(':', 2) 285 return ServerAddr(host=host, port=port, protocol=protocol) 286 287 @classmethod 288 def from_str_with_inference(cls, s: str) -> Optional['ServerAddr']: 289 """Construct ServerAddr from str, guessing missing details. 290 Ongoing compatibility not guaranteed. 291 """ 292 if not s: 293 return None 294 items = str(s).rsplit(':', 2) 295 if len(items) < 2: 296 return None # although maybe we could guess the port too? 297 host = items[0] 298 port = items[1] 299 if len(items) >= 3: 300 protocol = items[2] 301 else: 302 protocol = PREFERRED_NETWORK_PROTOCOL 303 return ServerAddr(host=host, port=port, protocol=protocol) 304 305 def to_friendly_name(self) -> str: 306 # note: this method is closely linked to from_str_with_inference 307 if self.protocol == 's': # hide trailing ":s" 308 return self.net_addr_str() 309 return str(self) 310 311 def __str__(self): 312 return '{}:{}'.format(self.net_addr_str(), self.protocol) 313 314 def to_json(self) -> str: 315 return str(self) 316 317 def __repr__(self): 318 return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>' 319 320 def net_addr_str(self) -> str: 321 return self._net_addr_str 322 323 def __eq__(self, other): 324 if not isinstance(other, ServerAddr): 325 return False 326 return (self.host == other.host 327 and self.port == other.port 328 and self.protocol == other.protocol) 329 330 def __ne__(self, other): 331 return not (self == other) 332 333 def __hash__(self): 334 return hash((self.host, self.port, self.protocol)) 335 336 337def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str: 338 filename = host 339 try: 340 ip = ip_address(host) 341 except ValueError: 342 pass 343 else: 344 if isinstance(ip, IPv6Address): 345 filename = f"ipv6_{ip.packed.hex()}" 346 return os.path.join(config.path, 'certs', filename) 347 348 349class Interface(Logger): 350 351 LOGGING_SHORTCUT = 'i' 352 353 def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]): 354 self.ready = asyncio.Future() 355 self.got_disconnected = asyncio.Event() 356 self.server = server 357 Logger.__init__(self) 358 assert network.config.path 359 self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host) 360 self.blockchain = None # type: Optional[Blockchain] 361 self._requested_chunks = set() # type: Set[int] 362 self.network = network 363 self.proxy = MySocksProxy.from_proxy_dict(proxy) 364 self.session = None # type: Optional[NotificationSession] 365 self._ipaddr_bucket = None 366 367 # Latest block header and corresponding height, as claimed by the server. 368 # Note that these values are updated before they are verified. 369 # Especially during initial header sync, verification can take a long time. 370 # Failing verification will get the interface closed. 371 self.tip_header = None 372 self.tip = 0 373 374 self.fee_estimates_eta = {} # type: Dict[int, int] 375 376 # Dump network messages (only for this interface). Set at runtime from the console. 377 self.debug = False 378 379 self.taskgroup = SilentTaskGroup() 380 381 async def spawn_task(): 382 task = await self.network.taskgroup.spawn(self.run()) 383 if sys.version_info >= (3, 8): 384 task.set_name(f"interface::{str(server)}") 385 asyncio.run_coroutine_threadsafe(spawn_task(), self.network.asyncio_loop) 386 387 @property 388 def host(self): 389 return self.server.host 390 391 @property 392 def port(self): 393 return self.server.port 394 395 @property 396 def protocol(self): 397 return self.server.protocol 398 399 def diagnostic_name(self): 400 return self.server.net_addr_str() 401 402 def __str__(self): 403 return f"<Interface {self.diagnostic_name()}>" 404 405 async def is_server_ca_signed(self, ca_ssl_context): 406 """Given a CA enforcing SSL context, returns True if the connection 407 can be established. Returns False if the server has a self-signed 408 certificate but otherwise is okay. Any other failures raise. 409 """ 410 try: 411 await self.open_session(ca_ssl_context, exit_early=True) 412 except ConnectError as e: 413 cause = e.__cause__ 414 if isinstance(cause, ssl.SSLError) and cause.reason == 'CERTIFICATE_VERIFY_FAILED': 415 # failures due to self-signed certs are normal 416 return False 417 raise 418 return True 419 420 async def _try_saving_ssl_cert_for_first_time(self, ca_ssl_context): 421 ca_signed = await self.is_server_ca_signed(ca_ssl_context) 422 if ca_signed: 423 if self._get_expected_fingerprint(): 424 raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers") 425 with open(self.cert_path, 'w') as f: 426 # empty file means this is CA signed, not self-signed 427 f.write('') 428 else: 429 await self._save_certificate() 430 431 def _is_saved_ssl_cert_available(self): 432 if not os.path.exists(self.cert_path): 433 return False 434 with open(self.cert_path, 'r') as f: 435 contents = f.read() 436 if contents == '': # CA signed 437 if self._get_expected_fingerprint(): 438 raise InvalidOptionCombination("cannot use --serverfingerprint with CA signed servers") 439 return True 440 # pinned self-signed cert 441 try: 442 b = pem.dePem(contents, 'CERTIFICATE') 443 except SyntaxError as e: 444 self.logger.info(f"error parsing already saved cert: {e}") 445 raise ErrorParsingSSLCert(e) from e 446 try: 447 x = x509.X509(b) 448 except Exception as e: 449 self.logger.info(f"error parsing already saved cert: {e}") 450 raise ErrorParsingSSLCert(e) from e 451 try: 452 x.check_date() 453 except x509.CertificateError as e: 454 self.logger.info(f"certificate has expired: {e}") 455 os.unlink(self.cert_path) # delete pinned cert only in this case 456 return False 457 self._verify_certificate_fingerprint(bytearray(b)) 458 return True 459 460 async def _get_ssl_context(self): 461 if self.protocol != 's': 462 # using plaintext TCP 463 return None 464 465 # see if we already have cert for this server; or get it for the first time 466 ca_sslc = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_path) 467 if not self._is_saved_ssl_cert_available(): 468 try: 469 await self._try_saving_ssl_cert_for_first_time(ca_sslc) 470 except (OSError, ConnectError, aiorpcx.socks.SOCKSError) as e: 471 raise ErrorGettingSSLCertFromServer(e) from e 472 # now we have a file saved in our certificate store 473 siz = os.stat(self.cert_path).st_size 474 if siz == 0: 475 # CA signed cert 476 sslc = ca_sslc 477 else: 478 # pinned self-signed cert 479 sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=self.cert_path) 480 sslc.check_hostname = 0 481 return sslc 482 483 def handle_disconnect(func): 484 @functools.wraps(func) 485 async def wrapper_func(self: 'Interface', *args, **kwargs): 486 try: 487 return await func(self, *args, **kwargs) 488 except GracefulDisconnect as e: 489 self.logger.log(e.log_level, f"disconnecting due to {repr(e)}") 490 except aiorpcx.jsonrpc.RPCError as e: 491 self.logger.warning(f"disconnecting due to {repr(e)}") 492 self.logger.debug(f"(disconnect) trace for {repr(e)}", exc_info=True) 493 finally: 494 self.got_disconnected.set() 495 await self.network.connection_down(self) 496 # if was not 'ready' yet, schedule waiting coroutines: 497 self.ready.cancel() 498 return wrapper_func 499 500 @ignore_exceptions # do not kill network.taskgroup 501 @log_exceptions 502 @handle_disconnect 503 async def run(self): 504 try: 505 ssl_context = await self._get_ssl_context() 506 except (ErrorParsingSSLCert, ErrorGettingSSLCertFromServer) as e: 507 self.logger.info(f'disconnecting due to: {repr(e)}') 508 return 509 try: 510 await self.open_session(ssl_context) 511 except (asyncio.CancelledError, ConnectError, aiorpcx.socks.SOCKSError) as e: 512 # make SSL errors for main interface more visible (to help servers ops debug cert pinning issues) 513 if (isinstance(e, ConnectError) and isinstance(e.__cause__, ssl.SSLError) 514 and self.is_main_server() and not self.network.auto_connect): 515 self.logger.warning(f'Cannot connect to main server due to SSL error ' 516 f'(maybe cert changed compared to "{self.cert_path}"). Exc: {repr(e)}') 517 else: 518 self.logger.info(f'disconnecting due to: {repr(e)}') 519 return 520 521 def _mark_ready(self) -> None: 522 if self.ready.cancelled(): 523 raise GracefulDisconnect('conn establishment was too slow; *ready* future was cancelled') 524 if self.ready.done(): 525 return 526 527 assert self.tip_header 528 chain = blockchain.check_header(self.tip_header) 529 if not chain: 530 self.blockchain = blockchain.get_best_chain() 531 else: 532 self.blockchain = chain 533 assert self.blockchain is not None 534 535 self.logger.info(f"set blockchain with height {self.blockchain.height()}") 536 537 self.ready.set_result(1) 538 539 async def _save_certificate(self) -> None: 540 if not os.path.exists(self.cert_path): 541 # we may need to retry this a few times, in case the handshake hasn't completed 542 for _ in range(10): 543 dercert = await self._fetch_certificate() 544 if dercert: 545 self.logger.info("succeeded in getting cert") 546 self._verify_certificate_fingerprint(dercert) 547 with open(self.cert_path, 'w') as f: 548 cert = ssl.DER_cert_to_PEM_cert(dercert) 549 # workaround android bug 550 cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert) 551 f.write(cert) 552 # even though close flushes we can't fsync when closed. 553 # and we must flush before fsyncing, cause flush flushes to OS buffer 554 # fsync writes to OS buffer to disk 555 f.flush() 556 os.fsync(f.fileno()) 557 break 558 await asyncio.sleep(1) 559 else: 560 raise GracefulDisconnect("could not get certificate after 10 tries") 561 562 async def _fetch_certificate(self) -> bytes: 563 sslc = ssl.SSLContext() 564 async with _RSClient(session_factory=RPCSession, 565 host=self.host, port=self.port, 566 ssl=sslc, proxy=self.proxy) as session: 567 asyncio_transport = session.transport._asyncio_transport # type: asyncio.BaseTransport 568 ssl_object = asyncio_transport.get_extra_info("ssl_object") # type: ssl.SSLObject 569 return ssl_object.getpeercert(binary_form=True) 570 571 def _get_expected_fingerprint(self) -> Optional[str]: 572 if self.is_main_server(): 573 return self.network.config.get("serverfingerprint") 574 575 def _verify_certificate_fingerprint(self, certificate): 576 expected_fingerprint = self._get_expected_fingerprint() 577 if not expected_fingerprint: 578 return 579 fingerprint = hashlib.sha256(certificate).hexdigest() 580 fingerprints_match = fingerprint.lower() == expected_fingerprint.lower() 581 if not fingerprints_match: 582 util.trigger_callback('cert_mismatch') 583 raise ErrorSSLCertFingerprintMismatch('Refusing to connect to server due to cert fingerprint mismatch') 584 self.logger.info("cert fingerprint verification passed") 585 586 async def get_block_header(self, height, assert_mode): 587 self.logger.info(f'requesting block header {height} in mode {assert_mode}') 588 # use lower timeout as we usually have network.bhi_lock here 589 timeout = self.network.get_network_timeout_seconds(NetworkTimeout.Urgent) 590 res = await self.session.send_request('blockchain.block.header', [height], timeout=timeout) 591 return blockchain.deserialize_header(bytes.fromhex(res), height) 592 593 async def request_chunk(self, height: int, tip=None, *, can_return_early=False): 594 if not is_non_negative_integer(height): 595 raise Exception(f"{repr(height)} is not a block height") 596 index = height // 2016 597 if can_return_early and index in self._requested_chunks: 598 return 599 self.logger.info(f"requesting chunk from height {height}") 600 size = 2016 601 if tip is not None: 602 size = min(size, tip - index * 2016 + 1) 603 size = max(size, 0) 604 try: 605 self._requested_chunks.add(index) 606 res = await self.session.send_request('blockchain.block.headers', [index * 2016, size]) 607 finally: 608 self._requested_chunks.discard(index) 609 assert_dict_contains_field(res, field_name='count') 610 assert_dict_contains_field(res, field_name='hex') 611 assert_dict_contains_field(res, field_name='max') 612 assert_non_negative_integer(res['count']) 613 assert_non_negative_integer(res['max']) 614 assert_hex_str(res['hex']) 615 if len(res['hex']) != HEADER_SIZE * 2 * res['count']: 616 raise RequestCorrupted('inconsistent chunk hex and count') 617 # we never request more than 2016 headers, but we enforce those fit in a single response 618 if res['max'] < 2016: 619 raise RequestCorrupted(f"server uses too low 'max' count for block.headers: {res['max']} < 2016") 620 if res['count'] != size: 621 raise RequestCorrupted(f"expected {size} headers but only got {res['count']}") 622 conn = self.blockchain.connect_chunk(index, res['hex']) 623 if not conn: 624 return conn, 0 625 return conn, res['count'] 626 627 def is_main_server(self) -> bool: 628 return (self.network.interface == self or 629 self.network.interface is None and self.network.default_server == self.server) 630 631 async def open_session(self, sslc, exit_early=False): 632 session_factory = lambda *args, iface=self, **kwargs: NotificationSession(*args, **kwargs, interface=iface) 633 async with _RSClient(session_factory=session_factory, 634 host=self.host, port=self.port, 635 ssl=sslc, proxy=self.proxy) as session: 636 self.session = session # type: NotificationSession 637 self.session.set_default_timeout(self.network.get_network_timeout_seconds(NetworkTimeout.Generic)) 638 try: 639 ver = await session.send_request('server.version', [self.client_name(), version.PROTOCOL_VERSION]) 640 except aiorpcx.jsonrpc.RPCError as e: 641 raise GracefulDisconnect(e) # probably 'unsupported protocol version' 642 if exit_early: 643 return 644 if ver[1] != version.PROTOCOL_VERSION: 645 raise GracefulDisconnect(f'server violated protocol-version-negotiation. ' 646 f'we asked for {version.PROTOCOL_VERSION!r}, they sent {ver[1]!r}') 647 if not self.network.check_interface_against_healthy_spread_of_connected_servers(self): 648 raise GracefulDisconnect(f'too many connected servers already ' 649 f'in bucket {self.bucket_based_on_ipaddress()}') 650 self.logger.info(f"connection established. version: {ver}") 651 652 try: 653 async with self.taskgroup as group: 654 await group.spawn(self.ping) 655 await group.spawn(self.request_fee_estimates) 656 await group.spawn(self.run_fetch_blocks) 657 await group.spawn(self.monitor_connection) 658 except aiorpcx.jsonrpc.RPCError as e: 659 if e.code in (JSONRPC.EXCESSIVE_RESOURCE_USAGE, 660 JSONRPC.SERVER_BUSY, 661 JSONRPC.METHOD_NOT_FOUND): 662 raise GracefulDisconnect(e, log_level=logging.WARNING) from e 663 raise 664 665 async def monitor_connection(self): 666 while True: 667 await asyncio.sleep(1) 668 if not self.session or self.session.is_closing(): 669 raise GracefulDisconnect('session was closed') 670 671 async def ping(self): 672 while True: 673 await asyncio.sleep(300) 674 await self.session.send_request('server.ping') 675 676 async def request_fee_estimates(self): 677 from .simple_config import FEE_ETA_TARGETS 678 while True: 679 async with TaskGroup() as group: 680 fee_tasks = [] 681 for i in FEE_ETA_TARGETS: 682 fee_tasks.append((i, await group.spawn(self.get_estimatefee(i)))) 683 for nblock_target, task in fee_tasks: 684 fee = task.result() 685 if fee < 0: continue 686 assert isinstance(fee, int) 687 self.fee_estimates_eta[nblock_target] = fee 688 self.network.update_fee_estimates() 689 await asyncio.sleep(60) 690 691 async def close(self, *, force_after: int = None): 692 """Closes the connection and waits for it to be closed. 693 We try to flush buffered data to the wire, so this can take some time. 694 """ 695 if force_after is None: 696 # We give up after a while and just abort the connection. 697 # Note: specifically if the server is running Fulcrum, waiting seems hopeless, 698 # the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76) 699 force_after = 1 # seconds 700 if self.session: 701 await self.session.close(force_after=force_after) 702 # monitor_connection will cancel tasks 703 704 async def run_fetch_blocks(self): 705 header_queue = asyncio.Queue() 706 await self.session.subscribe('blockchain.headers.subscribe', [], header_queue) 707 while True: 708 item = await header_queue.get() 709 raw_header = item[0] 710 height = raw_header['height'] 711 header = blockchain.deserialize_header(bfh(raw_header['hex']), height) 712 self.tip_header = header 713 self.tip = height 714 if self.tip < constants.net.max_checkpoint(): 715 raise GracefulDisconnect('server tip below max checkpoint') 716 self._mark_ready() 717 await self._process_header_at_tip() 718 # header processing done 719 util.trigger_callback('blockchain_updated') 720 util.trigger_callback('network_updated') 721 await self.network.switch_unwanted_fork_interface() 722 await self.network.switch_lagging_interface() 723 724 async def _process_header_at_tip(self): 725 height, header = self.tip, self.tip_header 726 async with self.network.bhi_lock: 727 if self.blockchain.height() >= height and self.blockchain.check_header(header): 728 # another interface amended the blockchain 729 self.logger.info(f"skipping header {height}") 730 return 731 _, height = await self.step(height, header) 732 # in the simple case, height == self.tip+1 733 if height <= self.tip: 734 await self.sync_until(height) 735 736 async def sync_until(self, height, next_height=None): 737 if next_height is None: 738 next_height = self.tip 739 last = None 740 while last is None or height <= next_height: 741 prev_last, prev_height = last, height 742 if next_height > height + 10: 743 could_connect, num_headers = await self.request_chunk(height, next_height) 744 if not could_connect: 745 if height <= constants.net.max_checkpoint(): 746 raise GracefulDisconnect('server chain conflicts with checkpoints or genesis') 747 last, height = await self.step(height) 748 continue 749 util.trigger_callback('network_updated') 750 height = (height // 2016 * 2016) + num_headers 751 assert height <= next_height+1, (height, self.tip) 752 last = 'catchup' 753 else: 754 last, height = await self.step(height) 755 assert (prev_last, prev_height) != (last, height), 'had to prevent infinite loop in interface.sync_until' 756 return last, height 757 758 async def step(self, height, header=None): 759 assert 0 <= height <= self.tip, (height, self.tip) 760 if header is None: 761 header = await self.get_block_header(height, 'catchup') 762 763 chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) 764 if chain: 765 self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain 766 # note: there is an edge case here that is not handled. 767 # we might know the blockhash (enough for check_header) but 768 # not have the header itself. e.g. regtest chain with only genesis. 769 # this situation resolves itself on the next block 770 return 'catchup', height+1 771 772 can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height) 773 if not can_connect: 774 self.logger.info(f"can't connect {height}") 775 height, header, bad, bad_header = await self._search_headers_backwards(height, header) 776 chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) 777 can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height) 778 assert chain or can_connect 779 if can_connect: 780 self.logger.info(f"could connect {height}") 781 height += 1 782 if isinstance(can_connect, Blockchain): # not when mocking 783 self.blockchain = can_connect 784 self.blockchain.save_header(header) 785 return 'catchup', height 786 787 good, bad, bad_header = await self._search_headers_binary(height, bad, bad_header, chain) 788 return await self._resolve_potential_chain_fork_given_forkpoint(good, bad, bad_header) 789 790 async def _search_headers_binary(self, height, bad, bad_header, chain): 791 assert bad == bad_header['block_height'] 792 _assert_header_does_not_check_against_any_chain(bad_header) 793 794 self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain 795 good = height 796 while True: 797 assert good < bad, (good, bad) 798 height = (good + bad) // 2 799 self.logger.info(f"binary step. good {good}, bad {bad}, height {height}") 800 header = await self.get_block_header(height, 'binary') 801 chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) 802 if chain: 803 self.blockchain = chain if isinstance(chain, Blockchain) else self.blockchain 804 good = height 805 else: 806 bad = height 807 bad_header = header 808 if good + 1 == bad: 809 break 810 811 mock = 'mock' in bad_header and bad_header['mock']['connect'](height) 812 real = not mock and self.blockchain.can_connect(bad_header, check_height=False) 813 if not real and not mock: 814 raise Exception('unexpected bad header during binary: {}'.format(bad_header)) 815 _assert_header_does_not_check_against_any_chain(bad_header) 816 817 self.logger.info(f"binary search exited. good {good}, bad {bad}") 818 return good, bad, bad_header 819 820 async def _resolve_potential_chain_fork_given_forkpoint(self, good, bad, bad_header): 821 assert good + 1 == bad 822 assert bad == bad_header['block_height'] 823 _assert_header_does_not_check_against_any_chain(bad_header) 824 # 'good' is the height of a block 'good_header', somewhere in self.blockchain. 825 # bad_header connects to good_header; bad_header itself is NOT in self.blockchain. 826 827 bh = self.blockchain.height() 828 assert bh >= good, (bh, good) 829 if bh == good: 830 height = good + 1 831 self.logger.info(f"catching up from {height}") 832 return 'no_fork', height 833 834 # this is a new fork we don't yet have 835 height = bad + 1 836 self.logger.info(f"new fork at bad height {bad}") 837 forkfun = self.blockchain.fork if 'mock' not in bad_header else bad_header['mock']['fork'] 838 b = forkfun(bad_header) # type: Blockchain 839 self.blockchain = b 840 assert b.forkpoint == bad 841 return 'fork', height 842 843 async def _search_headers_backwards(self, height, header): 844 async def iterate(): 845 nonlocal height, header 846 checkp = False 847 if height <= constants.net.max_checkpoint(): 848 height = constants.net.max_checkpoint() 849 checkp = True 850 header = await self.get_block_header(height, 'backward') 851 chain = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) 852 can_connect = blockchain.can_connect(header) if 'mock' not in header else header['mock']['connect'](height) 853 if chain or can_connect: 854 return False 855 if checkp: 856 raise GracefulDisconnect("server chain conflicts with checkpoints") 857 return True 858 859 bad, bad_header = height, header 860 _assert_header_does_not_check_against_any_chain(bad_header) 861 with blockchain.blockchains_lock: chains = list(blockchain.blockchains.values()) 862 local_max = max([0] + [x.height() for x in chains]) if 'mock' not in header else float('inf') 863 height = min(local_max + 1, height - 1) 864 while await iterate(): 865 bad, bad_header = height, header 866 delta = self.tip - height 867 height = self.tip - 2 * delta 868 869 _assert_header_does_not_check_against_any_chain(bad_header) 870 self.logger.info(f"exiting backward mode at {height}") 871 return height, header, bad, bad_header 872 873 @classmethod 874 def client_name(cls) -> str: 875 return f'electrum/{version.ELECTRUM_VERSION}' 876 877 def is_tor(self): 878 return self.host.endswith('.onion') 879 880 def ip_addr(self) -> Optional[str]: 881 session = self.session 882 if not session: return None 883 peer_addr = session.remote_address() 884 if not peer_addr: return None 885 return str(peer_addr.host) 886 887 def bucket_based_on_ipaddress(self) -> str: 888 def do_bucket(): 889 if self.is_tor(): 890 return BUCKET_NAME_OF_ONION_SERVERS 891 try: 892 ip_addr = ip_address(self.ip_addr()) # type: Union[IPv4Address, IPv6Address] 893 except ValueError: 894 return '' 895 if not ip_addr: 896 return '' 897 if ip_addr.is_loopback: # localhost is exempt 898 return '' 899 if ip_addr.version == 4: 900 slash16 = IPv4Network(ip_addr).supernet(prefixlen_diff=32-16) 901 return str(slash16) 902 elif ip_addr.version == 6: 903 slash48 = IPv6Network(ip_addr).supernet(prefixlen_diff=128-48) 904 return str(slash48) 905 return '' 906 907 if not self._ipaddr_bucket: 908 self._ipaddr_bucket = do_bucket() 909 return self._ipaddr_bucket 910 911 async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict: 912 if not is_hash256_str(tx_hash): 913 raise Exception(f"{repr(tx_hash)} is not a txid") 914 if not is_non_negative_integer(tx_height): 915 raise Exception(f"{repr(tx_height)} is not a block height") 916 # do request 917 res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height]) 918 # check response 919 block_height = assert_dict_contains_field(res, field_name='block_height') 920 merkle = assert_dict_contains_field(res, field_name='merkle') 921 pos = assert_dict_contains_field(res, field_name='pos') 922 # note: tx_height was just a hint to the server, don't enforce the response to match it 923 assert_non_negative_integer(block_height) 924 assert_non_negative_integer(pos) 925 assert_list_or_tuple(merkle) 926 for item in merkle: 927 assert_hash256_str(item) 928 return res 929 930 async def get_transaction(self, tx_hash: str, *, timeout=None) -> str: 931 if not is_hash256_str(tx_hash): 932 raise Exception(f"{repr(tx_hash)} is not a txid") 933 raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout) 934 # validate response 935 if not is_hex_str(raw): 936 raise RequestCorrupted(f"received garbage (non-hex) as tx data (txid {tx_hash}): {raw!r}") 937 tx = Transaction(raw) 938 try: 939 tx.deserialize() # see if raises 940 except Exception as e: 941 raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e 942 if tx.txid() != tx_hash: 943 raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})") 944 return raw 945 946 async def get_history_for_scripthash(self, sh: str) -> List[dict]: 947 if not is_hash256_str(sh): 948 raise Exception(f"{repr(sh)} is not a scripthash") 949 # do request 950 res = await self.session.send_request('blockchain.scripthash.get_history', [sh]) 951 # check response 952 assert_list_or_tuple(res) 953 prev_height = 1 954 for tx_item in res: 955 height = assert_dict_contains_field(tx_item, field_name='height') 956 assert_dict_contains_field(tx_item, field_name='tx_hash') 957 assert_integer(height) 958 assert_hash256_str(tx_item['tx_hash']) 959 if height in (-1, 0): 960 assert_dict_contains_field(tx_item, field_name='fee') 961 assert_non_negative_integer(tx_item['fee']) 962 prev_height = - float("inf") # this ensures confirmed txs can't follow mempool txs 963 else: 964 # check monotonicity of heights 965 if height < prev_height: 966 raise RequestCorrupted(f'heights of confirmed txs must be in increasing order') 967 prev_height = height 968 hashes = set(map(lambda item: item['tx_hash'], res)) 969 if len(hashes) != len(res): 970 # Either server is sending garbage... or maybe if server is race-prone 971 # a recently mined tx could be included in both last block and mempool? 972 # Still, it's simplest to just disregard the response. 973 raise RequestCorrupted(f"server history has non-unique txids for sh={sh}") 974 return res 975 976 async def listunspent_for_scripthash(self, sh: str) -> List[dict]: 977 if not is_hash256_str(sh): 978 raise Exception(f"{repr(sh)} is not a scripthash") 979 # do request 980 res = await self.session.send_request('blockchain.scripthash.listunspent', [sh]) 981 # check response 982 assert_list_or_tuple(res) 983 for utxo_item in res: 984 assert_dict_contains_field(utxo_item, field_name='tx_pos') 985 assert_dict_contains_field(utxo_item, field_name='value') 986 assert_dict_contains_field(utxo_item, field_name='tx_hash') 987 assert_dict_contains_field(utxo_item, field_name='height') 988 assert_non_negative_integer(utxo_item['tx_pos']) 989 assert_non_negative_integer(utxo_item['value']) 990 assert_non_negative_integer(utxo_item['height']) 991 assert_hash256_str(utxo_item['tx_hash']) 992 return res 993 994 async def get_balance_for_scripthash(self, sh: str) -> dict: 995 if not is_hash256_str(sh): 996 raise Exception(f"{repr(sh)} is not a scripthash") 997 # do request 998 res = await self.session.send_request('blockchain.scripthash.get_balance', [sh]) 999 # check response 1000 assert_dict_contains_field(res, field_name='confirmed') 1001 assert_dict_contains_field(res, field_name='unconfirmed') 1002 assert_non_negative_integer(res['confirmed']) 1003 assert_non_negative_integer(res['unconfirmed']) 1004 return res 1005 1006 async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool): 1007 if not is_non_negative_integer(tx_height): 1008 raise Exception(f"{repr(tx_height)} is not a block height") 1009 if not is_non_negative_integer(tx_pos): 1010 raise Exception(f"{repr(tx_pos)} should be non-negative integer") 1011 # do request 1012 res = await self.session.send_request( 1013 'blockchain.transaction.id_from_pos', 1014 [tx_height, tx_pos, merkle], 1015 ) 1016 # check response 1017 if merkle: 1018 assert_dict_contains_field(res, field_name='tx_hash') 1019 assert_dict_contains_field(res, field_name='merkle') 1020 assert_hash256_str(res['tx_hash']) 1021 assert_list_or_tuple(res['merkle']) 1022 for node_hash in res['merkle']: 1023 assert_hash256_str(node_hash) 1024 else: 1025 assert_hash256_str(res) 1026 return res 1027 1028 async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]: 1029 # do request 1030 res = await self.session.send_request('mempool.get_fee_histogram') 1031 # check response 1032 assert_list_or_tuple(res) 1033 prev_fee = float('inf') 1034 for fee, s in res: 1035 assert_non_negative_int_or_float(fee) 1036 assert_non_negative_integer(s) 1037 if fee >= prev_fee: # check monotonicity 1038 raise RequestCorrupted(f'fees must be in decreasing order') 1039 prev_fee = fee 1040 return res 1041 1042 async def get_server_banner(self) -> str: 1043 # do request 1044 res = await self.session.send_request('server.banner') 1045 # check response 1046 if not isinstance(res, str): 1047 raise RequestCorrupted(f'{res!r} should be a str') 1048 return res 1049 1050 async def get_donation_address(self) -> str: 1051 # do request 1052 res = await self.session.send_request('server.donation_address') 1053 # check response 1054 if not res: # ignore empty string 1055 return '' 1056 if not bitcoin.is_address(res): 1057 # note: do not hard-fail -- allow server to use future-type 1058 # bitcoin address we do not recognize 1059 self.logger.info(f"invalid donation address from server: {repr(res)}") 1060 res = '' 1061 return res 1062 1063 async def get_relay_fee(self) -> int: 1064 """Returns the min relay feerate in sat/kbyte.""" 1065 # do request 1066 res = await self.session.send_request('blockchain.relayfee') 1067 # check response 1068 assert_non_negative_int_or_float(res) 1069 relayfee = int(res * bitcoin.COIN) 1070 relayfee = max(0, relayfee) 1071 return relayfee 1072 1073 async def get_estimatefee(self, num_blocks: int) -> int: 1074 """Returns a feerate estimate for getting confirmed within 1075 num_blocks blocks, in sat/kbyte. 1076 """ 1077 if not is_non_negative_integer(num_blocks): 1078 raise Exception(f"{repr(num_blocks)} is not a num_blocks") 1079 # do request 1080 res = await self.session.send_request('blockchain.estimatefee', [num_blocks]) 1081 # check response 1082 if res != -1: 1083 assert_non_negative_int_or_float(res) 1084 res = int(res * bitcoin.COIN) 1085 return res 1086 1087 1088def _assert_header_does_not_check_against_any_chain(header: dict) -> None: 1089 chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) 1090 if chain_bad: 1091 raise Exception('bad_header must not check!') 1092 1093 1094def check_cert(host, cert): 1095 try: 1096 b = pem.dePem(cert, 'CERTIFICATE') 1097 x = x509.X509(b) 1098 except: 1099 traceback.print_exc(file=sys.stdout) 1100 return 1101 1102 try: 1103 x.check_date() 1104 expired = False 1105 except: 1106 expired = True 1107 1108 m = "host: %s\n"%host 1109 m += "has_expired: %s\n"% expired 1110 util.print_msg(m) 1111 1112 1113# Used by tests 1114def _match_hostname(name, val): 1115 if val == name: 1116 return True 1117 1118 return val.startswith('*.') and name.endswith(val[1:]) 1119 1120 1121def test_certificates(): 1122 from .simple_config import SimpleConfig 1123 config = SimpleConfig() 1124 mydir = os.path.join(config.path, "certs") 1125 certs = os.listdir(mydir) 1126 for c in certs: 1127 p = os.path.join(mydir,c) 1128 with open(p, encoding='utf-8') as f: 1129 cert = f.read() 1130 check_cert(c, cert) 1131 1132if __name__ == "__main__": 1133 test_certificates() 1134