1# This file is part of the Trezor project.
2#
3# Copyright (C) 2012-2022 SatoshiLabs and contributors
4#
5# This library is free software: you can redistribute it and/or modify
6# it under the terms of the GNU Lesser General Public License version 3
7# as published by the Free Software Foundation.
8#
9# This library is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12# GNU Lesser General Public License for more details.
13#
14# You should have received a copy of the License along with this library.
15# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
16
17import logging
18import textwrap
19from collections import namedtuple
20from copy import deepcopy
21from enum import IntEnum
22from itertools import zip_longest
23from typing import (
24    TYPE_CHECKING,
25    Any,
26    Callable,
27    Dict,
28    Generator,
29    Iterable,
30    Iterator,
31    List,
32    Optional,
33    Sequence,
34    Tuple,
35    Type,
36    Union,
37)
38
39from mnemonic import Mnemonic
40
41from . import mapping, messages, protobuf
42from .client import TrezorClient
43from .exceptions import TrezorFailure
44from .log import DUMP_BYTES
45from .tools import expect
46
47if TYPE_CHECKING:
48    from .transport import Transport
49    from .messages import PinMatrixRequestType
50
51    ExpectedMessage = Union[
52        protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter"
53    ]
54
55EXPECTED_RESPONSES_CONTEXT_LINES = 3
56
57LayoutLines = namedtuple("LayoutLines", "lines text")
58
59LOG = logging.getLogger(__name__)
60
61
62def layout_lines(lines: Sequence[str]) -> LayoutLines:
63    return LayoutLines(lines, " ".join(lines))
64
65
66class DebugLink:
67    def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
68        self.transport = transport
69        self.allow_interactions = auto_interact
70        self.mapping = mapping.DEFAULT_MAPPING
71
72    def open(self) -> None:
73        self.transport.begin_session()
74
75    def close(self) -> None:
76        self.transport.end_session()
77
78    def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any:
79        LOG.debug(
80            f"sending message: {msg.__class__.__name__}",
81            extra={"protobuf": msg},
82        )
83        msg_type, msg_bytes = self.mapping.encode(msg)
84        LOG.log(
85            DUMP_BYTES,
86            f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
87        )
88        self.transport.write(msg_type, msg_bytes)
89        if nowait:
90            return None
91
92        ret_type, ret_bytes = self.transport.read()
93        LOG.log(
94            DUMP_BYTES,
95            f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
96        )
97        msg = self.mapping.decode(ret_type, ret_bytes)
98        LOG.debug(
99            f"received message: {msg.__class__.__name__}",
100            extra={"protobuf": msg},
101        )
102        return msg
103
104    def state(self) -> messages.DebugLinkState:
105        return self._call(messages.DebugLinkGetState())
106
107    def read_layout(self) -> LayoutLines:
108        return layout_lines(self.state().layout_lines)
109
110    def wait_layout(self) -> LayoutLines:
111        obj = self._call(messages.DebugLinkGetState(wait_layout=True))
112        if isinstance(obj, messages.Failure):
113            raise TrezorFailure(obj)
114        return layout_lines(obj.layout_lines)
115
116    def watch_layout(self, watch: bool) -> None:
117        """Enable or disable watching layouts.
118        If disabled, wait_layout will not work.
119
120        The message is missing on T1. Use `TrezorClientDebugLink.watch_layout` for
121        cross-version compatibility.
122        """
123        self._call(messages.DebugLinkWatchLayout(watch=watch))
124
125    def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str:
126        """Transform correct PIN according to the displayed matrix."""
127        if matrix is None:
128            matrix = self.state().matrix
129            if matrix is None:
130                # we are on trezor-core
131                return pin
132
133        return "".join([str(matrix.index(p) + 1) for p in pin])
134
135    def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]:
136        state = self.state()
137        return (state.recovery_fake_word, state.recovery_word_pos)
138
139    def read_reset_word(self) -> str:
140        state = self._call(messages.DebugLinkGetState(wait_word_list=True))
141        return state.reset_word
142
143    def read_reset_word_pos(self) -> int:
144        state = self._call(messages.DebugLinkGetState(wait_word_pos=True))
145        return state.reset_word_pos
146
147    def input(
148        self,
149        word: Optional[str] = None,
150        button: Optional[bool] = None,
151        swipe: Optional[messages.DebugSwipeDirection] = None,
152        x: Optional[int] = None,
153        y: Optional[int] = None,
154        wait: Optional[bool] = None,
155        hold_ms: Optional[int] = None,
156    ) -> Optional[LayoutLines]:
157        if not self.allow_interactions:
158            return None
159
160        args = sum(a is not None for a in (word, button, swipe, x))
161        if args != 1:
162            raise ValueError("Invalid input - must use one of word, button, swipe")
163
164        decision = messages.DebugLinkDecision(
165            yes_no=button, swipe=swipe, input=word, x=x, y=y, wait=wait, hold_ms=hold_ms
166        )
167        ret = self._call(decision, nowait=not wait)
168        if ret is not None:
169            return layout_lines(ret.lines)
170
171        return None
172
173    def click(
174        self, click: Tuple[int, int], wait: bool = False
175    ) -> Optional[LayoutLines]:
176        x, y = click
177        return self.input(x=x, y=y, wait=wait)
178
179    def press_yes(self) -> None:
180        self.input(button=True)
181
182    def press_no(self) -> None:
183        self.input(button=False)
184
185    def swipe_up(self, wait: bool = False) -> None:
186        self.input(swipe=messages.DebugSwipeDirection.UP, wait=wait)
187
188    def swipe_down(self) -> None:
189        self.input(swipe=messages.DebugSwipeDirection.DOWN)
190
191    def swipe_right(self) -> None:
192        self.input(swipe=messages.DebugSwipeDirection.RIGHT)
193
194    def swipe_left(self) -> None:
195        self.input(swipe=messages.DebugSwipeDirection.LEFT)
196
197    def stop(self) -> None:
198        self._call(messages.DebugLinkStop(), nowait=True)
199
200    def reseed(self, value: int) -> protobuf.MessageType:
201        return self._call(messages.DebugLinkReseedRandom(value=value))
202
203    def start_recording(self, directory: str) -> None:
204        self._call(messages.DebugLinkRecordScreen(target_directory=directory))
205
206    def stop_recording(self) -> None:
207        self._call(messages.DebugLinkRecordScreen(target_directory=None))
208
209    @expect(messages.DebugLinkMemory, field="memory", ret_type=bytes)
210    def memory_read(self, address: int, length: int) -> protobuf.MessageType:
211        return self._call(messages.DebugLinkMemoryRead(address=address, length=length))
212
213    def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None:
214        self._call(
215            messages.DebugLinkMemoryWrite(address=address, memory=memory, flash=flash),
216            nowait=True,
217        )
218
219    def flash_erase(self, sector: int) -> None:
220        self._call(messages.DebugLinkFlashErase(sector=sector), nowait=True)
221
222    @expect(messages.Success)
223    def erase_sd_card(self, format: bool = True) -> messages.Success:
224        return self._call(messages.DebugLinkEraseSdCard(format=format))
225
226
227class NullDebugLink(DebugLink):
228    def __init__(self) -> None:
229        # Ignoring type error as self.transport will not be touched while using NullDebugLink
230        super().__init__(None)  # type: ignore ["None" cannot be assigned to parameter of type "Transport"]
231
232    def open(self) -> None:
233        pass
234
235    def close(self) -> None:
236        pass
237
238    def _call(
239        self, msg: protobuf.MessageType, nowait: bool = False
240    ) -> Optional[messages.DebugLinkState]:
241        if not nowait:
242            if isinstance(msg, messages.DebugLinkGetState):
243                return messages.DebugLinkState()
244            else:
245                raise RuntimeError("unexpected call to a fake debuglink")
246
247        return None
248
249
250class DebugUI:
251    INPUT_FLOW_DONE = object()
252
253    def __init__(self, debuglink: DebugLink) -> None:
254        self.debuglink = debuglink
255        self.clear()
256
257    def clear(self) -> None:
258        self.pins: Optional[Iterator[str]] = None
259        self.passphrase = ""
260        self.input_flow: Union[
261            Generator[None, messages.ButtonRequest, None], object, None
262        ] = None
263
264    def button_request(self, br: messages.ButtonRequest) -> None:
265        if self.input_flow is None:
266            if br.code == messages.ButtonRequestType.PinEntry:
267                self.debuglink.input(self.get_pin())
268            else:
269                if br.pages is not None:
270                    for _ in range(br.pages - 1):
271                        self.debuglink.swipe_up(wait=True)
272                self.debuglink.press_yes()
273        elif self.input_flow is self.INPUT_FLOW_DONE:
274            raise AssertionError("input flow ended prematurely")
275        else:
276            try:
277                assert isinstance(self.input_flow, Generator)
278                self.input_flow.send(br)
279            except StopIteration:
280                self.input_flow = self.INPUT_FLOW_DONE
281
282    def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str:
283        if self.pins is None:
284            raise RuntimeError("PIN requested but no sequence was configured")
285
286        try:
287            return self.debuglink.encode_pin(next(self.pins))
288        except StopIteration:
289            raise AssertionError("PIN sequence ended prematurely")
290
291    def get_passphrase(self, available_on_device: bool) -> str:
292        return self.passphrase
293
294
295class MessageFilter:
296    def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None:
297        self.message_type = message_type
298        self.fields: Dict[str, Any] = {}
299        self.update_fields(**fields)
300
301    def update_fields(self, **fields: Any) -> "MessageFilter":
302        for name, value in fields.items():
303            try:
304                self.fields[name] = self.from_message_or_type(value)
305            except TypeError:
306                self.fields[name] = value
307
308        return self
309
310    @classmethod
311    def from_message_or_type(
312        cls, message_or_type: "ExpectedMessage"
313    ) -> "MessageFilter":
314        if isinstance(message_or_type, cls):
315            return message_or_type
316        if isinstance(message_or_type, protobuf.MessageType):
317            return cls.from_message(message_or_type)
318        if isinstance(message_or_type, type) and issubclass(
319            message_or_type, protobuf.MessageType
320        ):
321            return cls(message_or_type)
322        raise TypeError("Invalid kind of expected response")
323
324    @classmethod
325    def from_message(cls, message: protobuf.MessageType) -> "MessageFilter":
326        fields = {}
327        for field in message.FIELDS.values():
328            value = getattr(message, field.name)
329            if value in (None, [], protobuf.REQUIRED_FIELD_PLACEHOLDER):
330                continue
331            fields[field.name] = value
332        return cls(type(message), **fields)
333
334    def match(self, message: protobuf.MessageType) -> bool:
335        if type(message) != self.message_type:
336            return False
337
338        for field, expected_value in self.fields.items():
339            actual_value = getattr(message, field, None)
340            if isinstance(expected_value, MessageFilter):
341                if actual_value is None or not expected_value.match(actual_value):
342                    return False
343            elif expected_value != actual_value:
344                return False
345
346        return True
347
348    def to_string(self, maxwidth: int = 80) -> str:
349        fields: List[Tuple[str, str]] = []
350        for field in self.message_type.FIELDS.values():
351            if field.name not in self.fields:
352                continue
353            value = self.fields[field.name]
354            if isinstance(value, IntEnum):
355                field_str = value.name
356            elif isinstance(value, MessageFilter):
357                field_str = value.to_string(maxwidth - 4)
358            elif isinstance(value, protobuf.MessageType):
359                field_str = protobuf.format_message(value)
360            else:
361                field_str = repr(value)
362            field_str = textwrap.indent(field_str, "    ").lstrip()
363            fields.append((field.name, field_str))
364
365        pairs = [f"{k}={v}" for k, v in fields]
366        oneline_str = ", ".join(pairs)
367        if len(oneline_str) < maxwidth:
368            return f"{self.message_type.__name__}({oneline_str})"
369        else:
370            item: List[str] = []
371            item.append(f"{self.message_type.__name__}(")
372            for pair in pairs:
373                item.append(f"    {pair}")
374            item.append(")")
375            return "\n".join(item)
376
377
378class MessageFilterGenerator:
379    def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
380        message_type = getattr(messages, key)
381        return MessageFilter(message_type).update_fields
382
383
384message_filters = MessageFilterGenerator()
385
386
387class TrezorClientDebugLink(TrezorClient):
388    # This class implements automatic responses
389    # and other functionality for unit tests
390    # for various callbacks, created in order
391    # to automatically pass unit tests.
392    #
393    # This mixing should be used only for purposes
394    # of unit testing, because it will fail to work
395    # without special DebugLink interface provided
396    # by the device.
397
398    def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
399        try:
400            debug_transport = transport.find_debug()
401            self.debug = DebugLink(debug_transport, auto_interact)
402            # try to open debuglink, see if it works
403            self.debug.open()
404            self.debug.close()
405        except Exception:
406            if not auto_interact:
407                self.debug = NullDebugLink()
408            else:
409                raise
410
411        self.reset_debug_features()
412
413        super().__init__(transport, ui=self.ui)
414
415    def reset_debug_features(self) -> None:
416        """Prepare the debugging client for a new testcase.
417
418        Clears all debugging state that might have been modified by a testcase.
419        """
420        self.ui: DebugUI = DebugUI(self.debug)
421        self.in_with_statement = False
422        self.expected_responses: Optional[List[MessageFilter]] = None
423        self.actual_responses: Optional[List[protobuf.MessageType]] = None
424        self.filters: Dict[
425            Type[protobuf.MessageType],
426            Callable[[protobuf.MessageType], protobuf.MessageType],
427        ] = {}
428
429    def open(self) -> None:
430        super().open()
431        if self.session_counter == 1:
432            self.debug.open()
433
434    def close(self) -> None:
435        if self.session_counter == 1:
436            self.debug.close()
437        super().close()
438
439    def set_filter(
440        self,
441        message_type: Type[protobuf.MessageType],
442        callback: Callable[[protobuf.MessageType], protobuf.MessageType],
443    ) -> None:
444        """Configure a filter function for a specified message type.
445
446        The `callback` must be a function that accepts a protobuf message, and returns
447        a (possibly modified) protobuf message of the same type. Whenever a message
448        is sent or received that matches `message_type`, `callback` is invoked on the
449        message and its result is substituted for the original.
450
451        Useful for test scenarios with an active malicious actor on the wire.
452        """
453        if not self.in_with_statement:
454            raise RuntimeError("Must be called inside 'with' statement")
455
456        self.filters[message_type] = callback
457
458    def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
459        message_type = msg.__class__
460        callback = self.filters.get(message_type)
461        if callable(callback):
462            return callback(deepcopy(msg))
463        else:
464            return msg
465
466    def set_input_flow(
467        self, input_flow: Generator[None, Optional[messages.ButtonRequest], None]
468    ) -> None:
469        """Configure a sequence of input events for the current with-block.
470
471        The `input_flow` must be a generator function. A `yield` statement in the
472        input flow function waits for a ButtonRequest from the device, and returns
473        its code.
474
475        Example usage:
476
477        >>> def input_flow():
478        >>>     # wait for first button prompt
479        >>>     code = yield
480        >>>     assert code == ButtonRequestType.Other
481        >>>     # press No
482        >>>     client.debug.press_no()
483        >>>
484        >>>     # wait for second button prompt
485        >>>     yield
486        >>>     # press Yes
487        >>>     client.debug.press_yes()
488        >>>
489        >>> with client:
490        >>>     client.set_input_flow(input_flow)
491        >>>     some_call(client)
492        """
493        if not self.in_with_statement:
494            raise RuntimeError("Must be called inside 'with' statement")
495
496        if callable(input_flow):
497            input_flow = input_flow()
498        if not hasattr(input_flow, "send"):
499            raise RuntimeError("input_flow should be a generator function")
500        self.ui.input_flow = input_flow
501        input_flow.send(None)  # start the generator
502
503    def watch_layout(self, watch: bool = True) -> None:
504        """Enable or disable watching layout changes.
505
506        Since trezor-core v2.3.2, it is necessary to call `watch_layout()` before
507        using `debug.wait_layout()`, otherwise layout changes are not reported.
508        """
509        if self.version >= (2, 3, 2):
510            # version check is necessary because otherwise we cannot reliably detect
511            # whether and where to wait for reply:
512            # - T1 reports unknown debuglink messages on the wirelink
513            # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
514            self.debug.watch_layout(watch)
515
516    def __enter__(self) -> "TrezorClientDebugLink":
517        # For usage in with/expected_responses
518        if self.in_with_statement:
519            raise RuntimeError("Do not nest!")
520        self.in_with_statement = True
521        return self
522
523    def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
524        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
525
526        self.watch_layout(False)
527        # copy expected/actual responses before clearing them
528        expected_responses = self.expected_responses
529        actual_responses = self.actual_responses
530        self.reset_debug_features()
531
532        if exc_type is None:
533            # If no other exception was raised, evaluate missed responses
534            # (raises AssertionError on mismatch)
535            self._verify_responses(expected_responses, actual_responses)
536
537    def set_expected_responses(
538        self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
539    ) -> None:
540        """Set a sequence of expected responses to client calls.
541
542        Within a given with-block, the list of received responses from device must
543        match the list of expected responses, otherwise an AssertionError is raised.
544
545        If an expected response is given a field value other than None, that field value
546        must exactly match the received field value. If a given field is None
547        (or unspecified) in the expected response, the received field value is not
548        checked.
549
550        Each expected response can also be a tuple (bool, message). In that case, the
551        expected response is only evaluated if the first field is True.
552        This is useful for differentiating sequences between Trezor models:
553
554        >>> trezor_one = client.features.model == "1"
555        >>> client.set_expected_responses([
556        >>>     messages.ButtonRequest(code=ConfirmOutput),
557        >>>     (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
558        >>>     messages.Success(),
559        >>> ])
560        """
561        if not self.in_with_statement:
562            raise RuntimeError("Must be called inside 'with' statement")
563
564        # make sure all items are (bool, message) tuples
565        expected_with_validity = (
566            e if isinstance(e, tuple) else (True, e) for e in expected
567        )
568
569        # only apply those items that are (True, message)
570        self.expected_responses = [
571            MessageFilter.from_message_or_type(expected)
572            for valid, expected in expected_with_validity
573            if valid
574        ]
575        self.actual_responses = []
576
577    def use_pin_sequence(self, pins: Iterable[str]) -> None:
578        """Respond to PIN prompts from device with the provided PINs.
579        The sequence must be at least as long as the expected number of PIN prompts.
580        """
581        self.ui.pins = iter(pins)
582
583    def use_passphrase(self, passphrase: str) -> None:
584        """Respond to passphrase prompts from device with the provided passphrase."""
585        self.ui.passphrase = Mnemonic.normalize_string(passphrase)
586
587    def use_mnemonic(self, mnemonic: str) -> None:
588        """Use the provided mnemonic to respond to device.
589        Only applies to T1, where device prompts the host for mnemonic words."""
590        self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
591
592    def _raw_read(self) -> protobuf.MessageType:
593        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
594
595        resp = super()._raw_read()
596        resp = self._filter_message(resp)
597        if self.actual_responses is not None:
598            self.actual_responses.append(resp)
599        return resp
600
601    def _raw_write(self, msg: protobuf.MessageType) -> None:
602        return super()._raw_write(self._filter_message(msg))
603
604    @staticmethod
605    def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]:
606        start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
607        stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
608        output: List[str] = []
609        output.append("Expected responses:")
610        if start_at > 0:
611            output.append(f"    (...{start_at} previous responses omitted)")
612        for i in range(start_at, stop_at):
613            exp = expected[i]
614            prefix = "    " if i != current else ">>> "
615            output.append(textwrap.indent(exp.to_string(), prefix))
616        if stop_at < len(expected):
617            omitted = len(expected) - stop_at
618            output.append(f"    (...{omitted} following responses omitted)")
619
620        output.append("")
621        return output
622
623    @classmethod
624    def _verify_responses(
625        cls,
626        expected: Optional[List[MessageFilter]],
627        actual: Optional[List[protobuf.MessageType]],
628    ) -> None:
629        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
630
631        if expected is None and actual is None:
632            return
633
634        assert expected is not None
635        assert actual is not None
636
637        for i, (exp, act) in enumerate(zip_longest(expected, actual)):
638            if exp is None:
639                output = cls._expectation_lines(expected, i)
640                output.append("No more messages were expected, but we got:")
641                for resp in actual[i:]:
642                    output.append(
643                        textwrap.indent(protobuf.format_message(resp), "    ")
644                    )
645                raise AssertionError("\n".join(output))
646
647            if act is None:
648                output = cls._expectation_lines(expected, i)
649                output.append("This and the following message was not received.")
650                raise AssertionError("\n".join(output))
651
652            if not exp.match(act):
653                output = cls._expectation_lines(expected, i)
654                output.append("Actually received:")
655                output.append(textwrap.indent(protobuf.format_message(act), "    "))
656                raise AssertionError("\n".join(output))
657
658    def mnemonic_callback(self, _) -> str:
659        word, pos = self.debug.read_recovery_word()
660        if word:
661            return word
662        if pos:
663            return self.mnemonic[pos - 1]
664
665        raise RuntimeError("Unexpected call")
666
667
668@expect(messages.Success, field="message", ret_type=str)
669def load_device(
670    client: "TrezorClient",
671    mnemonic: Union[str, Iterable[str]],
672    pin: Optional[str],
673    passphrase_protection: bool,
674    label: Optional[str],
675    language: str = "en-US",
676    skip_checksum: bool = False,
677    needs_backup: bool = False,
678    no_backup: bool = False,
679) -> protobuf.MessageType:
680    if isinstance(mnemonic, str):
681        mnemonic = [mnemonic]
682
683    mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
684
685    if client.features.initialized:
686        raise RuntimeError(
687            "Device is initialized already. Call device.wipe() and try again."
688        )
689
690    resp = client.call(
691        messages.LoadDevice(
692            mnemonics=mnemonics,
693            pin=pin,
694            passphrase_protection=passphrase_protection,
695            language=language,
696            label=label,
697            skip_checksum=skip_checksum,
698            needs_backup=needs_backup,
699            no_backup=no_backup,
700        )
701    )
702    client.init_device()
703    return resp
704
705
706# keep the old name for compatibility
707load_device_by_mnemonic = load_device
708
709
710@expect(messages.Success, field="message", ret_type=str)
711def self_test(client: "TrezorClient") -> protobuf.MessageType:
712    if client.features.bootloader_mode is not True:
713        raise RuntimeError("Device must be in bootloader mode")
714
715    return client.call(
716        messages.SelfTest(
717            payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
718        )
719    )
720