1# This file is part of the Trezor project. 2# 3# Copyright (C) 2012-2022 SatoshiLabs and contributors 4# 5# This library is free software: you can redistribute it and/or modify 6# it under the terms of the GNU Lesser General Public License version 3 7# as published by the Free Software Foundation. 8# 9# This library is distributed in the hope that it will be useful, 10# but WITHOUT ANY WARRANTY; without even the implied warranty of 11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12# GNU Lesser General Public License for more details. 13# 14# You should have received a copy of the License along with this library. 15# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. 16 17import logging 18import textwrap 19from collections import namedtuple 20from copy import deepcopy 21from enum import IntEnum 22from itertools import zip_longest 23from typing import ( 24 TYPE_CHECKING, 25 Any, 26 Callable, 27 Dict, 28 Generator, 29 Iterable, 30 Iterator, 31 List, 32 Optional, 33 Sequence, 34 Tuple, 35 Type, 36 Union, 37) 38 39from mnemonic import Mnemonic 40 41from . import mapping, messages, protobuf 42from .client import TrezorClient 43from .exceptions import TrezorFailure 44from .log import DUMP_BYTES 45from .tools import expect 46 47if TYPE_CHECKING: 48 from .transport import Transport 49 from .messages import PinMatrixRequestType 50 51 ExpectedMessage = Union[ 52 protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter" 53 ] 54 55EXPECTED_RESPONSES_CONTEXT_LINES = 3 56 57LayoutLines = namedtuple("LayoutLines", "lines text") 58 59LOG = logging.getLogger(__name__) 60 61 62def layout_lines(lines: Sequence[str]) -> LayoutLines: 63 return LayoutLines(lines, " ".join(lines)) 64 65 66class DebugLink: 67 def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: 68 self.transport = transport 69 self.allow_interactions = auto_interact 70 self.mapping = mapping.DEFAULT_MAPPING 71 72 def open(self) -> None: 73 self.transport.begin_session() 74 75 def close(self) -> None: 76 self.transport.end_session() 77 78 def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any: 79 LOG.debug( 80 f"sending message: {msg.__class__.__name__}", 81 extra={"protobuf": msg}, 82 ) 83 msg_type, msg_bytes = self.mapping.encode(msg) 84 LOG.log( 85 DUMP_BYTES, 86 f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", 87 ) 88 self.transport.write(msg_type, msg_bytes) 89 if nowait: 90 return None 91 92 ret_type, ret_bytes = self.transport.read() 93 LOG.log( 94 DUMP_BYTES, 95 f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", 96 ) 97 msg = self.mapping.decode(ret_type, ret_bytes) 98 LOG.debug( 99 f"received message: {msg.__class__.__name__}", 100 extra={"protobuf": msg}, 101 ) 102 return msg 103 104 def state(self) -> messages.DebugLinkState: 105 return self._call(messages.DebugLinkGetState()) 106 107 def read_layout(self) -> LayoutLines: 108 return layout_lines(self.state().layout_lines) 109 110 def wait_layout(self) -> LayoutLines: 111 obj = self._call(messages.DebugLinkGetState(wait_layout=True)) 112 if isinstance(obj, messages.Failure): 113 raise TrezorFailure(obj) 114 return layout_lines(obj.layout_lines) 115 116 def watch_layout(self, watch: bool) -> None: 117 """Enable or disable watching layouts. 118 If disabled, wait_layout will not work. 119 120 The message is missing on T1. Use `TrezorClientDebugLink.watch_layout` for 121 cross-version compatibility. 122 """ 123 self._call(messages.DebugLinkWatchLayout(watch=watch)) 124 125 def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str: 126 """Transform correct PIN according to the displayed matrix.""" 127 if matrix is None: 128 matrix = self.state().matrix 129 if matrix is None: 130 # we are on trezor-core 131 return pin 132 133 return "".join([str(matrix.index(p) + 1) for p in pin]) 134 135 def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]: 136 state = self.state() 137 return (state.recovery_fake_word, state.recovery_word_pos) 138 139 def read_reset_word(self) -> str: 140 state = self._call(messages.DebugLinkGetState(wait_word_list=True)) 141 return state.reset_word 142 143 def read_reset_word_pos(self) -> int: 144 state = self._call(messages.DebugLinkGetState(wait_word_pos=True)) 145 return state.reset_word_pos 146 147 def input( 148 self, 149 word: Optional[str] = None, 150 button: Optional[bool] = None, 151 swipe: Optional[messages.DebugSwipeDirection] = None, 152 x: Optional[int] = None, 153 y: Optional[int] = None, 154 wait: Optional[bool] = None, 155 hold_ms: Optional[int] = None, 156 ) -> Optional[LayoutLines]: 157 if not self.allow_interactions: 158 return None 159 160 args = sum(a is not None for a in (word, button, swipe, x)) 161 if args != 1: 162 raise ValueError("Invalid input - must use one of word, button, swipe") 163 164 decision = messages.DebugLinkDecision( 165 yes_no=button, swipe=swipe, input=word, x=x, y=y, wait=wait, hold_ms=hold_ms 166 ) 167 ret = self._call(decision, nowait=not wait) 168 if ret is not None: 169 return layout_lines(ret.lines) 170 171 return None 172 173 def click( 174 self, click: Tuple[int, int], wait: bool = False 175 ) -> Optional[LayoutLines]: 176 x, y = click 177 return self.input(x=x, y=y, wait=wait) 178 179 def press_yes(self) -> None: 180 self.input(button=True) 181 182 def press_no(self) -> None: 183 self.input(button=False) 184 185 def swipe_up(self, wait: bool = False) -> None: 186 self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait) 187 188 def swipe_down(self) -> None: 189 self.input(swipe=messages.DebugSwipeDirection.DOWN) 190 191 def swipe_right(self) -> None: 192 self.input(swipe=messages.DebugSwipeDirection.RIGHT) 193 194 def swipe_left(self) -> None: 195 self.input(swipe=messages.DebugSwipeDirection.LEFT) 196 197 def stop(self) -> None: 198 self._call(messages.DebugLinkStop(), nowait=True) 199 200 def reseed(self, value: int) -> protobuf.MessageType: 201 return self._call(messages.DebugLinkReseedRandom(value=value)) 202 203 def start_recording(self, directory: str) -> None: 204 self._call(messages.DebugLinkRecordScreen(target_directory=directory)) 205 206 def stop_recording(self) -> None: 207 self._call(messages.DebugLinkRecordScreen(target_directory=None)) 208 209 @expect(messages.DebugLinkMemory, field="memory", ret_type=bytes) 210 def memory_read(self, address: int, length: int) -> protobuf.MessageType: 211 return self._call(messages.DebugLinkMemoryRead(address=address, length=length)) 212 213 def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None: 214 self._call( 215 messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash), 216 nowait=True, 217 ) 218 219 def flash_erase(self, sector: int) -> None: 220 self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True) 221 222 @expect(messages.Success) 223 def erase_sd_card(self, format: bool = True) -> messages.Success: 224 return self._call(messages.DebugLinkEraseSdCard(format=format)) 225 226 227class NullDebugLink(DebugLink): 228 def __init__(self) -> None: 229 # Ignoring type error as self.transport will not be touched while using NullDebugLink 230 super().__init__(None) # type: ignore ["None" cannot be assigned to parameter of type "Transport"] 231 232 def open(self) -> None: 233 pass 234 235 def close(self) -> None: 236 pass 237 238 def _call( 239 self, msg: protobuf.MessageType, nowait: bool = False 240 ) -> Optional[messages.DebugLinkState]: 241 if not nowait: 242 if isinstance(msg, messages.DebugLinkGetState): 243 return messages.DebugLinkState() 244 else: 245 raise RuntimeError("unexpected call to a fake debuglink") 246 247 return None 248 249 250class DebugUI: 251 INPUT_FLOW_DONE = object() 252 253 def __init__(self, debuglink: DebugLink) -> None: 254 self.debuglink = debuglink 255 self.clear() 256 257 def clear(self) -> None: 258 self.pins: Optional[Iterator[str]] = None 259 self.passphrase = "" 260 self.input_flow: Union[ 261 Generator[None, messages.ButtonRequest, None], object, None 262 ] = None 263 264 def button_request(self, br: messages.ButtonRequest) -> None: 265 if self.input_flow is None: 266 if br.code == messages.ButtonRequestType.PinEntry: 267 self.debuglink.input(self.get_pin()) 268 else: 269 if br.pages is not None: 270 for _ in range(br.pages - 1): 271 self.debuglink.swipe_up(wait=True) 272 self.debuglink.press_yes() 273 elif self.input_flow is self.INPUT_FLOW_DONE: 274 raise AssertionError("input flow ended prematurely") 275 else: 276 try: 277 assert isinstance(self.input_flow, Generator) 278 self.input_flow.send(br) 279 except StopIteration: 280 self.input_flow = self.INPUT_FLOW_DONE 281 282 def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str: 283 if self.pins is None: 284 raise RuntimeError("PIN requested but no sequence was configured") 285 286 try: 287 return self.debuglink.encode_pin(next(self.pins)) 288 except StopIteration: 289 raise AssertionError("PIN sequence ended prematurely") 290 291 def get_passphrase(self, available_on_device: bool) -> str: 292 return self.passphrase 293 294 295class MessageFilter: 296 def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None: 297 self.message_type = message_type 298 self.fields: Dict[str, Any] = {} 299 self.update_fields(**fields) 300 301 def update_fields(self, **fields: Any) -> "MessageFilter": 302 for name, value in fields.items(): 303 try: 304 self.fields[name] = self.from_message_or_type(value) 305 except TypeError: 306 self.fields[name] = value 307 308 return self 309 310 @classmethod 311 def from_message_or_type( 312 cls, message_or_type: "ExpectedMessage" 313 ) -> "MessageFilter": 314 if isinstance(message_or_type, cls): 315 return message_or_type 316 if isinstance(message_or_type, protobuf.MessageType): 317 return cls.from_message(message_or_type) 318 if isinstance(message_or_type, type) and issubclass( 319 message_or_type, protobuf.MessageType 320 ): 321 return cls(message_or_type) 322 raise TypeError("Invalid kind of expected response") 323 324 @classmethod 325 def from_message(cls, message: protobuf.MessageType) -> "MessageFilter": 326 fields = {} 327 for field in message.FIELDS.values(): 328 value = getattr(message, field.name) 329 if value in (None, [], protobuf.REQUIRED_FIELD_PLACEHOLDER): 330 continue 331 fields[field.name] = value 332 return cls(type(message), **fields) 333 334 def match(self, message: protobuf.MessageType) -> bool: 335 if type(message) != self.message_type: 336 return False 337 338 for field, expected_value in self.fields.items(): 339 actual_value = getattr(message, field, None) 340 if isinstance(expected_value, MessageFilter): 341 if actual_value is None or not expected_value.match(actual_value): 342 return False 343 elif expected_value != actual_value: 344 return False 345 346 return True 347 348 def to_string(self, maxwidth: int = 80) -> str: 349 fields: List[Tuple[str, str]] = [] 350 for field in self.message_type.FIELDS.values(): 351 if field.name not in self.fields: 352 continue 353 value = self.fields[field.name] 354 if isinstance(value, IntEnum): 355 field_str = value.name 356 elif isinstance(value, MessageFilter): 357 field_str = value.to_string(maxwidth - 4) 358 elif isinstance(value, protobuf.MessageType): 359 field_str = protobuf.format_message(value) 360 else: 361 field_str = repr(value) 362 field_str = textwrap.indent(field_str, " ").lstrip() 363 fields.append((field.name, field_str)) 364 365 pairs = [f"{k}={v}" for k, v in fields] 366 oneline_str = ", ".join(pairs) 367 if len(oneline_str) < maxwidth: 368 return f"{self.message_type.__name__}({oneline_str})" 369 else: 370 item: List[str] = [] 371 item.append(f"{self.message_type.__name__}(") 372 for pair in pairs: 373 item.append(f" {pair}") 374 item.append(")") 375 return "\n".join(item) 376 377 378class MessageFilterGenerator: 379 def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: 380 message_type = getattr(messages, key) 381 return MessageFilter(message_type).update_fields 382 383 384message_filters = MessageFilterGenerator() 385 386 387class TrezorClientDebugLink(TrezorClient): 388 # This class implements automatic responses 389 # and other functionality for unit tests 390 # for various callbacks, created in order 391 # to automatically pass unit tests. 392 # 393 # This mixing should be used only for purposes 394 # of unit testing, because it will fail to work 395 # without special DebugLink interface provided 396 # by the device. 397 398 def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: 399 try: 400 debug_transport = transport.find_debug() 401 self.debug = DebugLink(debug_transport, auto_interact) 402 # try to open debuglink, see if it works 403 self.debug.open() 404 self.debug.close() 405 except Exception: 406 if not auto_interact: 407 self.debug = NullDebugLink() 408 else: 409 raise 410 411 self.reset_debug_features() 412 413 super().__init__(transport, ui=self.ui) 414 415 def reset_debug_features(self) -> None: 416 """Prepare the debugging client for a new testcase. 417 418 Clears all debugging state that might have been modified by a testcase. 419 """ 420 self.ui: DebugUI = DebugUI(self.debug) 421 self.in_with_statement = False 422 self.expected_responses: Optional[List[MessageFilter]] = None 423 self.actual_responses: Optional[List[protobuf.MessageType]] = None 424 self.filters: Dict[ 425 Type[protobuf.MessageType], 426 Callable[[protobuf.MessageType], protobuf.MessageType], 427 ] = {} 428 429 def open(self) -> None: 430 super().open() 431 if self.session_counter == 1: 432 self.debug.open() 433 434 def close(self) -> None: 435 if self.session_counter == 1: 436 self.debug.close() 437 super().close() 438 439 def set_filter( 440 self, 441 message_type: Type[protobuf.MessageType], 442 callback: Callable[[protobuf.MessageType], protobuf.MessageType], 443 ) -> None: 444 """Configure a filter function for a specified message type. 445 446 The `callback` must be a function that accepts a protobuf message, and returns 447 a (possibly modified) protobuf message of the same type. Whenever a message 448 is sent or received that matches `message_type`, `callback` is invoked on the 449 message and its result is substituted for the original. 450 451 Useful for test scenarios with an active malicious actor on the wire. 452 """ 453 if not self.in_with_statement: 454 raise RuntimeError("Must be called inside 'with' statement") 455 456 self.filters[message_type] = callback 457 458 def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: 459 message_type = msg.__class__ 460 callback = self.filters.get(message_type) 461 if callable(callback): 462 return callback(deepcopy(msg)) 463 else: 464 return msg 465 466 def set_input_flow( 467 self, input_flow: Generator[None, Optional[messages.ButtonRequest], None] 468 ) -> None: 469 """Configure a sequence of input events for the current with-block. 470 471 The `input_flow` must be a generator function. A `yield` statement in the 472 input flow function waits for a ButtonRequest from the device, and returns 473 its code. 474 475 Example usage: 476 477 >>> def input_flow(): 478 >>> # wait for first button prompt 479 >>> code = yield 480 >>> assert code == ButtonRequestType.Other 481 >>> # press No 482 >>> client.debug.press_no() 483 >>> 484 >>> # wait for second button prompt 485 >>> yield 486 >>> # press Yes 487 >>> client.debug.press_yes() 488 >>> 489 >>> with client: 490 >>> client.set_input_flow(input_flow) 491 >>> some_call(client) 492 """ 493 if not self.in_with_statement: 494 raise RuntimeError("Must be called inside 'with' statement") 495 496 if callable(input_flow): 497 input_flow = input_flow() 498 if not hasattr(input_flow, "send"): 499 raise RuntimeError("input_flow should be a generator function") 500 self.ui.input_flow = input_flow 501 input_flow.send(None) # start the generator 502 503 def watch_layout(self, watch: bool = True) -> None: 504 """Enable or disable watching layout changes. 505 506 Since trezor-core v2.3.2, it is necessary to call `watch_layout()` before 507 using `debug.wait_layout()`, otherwise layout changes are not reported. 508 """ 509 if self.version >= (2, 3, 2): 510 # version check is necessary because otherwise we cannot reliably detect 511 # whether and where to wait for reply: 512 # - T1 reports unknown debuglink messages on the wirelink 513 # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug 514 self.debug.watch_layout(watch) 515 516 def __enter__(self) -> "TrezorClientDebugLink": 517 # For usage in with/expected_responses 518 if self.in_with_statement: 519 raise RuntimeError("Do not nest!") 520 self.in_with_statement = True 521 return self 522 523 def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: 524 __tracebackhide__ = True # for pytest # pylint: disable=W0612 525 526 self.watch_layout(False) 527 # copy expected/actual responses before clearing them 528 expected_responses = self.expected_responses 529 actual_responses = self.actual_responses 530 self.reset_debug_features() 531 532 if exc_type is None: 533 # If no other exception was raised, evaluate missed responses 534 # (raises AssertionError on mismatch) 535 self._verify_responses(expected_responses, actual_responses) 536 537 def set_expected_responses( 538 self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] 539 ) -> None: 540 """Set a sequence of expected responses to client calls. 541 542 Within a given with-block, the list of received responses from device must 543 match the list of expected responses, otherwise an AssertionError is raised. 544 545 If an expected response is given a field value other than None, that field value 546 must exactly match the received field value. If a given field is None 547 (or unspecified) in the expected response, the received field value is not 548 checked. 549 550 Each expected response can also be a tuple (bool, message). In that case, the 551 expected response is only evaluated if the first field is True. 552 This is useful for differentiating sequences between Trezor models: 553 554 >>> trezor_one = client.features.model == "1" 555 >>> client.set_expected_responses([ 556 >>> messages.ButtonRequest(code=ConfirmOutput), 557 >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), 558 >>> messages.Success(), 559 >>> ]) 560 """ 561 if not self.in_with_statement: 562 raise RuntimeError("Must be called inside 'with' statement") 563 564 # make sure all items are (bool, message) tuples 565 expected_with_validity = ( 566 e if isinstance(e, tuple) else (True, e) for e in expected 567 ) 568 569 # only apply those items that are (True, message) 570 self.expected_responses = [ 571 MessageFilter.from_message_or_type(expected) 572 for valid, expected in expected_with_validity 573 if valid 574 ] 575 self.actual_responses = [] 576 577 def use_pin_sequence(self, pins: Iterable[str]) -> None: 578 """Respond to PIN prompts from device with the provided PINs. 579 The sequence must be at least as long as the expected number of PIN prompts. 580 """ 581 self.ui.pins = iter(pins) 582 583 def use_passphrase(self, passphrase: str) -> None: 584 """Respond to passphrase prompts from device with the provided passphrase.""" 585 self.ui.passphrase = Mnemonic.normalize_string(passphrase) 586 587 def use_mnemonic(self, mnemonic: str) -> None: 588 """Use the provided mnemonic to respond to device. 589 Only applies to T1, where device prompts the host for mnemonic words.""" 590 self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") 591 592 def _raw_read(self) -> protobuf.MessageType: 593 __tracebackhide__ = True # for pytest # pylint: disable=W0612 594 595 resp = super()._raw_read() 596 resp = self._filter_message(resp) 597 if self.actual_responses is not None: 598 self.actual_responses.append(resp) 599 return resp 600 601 def _raw_write(self, msg: protobuf.MessageType) -> None: 602 return super()._raw_write(self._filter_message(msg)) 603 604 @staticmethod 605 def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]: 606 start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) 607 stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) 608 output: List[str] = [] 609 output.append("Expected responses:") 610 if start_at > 0: 611 output.append(f" (...{start_at} previous responses omitted)") 612 for i in range(start_at, stop_at): 613 exp = expected[i] 614 prefix = " " if i != current else ">>> " 615 output.append(textwrap.indent(exp.to_string(), prefix)) 616 if stop_at < len(expected): 617 omitted = len(expected) - stop_at 618 output.append(f" (...{omitted} following responses omitted)") 619 620 output.append("") 621 return output 622 623 @classmethod 624 def _verify_responses( 625 cls, 626 expected: Optional[List[MessageFilter]], 627 actual: Optional[List[protobuf.MessageType]], 628 ) -> None: 629 __tracebackhide__ = True # for pytest # pylint: disable=W0612 630 631 if expected is None and actual is None: 632 return 633 634 assert expected is not None 635 assert actual is not None 636 637 for i, (exp, act) in enumerate(zip_longest(expected, actual)): 638 if exp is None: 639 output = cls._expectation_lines(expected, i) 640 output.append("No more messages were expected, but we got:") 641 for resp in actual[i:]: 642 output.append( 643 textwrap.indent(protobuf.format_message(resp), " ") 644 ) 645 raise AssertionError("\n".join(output)) 646 647 if act is None: 648 output = cls._expectation_lines(expected, i) 649 output.append("This and the following message was not received.") 650 raise AssertionError("\n".join(output)) 651 652 if not exp.match(act): 653 output = cls._expectation_lines(expected, i) 654 output.append("Actually received:") 655 output.append(textwrap.indent(protobuf.format_message(act), " ")) 656 raise AssertionError("\n".join(output)) 657 658 def mnemonic_callback(self, _) -> str: 659 word, pos = self.debug.read_recovery_word() 660 if word: 661 return word 662 if pos: 663 return self.mnemonic[pos - 1] 664 665 raise RuntimeError("Unexpected call") 666 667 668@expect(messages.Success, field="message", ret_type=str) 669def load_device( 670 client: "TrezorClient", 671 mnemonic: Union[str, Iterable[str]], 672 pin: Optional[str], 673 passphrase_protection: bool, 674 label: Optional[str], 675 language: str = "en-US", 676 skip_checksum: bool = False, 677 needs_backup: bool = False, 678 no_backup: bool = False, 679) -> protobuf.MessageType: 680 if isinstance(mnemonic, str): 681 mnemonic = [mnemonic] 682 683 mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] 684 685 if client.features.initialized: 686 raise RuntimeError( 687 "Device is initialized already. Call device.wipe() and try again." 688 ) 689 690 resp = client.call( 691 messages.LoadDevice( 692 mnemonics=mnemonics, 693 pin=pin, 694 passphrase_protection=passphrase_protection, 695 language=language, 696 label=label, 697 skip_checksum=skip_checksum, 698 needs_backup=needs_backup, 699 no_backup=no_backup, 700 ) 701 ) 702 client.init_device() 703 return resp 704 705 706# keep the old name for compatibility 707load_device_by_mnemonic = load_device 708 709 710@expect(messages.Success, field="message", ret_type=str) 711def self_test(client: "TrezorClient") -> protobuf.MessageType: 712 if client.features.bootloader_mode is not True: 713 raise RuntimeError("Device must be in bootloader mode") 714 715 return client.call( 716 messages.SelfTest( 717 payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" 718 ) 719 ) 720