1import time
2from struct import pack
3
4from electrum import ecc
5from electrum.i18n import _
6from electrum.util import UserCancelled, UserFacingException
7from electrum.keystore import bip39_normalize_passphrase
8from electrum.bip32 import BIP32Node, convert_bip32_path_to_list_of_uint32 as parse_path
9from electrum.logging import Logger
10from electrum.plugin import runs_in_hwd_thread
11from electrum.plugins.hw_wallet.plugin import OutdatedHwFirmwareException, HardwareClientBase
12
13from trezorlib.client import TrezorClient, PASSPHRASE_ON_DEVICE
14from trezorlib.exceptions import TrezorFailure, Cancelled, OutdatedFirmwareError
15from trezorlib.messages import WordRequestType, FailureType, RecoveryDeviceType, ButtonRequestType
16import trezorlib.btc
17import trezorlib.device
18
19MESSAGES = {
20    ButtonRequestType.ConfirmOutput:
21        _("Confirm the transaction output on your {} device"),
22    ButtonRequestType.ResetDevice:
23        _("Complete the initialization process on your {} device"),
24    ButtonRequestType.ConfirmWord:
25        _("Write down the seed word shown on your {}"),
26    ButtonRequestType.WipeDevice:
27        _("Confirm on your {} that you want to wipe it clean"),
28    ButtonRequestType.ProtectCall:
29        _("Confirm on your {} device the message to sign"),
30    ButtonRequestType.SignTx:
31        _("Confirm the total amount spent and the transaction fee on your {} device"),
32    ButtonRequestType.Address:
33        _("Confirm wallet address on your {} device"),
34    ButtonRequestType._Deprecated_ButtonRequest_PassphraseType:
35        _("Choose on your {} device where to enter your passphrase"),
36    ButtonRequestType.PassphraseEntry:
37        _("Please enter your passphrase on the {} device"),
38    'default': _("Check your {} device to continue"),
39}
40
41
42class TrezorClientBase(HardwareClientBase, Logger):
43    def __init__(self, transport, handler, plugin):
44        HardwareClientBase.__init__(self, plugin=plugin)
45        if plugin.is_outdated_fw_ignored():
46            TrezorClient.is_outdated = lambda *args, **kwargs: False
47        self.client = TrezorClient(transport, ui=self)
48        self.device = plugin.device
49        self.handler = handler
50        Logger.__init__(self)
51
52        self.msg = None
53        self.creating_wallet = False
54
55        self.in_flow = False
56
57        self.used()
58
59    def run_flow(self, message=None, creating_wallet=False):
60        if self.in_flow:
61            raise RuntimeError("Overlapping call to run_flow")
62
63        self.in_flow = True
64        self.msg = message
65        self.creating_wallet = creating_wallet
66        self.prevent_timeouts()
67        return self
68
69    def end_flow(self):
70        self.in_flow = False
71        self.msg = None
72        self.creating_wallet = False
73        self.handler.finished()
74        self.used()
75
76    def __enter__(self):
77        return self
78
79    def __exit__(self, exc_type, e, traceback):
80        self.end_flow()
81        if e is not None:
82            if isinstance(e, Cancelled):
83                raise UserCancelled() from e
84            elif isinstance(e, TrezorFailure):
85                raise RuntimeError(str(e)) from e
86            elif isinstance(e, OutdatedFirmwareError):
87                raise OutdatedHwFirmwareException(e) from e
88            else:
89                return False
90        return True
91
92    @property
93    def features(self):
94        return self.client.features
95
96    def __str__(self):
97        return "%s/%s" % (self.label(), self.features.device_id)
98
99    def label(self):
100        return self.features.label
101
102    def get_soft_device_id(self):
103        return self.features.device_id
104
105    def is_initialized(self):
106        return self.features.initialized
107
108    def is_pairable(self):
109        return not self.features.bootloader_mode
110
111    @runs_in_hwd_thread
112    def has_usable_connection_with_device(self):
113        if self.in_flow:
114            return True
115
116        try:
117            self.client.init_device()
118        except BaseException:
119            return False
120        return True
121
122    def used(self):
123        self.last_operation = time.time()
124
125    def prevent_timeouts(self):
126        self.last_operation = float('inf')
127
128    @runs_in_hwd_thread
129    def timeout(self, cutoff):
130        '''Time out the client if the last operation was before cutoff.'''
131        if self.last_operation < cutoff:
132            self.logger.info("timed out")
133            self.clear_session()
134
135    def i4b(self, x):
136        return pack('>I', x)
137
138    @runs_in_hwd_thread
139    def get_xpub(self, bip32_path, xtype, creating=False):
140        address_n = parse_path(bip32_path)
141        with self.run_flow(creating_wallet=creating):
142            node = trezorlib.btc.get_public_node(self.client, address_n).node
143        return BIP32Node(xtype=xtype,
144                         eckey=ecc.ECPubkey(node.public_key),
145                         chaincode=node.chain_code,
146                         depth=node.depth,
147                         fingerprint=self.i4b(node.fingerprint),
148                         child_number=self.i4b(node.child_num)).to_xpub()
149
150    @runs_in_hwd_thread
151    def toggle_passphrase(self):
152        if self.features.passphrase_protection:
153            msg = _("Confirm on your {} device to disable passphrases")
154        else:
155            msg = _("Confirm on your {} device to enable passphrases")
156        enabled = not self.features.passphrase_protection
157        with self.run_flow(msg):
158            trezorlib.device.apply_settings(self.client, use_passphrase=enabled)
159
160    @runs_in_hwd_thread
161    def change_label(self, label):
162        with self.run_flow(_("Confirm the new label on your {} device")):
163            trezorlib.device.apply_settings(self.client, label=label)
164
165    @runs_in_hwd_thread
166    def change_homescreen(self, homescreen):
167        with self.run_flow(_("Confirm on your {} device to change your home screen")):
168            trezorlib.device.apply_settings(self.client, homescreen=homescreen)
169
170    @runs_in_hwd_thread
171    def set_pin(self, remove):
172        if remove:
173            msg = _("Confirm on your {} device to disable PIN protection")
174        elif self.features.pin_protection:
175            msg = _("Confirm on your {} device to change your PIN")
176        else:
177            msg = _("Confirm on your {} device to set a PIN")
178        with self.run_flow(msg):
179            trezorlib.device.change_pin(self.client, remove)
180
181    @runs_in_hwd_thread
182    def clear_session(self):
183        '''Clear the session to force pin (and passphrase if enabled)
184        re-entry.  Does not leak exceptions.'''
185        self.logger.info(f"clear session: {self}")
186        self.prevent_timeouts()
187        try:
188            self.client.clear_session()
189        except BaseException as e:
190            # If the device was removed it has the same effect...
191            self.logger.info(f"clear_session: ignoring error {e}")
192
193    @runs_in_hwd_thread
194    def close(self):
195        '''Called when Our wallet was closed or the device removed.'''
196        self.logger.info("closing client")
197        self.clear_session()
198
199    @runs_in_hwd_thread
200    def is_uptodate(self):
201        if self.client.is_outdated():
202            return False
203        return self.client.version >= self.plugin.minimum_firmware
204
205    def get_trezor_model(self):
206        """Returns '1' for Trezor One, 'T' for Trezor T."""
207        return self.features.model
208
209    def device_model_name(self):
210        model = self.get_trezor_model()
211        if model == '1':
212            return "Trezor One"
213        elif model == 'T':
214            return "Trezor T"
215        return None
216
217    @runs_in_hwd_thread
218    def show_address(self, address_str, script_type, multisig=None):
219        coin_name = self.plugin.get_coin_name()
220        address_n = parse_path(address_str)
221        with self.run_flow():
222            return trezorlib.btc.get_address(
223                self.client,
224                coin_name,
225                address_n,
226                show_display=True,
227                script_type=script_type,
228                multisig=multisig)
229
230    @runs_in_hwd_thread
231    def sign_message(self, address_str, message):
232        coin_name = self.plugin.get_coin_name()
233        address_n = parse_path(address_str)
234        with self.run_flow():
235            return trezorlib.btc.sign_message(
236                self.client,
237                coin_name,
238                address_n,
239                message)
240
241    @runs_in_hwd_thread
242    def recover_device(self, recovery_type, *args, **kwargs):
243        input_callback = self.mnemonic_callback(recovery_type)
244        with self.run_flow():
245            return trezorlib.device.recover(
246                self.client,
247                *args,
248                input_callback=input_callback,
249                type=recovery_type,
250                **kwargs)
251
252    # ========= Unmodified trezorlib methods =========
253
254    @runs_in_hwd_thread
255    def sign_tx(self, *args, **kwargs):
256        with self.run_flow():
257            return trezorlib.btc.sign_tx(self.client, *args, **kwargs)
258
259    @runs_in_hwd_thread
260    def reset_device(self, *args, **kwargs):
261        with self.run_flow():
262            return trezorlib.device.reset(self.client, *args, **kwargs)
263
264    @runs_in_hwd_thread
265    def wipe_device(self, *args, **kwargs):
266        with self.run_flow():
267            return trezorlib.device.wipe(self.client, *args, **kwargs)
268
269    # ========= UI methods ==========
270
271    def button_request(self, code):
272        message = self.msg or MESSAGES.get(code) or MESSAGES['default']
273        self.handler.show_message(message.format(self.device), self.client.cancel)
274
275    def get_pin(self, code=None):
276        show_strength = True
277        if code == 2:
278            msg = _("Enter a new PIN for your {}:")
279        elif code == 3:
280            msg = (_("Re-enter the new PIN for your {}.\n\n"
281                     "NOTE: the positions of the numbers have changed!"))
282        else:
283            msg = _("Enter your current {} PIN:")
284            show_strength = False
285        pin = self.handler.get_pin(msg.format(self.device), show_strength=show_strength)
286        if not pin:
287            raise Cancelled
288        if len(pin) > 9:
289            self.handler.show_error(_('The PIN cannot be longer than 9 characters.'))
290            raise Cancelled
291        return pin
292
293    def get_passphrase(self, available_on_device):
294        if self.creating_wallet:
295            msg = _("Enter a passphrase to generate this wallet.  Each time "
296                    "you use this wallet your {} will prompt you for the "
297                    "passphrase.  If you forget the passphrase you cannot "
298                    "access the bitcoins in the wallet.").format(self.device)
299        else:
300            msg = _("Enter the passphrase to unlock this wallet:")
301
302        self.handler.passphrase_on_device = available_on_device
303        passphrase = self.handler.get_passphrase(msg, self.creating_wallet)
304        if passphrase is PASSPHRASE_ON_DEVICE:
305            return passphrase
306        if passphrase is None:
307            raise Cancelled
308        passphrase = bip39_normalize_passphrase(passphrase)
309        length = len(passphrase)
310        if length > 50:
311            self.handler.show_error(_("Too long passphrase ({} > 50 chars).").format(length))
312            raise Cancelled
313        return passphrase
314
315    def _matrix_char(self, matrix_type):
316        num = 9 if matrix_type == WordRequestType.Matrix9 else 6
317        char = self.handler.get_matrix(num)
318        if char == 'x':
319            raise Cancelled
320        return char
321
322    def mnemonic_callback(self, recovery_type):
323        if recovery_type is None:
324            return None
325
326        if recovery_type == RecoveryDeviceType.Matrix:
327            return self._matrix_char
328
329        step = 0
330        def word_callback(_ignored):
331            nonlocal step
332            step += 1
333            msg = _("Step {}/24.  Enter seed word as explained on your {}:").format(step, self.device)
334            word = self.handler.get_word(msg)
335            if not word:
336                raise Cancelled
337            return word
338        return word_callback
339