1# This file is part of the Trezor project.
2#
3# Copyright (C) 2012-2022 SatoshiLabs and contributors
4#
5# This library is free software: you can redistribute it and/or modify
6# it under the terms of the GNU Lesser General Public License version 3
7# as published by the Free Software Foundation.
8#
9# This library is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12# GNU Lesser General Public License for more details.
13#
14# You should have received a copy of the License along with this library.
15# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
16
17import os
18import sys
19from typing import Any, Callable, Optional, Union
20
21import click
22from mnemonic import Mnemonic
23from typing_extensions import Protocol
24
25from . import device, messages
26from .client import MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE
27from .exceptions import Cancelled
28from .messages import PinMatrixRequestType, WordRequestType
29
30PIN_MATRIX_DESCRIPTION = """
31Use the numeric keypad or lowercase letters to describe number positions.
32
33The layout is:
34
35    7 8 9        e r t
36    4 5 6  -or-  d f g
37    1 2 3        c v b
38""".strip()
39
40RECOVERY_MATRIX_DESCRIPTION = """
41Use the numeric keypad to describe positions.
42For the word list use only left and right keys.
43Use backspace to correct an entry.
44
45The keypad layout is:
46    7 8 9     7 | 9
47    4 5 6     4 | 6
48    1 2 3     1 | 3
49""".strip()
50
51PIN_GENERIC = None
52PIN_CURRENT = PinMatrixRequestType.Current
53PIN_NEW = PinMatrixRequestType.NewFirst
54PIN_CONFIRM = PinMatrixRequestType.NewSecond
55WIPE_CODE_NEW = PinMatrixRequestType.WipeCodeFirst
56WIPE_CODE_CONFIRM = PinMatrixRequestType.WipeCodeSecond
57
58# Workaround for limitation of Git Bash
59# getpass function does not work correctly on Windows when not using a real terminal
60# (the hidden input is not allowed and it also freezes the script completely)
61# Details: https://bugs.python.org/issue44762
62CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty()
63
64
65class TrezorClientUI(Protocol):
66    def button_request(self, br: messages.ButtonRequest) -> None:
67        ...
68
69    def get_pin(self, code: Optional[PinMatrixRequestType]) -> str:
70        ...
71
72    def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
73        ...
74
75
76def echo(*args: Any, **kwargs: Any) -> None:
77    return click.echo(*args, err=True, **kwargs)
78
79
80def prompt(text: str, *, hide_input: bool = False, **kwargs: Any) -> Any:
81    # Disallowing hidden input and warning user when it would cause issues
82    if not CAN_HANDLE_HIDDEN_INPUT and hide_input:
83        hide_input = False
84        text += " (WARNING: will be displayed!)"
85    return click.prompt(text, hide_input=hide_input, err=True, **kwargs)
86
87
88class ClickUI:
89    def __init__(
90        self, always_prompt: bool = False, passphrase_on_host: bool = False
91    ) -> None:
92        self.pinmatrix_shown = False
93        self.prompt_shown = False
94        self.always_prompt = always_prompt
95        self.passphrase_on_host = passphrase_on_host
96
97    def button_request(self, _br: messages.ButtonRequest) -> None:
98        if not self.prompt_shown:
99            echo("Please confirm action on your Trezor device.")
100        if not self.always_prompt:
101            self.prompt_shown = True
102
103    def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
104        if code == PIN_CURRENT:
105            desc = "current PIN"
106        elif code == PIN_NEW:
107            desc = "new PIN"
108        elif code == PIN_CONFIRM:
109            desc = "new PIN again"
110        elif code == WIPE_CODE_NEW:
111            desc = "new wipe code"
112        elif code == WIPE_CODE_CONFIRM:
113            desc = "new wipe code again"
114        else:
115            desc = "PIN"
116
117        if not self.pinmatrix_shown:
118            echo(PIN_MATRIX_DESCRIPTION)
119            if not self.always_prompt:
120                self.pinmatrix_shown = True
121
122        while True:
123            try:
124                pin = prompt(f"Please enter {desc}", hide_input=True)
125            except click.Abort:
126                raise Cancelled from None
127
128            # translate letters to numbers if letters were used
129            if all(d in "cvbdfgert" for d in pin):
130                pin = pin.translate(str.maketrans("cvbdfgert", "123456789"))
131
132            if any(d not in "123456789" for d in pin):
133                echo(
134                    "The value may only consist of digits 1 to 9 or letters cvbdfgert."
135                )
136            elif len(pin) > MAX_PIN_LENGTH:
137                echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
138            else:
139                return pin
140
141    def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
142        if available_on_device and not self.passphrase_on_host:
143            return PASSPHRASE_ON_DEVICE
144
145        env_passphrase = os.getenv("PASSPHRASE")
146        if env_passphrase is not None:
147            echo("Passphrase required. Using PASSPHRASE environment variable.")
148            return env_passphrase
149
150        while True:
151            try:
152                passphrase = prompt(
153                    "Passphrase required",
154                    hide_input=True,
155                    default="",
156                    show_default=False,
157                )
158                # In case user sees the input on the screen, we do not need confirmation
159                if not CAN_HANDLE_HIDDEN_INPUT:
160                    return passphrase
161                second = prompt(
162                    "Confirm your passphrase",
163                    hide_input=True,
164                    default="",
165                    show_default=False,
166                )
167                if passphrase == second:
168                    return passphrase
169                else:
170                    echo("Passphrase did not match. Please try again.")
171            except click.Abort:
172                raise Cancelled from None
173
174
175def mnemonic_words(
176    expand: bool = False, language: str = "english"
177) -> Callable[[WordRequestType], str]:
178    if expand:
179        wordlist = Mnemonic(language).wordlist
180    else:
181        wordlist = []
182
183    def expand_word(word: str) -> str:
184        if not expand:
185            return word
186        if word in wordlist:
187            return word
188        matches = [w for w in wordlist if w.startswith(word)]
189        if len(matches) == 1:
190            return matches[0]
191        echo("Choose one of: " + ", ".join(matches))
192        raise KeyError(word)
193
194    def get_word(type: WordRequestType) -> str:
195        assert type == WordRequestType.Plain
196        while True:
197            try:
198                word = prompt("Enter one word of mnemonic")
199                return expand_word(word)
200            except KeyError:
201                pass
202            except click.Abort:
203                raise Cancelled from None
204
205    return get_word
206
207
208def matrix_words(type: WordRequestType) -> str:
209    while True:
210        try:
211            ch = click.getchar()
212        except (KeyboardInterrupt, EOFError):
213            raise Cancelled from None
214
215        if ch in "\x04\x1b":
216            # Ctrl+D, Esc
217            raise Cancelled
218        if ch in "\x08\x7f":
219            # Backspace, Del
220            return device.RECOVERY_BACK
221        if type == WordRequestType.Matrix6 and ch in "147369":
222            return ch
223        if type == WordRequestType.Matrix9 and ch in "123456789":
224            return ch
225