1# This file is part of the Trezor project.
2#
3# Copyright (C) 2012-2022 SatoshiLabs and contributors
4#
5# This library is free software: you can redistribute it and/or modify
6# it under the terms of the GNU Lesser General Public License version 3
7# as published by the Free Software Foundation.
8#
9# This library is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12# GNU Lesser General Public License for more details.
13#
14# You should have received a copy of the License along with this library.
15# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
16
17import hashlib
18from enum import Enum
19from hashlib import blake2s
20from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
21
22import construct as c
23import ecdsa
24
25from . import cosi, messages
26from .tools import session
27
28if TYPE_CHECKING:
29    from .client import TrezorClient
30
31V1_SIGNATURE_SLOTS = 3
32V1_BOOTLOADER_KEYS = [
33    bytes.fromhex(key)
34    for key in (
35        "04d571b7f148c5e4232c3814f777d8faeaf1a84216c78d569b71041ffc768a5b2d810fc3bb134dd026b57e65005275aedef43e155f48fc11a32ec790a93312bd58",
36        "0463279c0c0866e50c05c799d32bd6bab0188b6de06536d1109d2ed9ce76cb335c490e55aee10cc901215132e853097d5432eda06b792073bd7740c94ce4516cb1",
37        "0443aedbb6f7e71c563f8ed2ef64ec9981482519e7ef4f4aa98b27854e8c49126d4956d300ab45fdc34cd26bc8710de0a31dbdf6de7435fd0b492be70ac75fde58",
38        "04877c39fd7c62237e038235e9c075dab261630f78eeb8edb92487159fffedfdf6046c6f8b881fa407c4a4ce6c28de0b19c1f4e29f1fcbc5a58ffd1432a3e0938a",
39        "047384c51ae81add0a523adbb186c91b906ffb64c2c765802bf26dbd13bdf12c319e80c2213a136c8ee03d7874fd22b70d68e7dee469decfbbb510ee9a460cda45",
40    )
41]
42
43V2_BOARDLOADER_KEYS = [
44    bytes.fromhex(key)
45    for key in (
46        "0eb9856be9ba7e972c7f34eac1ed9b6fd0efd172ec00faf0c589759da4ddfba0",
47        "ac8ab40b32c98655798fd5da5e192be27a22306ea05c6d277cdff4a3f4125cd8",
48        "ce0fcd12543ef5936cf2804982136707863d17295faced72af171d6e6513ff06",
49    )
50]
51
52V2_BOARDLOADER_DEV_KEYS = [
53    bytes.fromhex(key)
54    for key in (
55        "db995fe25169d141cab9bbba92baa01f9f2e1ece7df4cb2ac05190f37fcc1f9d",
56        "2152f8d19b791d24453242e15f2eab6cb7cffa7b6a5ed30097960e069881db12",
57        "22fc297792f0b6ffc0bfcfdb7edb0c0aa14e025a365ec0e342e86e3829cb74b6",
58    )
59]
60
61V2_BOOTLOADER_KEYS = [
62    bytes.fromhex(key)
63    for key in (
64        "c2c87a49c5a3460977fbb2ec9dfe60f06bd694db8244bd4981fe3b7a26307f3f",
65        "80d036b08739b846f4cb77593078deb25dc9487aedcf52e30b4fb7cd7024178a",
66        "b8307a71f552c60a4cbb317ff48b82cdbf6b6bb5f04c920fec7badf017883751",
67    )
68]
69
70V2_SIGS_REQUIRED = 2
71
72ONEV2_CHUNK_SIZE = 1024 * 64
73V2_CHUNK_SIZE = 1024 * 128
74
75
76def _transform_vendor_trust(data: bytes) -> bytes:
77    """Byte-swap and bit-invert the VendorTrust field.
78
79    Vendor trust is interpreted as a bitmask in a 16-bit little-endian integer,
80    with the added twist that 0 means set and 1 means unset.
81    We feed it to a `BitStruct` that expects a big-endian sequence where bits have
82    the traditional meaning. We must therefore do a bitwise negation of each byte,
83    and return them in reverse order. This is the same transformation both ways,
84    fortunately, so we don't need two separate functions.
85    """
86    return bytes(~b & 0xFF for b in data)[::-1]
87
88
89class FirmwareIntegrityError(Exception):
90    pass
91
92
93class InvalidSignatureError(FirmwareIntegrityError):
94    pass
95
96
97class Unsigned(FirmwareIntegrityError):
98    pass
99
100
101class ToifMode(Enum):
102    full_color = b"f"
103    grayscale = b"g"
104
105
106class HeaderType(Enum):
107    FIRMWARE = b"TRZF"
108    BOOTLOADER = b"TRZB"
109
110
111class EnumAdapter(c.Adapter):
112    def __init__(self, subcon: Any, enum: Any) -> None:
113        self.enum = enum
114        super().__init__(subcon)
115
116    def _encode(self, obj: Any, ctx: Any, path: Any):
117        return obj.value
118
119    def _decode(self, obj: Any, ctx: Any, path: Any):
120        try:
121            return self.enum(obj)
122        except ValueError:
123            return obj
124
125
126# fmt: off
127Toif = c.Struct(
128    "magic" / c.Const(b"TOI"),
129    "format" / EnumAdapter(c.Bytes(1), ToifMode),
130    "width" / c.Int16ul,
131    "height" / c.Int16ul,
132    "data" / c.Prefixed(c.Int32ul, c.GreedyBytes),
133)
134
135
136VendorTrust = c.Transformed(c.BitStruct(
137    "_reserved" / c.Default(c.BitsInteger(9), 0),
138    "show_vendor_string" / c.Flag,
139    "require_user_click" / c.Flag,
140    "red_background" / c.Flag,
141    "delay" / c.BitsInteger(4),
142), _transform_vendor_trust, 2, _transform_vendor_trust, 2)
143
144
145VendorHeader = c.Struct(
146    "_start_offset" / c.Tell,
147    "magic" / c.Const(b"TRZV"),
148    "header_len" / c.Int32ul,
149    "expiry" / c.Int32ul,
150    "version" / c.Struct(
151        "major" / c.Int8ul,
152        "minor" / c.Int8ul,
153    ),
154    "sig_m" / c.Int8ul,
155    "sig_n" / c.Rebuild(c.Int8ul, c.len_(c.this.pubkeys)),
156    "trust" / VendorTrust,
157    "_reserved" / c.Padding(14),
158    "pubkeys" / c.Bytes(32)[c.this.sig_n],
159    "text" / c.Aligned(4, c.PascalString(c.Int8ul, "utf-8")),
160    "image" / Toif,
161    "_end_offset" / c.Tell,
162
163    "_min_header_len" / c.Check(c.this.header_len > (c.this._end_offset - c.this._start_offset) + 65),
164    "_header_len_aligned" / c.Check(c.this.header_len % 512 == 0),
165
166    c.Padding(c.this.header_len - c.this._end_offset + c.this._start_offset - 65),
167    "sigmask" / c.Byte,
168    "signature" / c.Bytes(64),
169)
170
171
172VersionLong = c.Struct(
173    "major" / c.Int8ul,
174    "minor" / c.Int8ul,
175    "patch" / c.Int8ul,
176    "build" / c.Int8ul,
177)
178
179
180FirmwareHeader = c.Struct(
181    "_start_offset" / c.Tell,
182    "magic" / EnumAdapter(c.Bytes(4), HeaderType),
183    "header_len" / c.Int32ul,
184    "expiry" / c.Int32ul,
185    "code_length" / c.Rebuild(
186        c.Int32ul,
187        lambda this:
188            len(this._.code) if "code" in this._
189            else (this.code_length or 0)
190    ),
191    "version" / VersionLong,
192    "fix_version" / VersionLong,
193    "_reserved" / c.Padding(8),
194    "hashes" / c.Bytes(32)[16],
195
196    "v1_signatures" / c.Bytes(64)[V1_SIGNATURE_SLOTS],
197    "v1_key_indexes" / c.Int8ul[V1_SIGNATURE_SLOTS],  # pylint: disable=E1136
198
199    "_reserved" / c.Padding(220),
200    "sigmask" / c.Byte,
201    "signature" / c.Bytes(64),
202
203    "_end_offset" / c.Tell,
204
205    "_rebuild_header_len" / c.If(
206        c.this.version.major > 1,
207        c.Pointer(
208            c.this._start_offset + 4,
209            c.Rebuild(c.Int32ul, c.this._end_offset - c.this._start_offset)
210        ),
211    ),
212)
213
214
215"""Raw firmware image.
216
217Consists of firmware header and code block.
218This is the expected format of firmware binaries for Trezor One, or bootloader images
219for Trezor T."""
220FirmwareImage = c.Struct(
221    "header" / FirmwareHeader,
222    "_code_offset" / c.Tell,
223    "code" / c.Bytes(c.this.header.code_length),
224    c.Terminated,
225)
226
227
228"""Firmware image prefixed by a vendor header.
229
230This is the expected format of firmware binaries for Trezor T."""
231VendorFirmware = c.Struct(
232    "vendor_header" / VendorHeader,
233    "image" / FirmwareImage,
234    c.Terminated,
235)
236
237
238"""Legacy firmware image.
239Consists of a custom header and code block.
240This is the expected format of firmware binaries for Trezor One pre-1.8.0.
241
242The code block can optionally be interpreted as a new-style firmware image. That is the
243expected format of firmware binary for Trezor One version 1.8.0, which can be installed
244by both the older and the newer bootloader."""
245LegacyFirmware = c.Struct(
246    "magic" / c.Const(b"TRZR"),
247    "code_length" / c.Rebuild(c.Int32ul, c.len_(c.this.code)),
248    "key_indexes" / c.Int8ul[V1_SIGNATURE_SLOTS],  # pylint: disable=E1136
249    "flags" / c.BitStruct(
250        c.Padding(7),
251        "restore_storage" / c.Flag,
252    ),
253    "_reserved" / c.Padding(52),
254    "signatures" / c.Bytes(64)[V1_SIGNATURE_SLOTS],
255    "code" / c.Bytes(c.this.code_length),
256    c.Terminated,
257
258    "embedded_onev2" / c.RestreamData(c.this.code, c.Optional(FirmwareImage)),
259)
260
261# fmt: on
262
263
264class FirmwareFormat(Enum):
265    TREZOR_ONE = 1
266    TREZOR_T = 2
267    TREZOR_ONE_V2 = 3
268
269
270ParsedFirmware = Tuple[FirmwareFormat, c.Container]
271
272
273def parse(data: bytes) -> ParsedFirmware:
274    if data[:4] == b"TRZR":
275        version = FirmwareFormat.TREZOR_ONE
276        cls = LegacyFirmware
277    elif data[:4] == b"TRZV":
278        version = FirmwareFormat.TREZOR_T
279        cls = VendorFirmware
280    elif data[:4] == b"TRZF":
281        version = FirmwareFormat.TREZOR_ONE_V2
282        cls = FirmwareImage
283    else:
284        raise ValueError("Unrecognized firmware image type")
285
286    try:
287        fw = cls.parse(data)
288    except Exception as e:
289        raise FirmwareIntegrityError("Invalid firmware image") from e
290    return version, fw
291
292
293def digest_onev1(fw: c.Container) -> bytes:
294    return hashlib.sha256(fw.code).digest()
295
296
297def check_sig_v1(
298    digest: bytes, key_indexes: List[int], signatures: List[bytes]
299) -> None:
300    distinct_key_indexes = set(i for i in key_indexes if i != 0)
301    if not distinct_key_indexes:
302        raise Unsigned
303
304    if len(distinct_key_indexes) < len(key_indexes):
305        raise InvalidSignatureError(
306            f"Not enough distinct signatures (found {len(distinct_key_indexes)}, need {len(key_indexes)})"
307        )
308
309    for i in range(len(key_indexes)):
310        key_idx = key_indexes[i] - 1
311        signature = signatures[i]
312
313        if key_idx >= len(V1_BOOTLOADER_KEYS):
314            # unknown pubkey
315            raise InvalidSignatureError(f"Unknown key in slot {i}")
316
317        pubkey = V1_BOOTLOADER_KEYS[key_idx][1:]
318        verify = ecdsa.VerifyingKey.from_string(pubkey, curve=ecdsa.curves.SECP256k1)
319        try:
320            verify.verify_digest(signature, digest)
321        except ecdsa.BadSignatureError as e:
322            raise InvalidSignatureError(f"Invalid signature in slot {i}") from e
323
324
325def header_digest(header: c.Container, hash_function: Callable = blake2s) -> bytes:
326    stripped_header = header.copy()
327    stripped_header.sigmask = 0
328    stripped_header.signature = b"\0" * 64
329    stripped_header.v1_key_indexes = [0, 0, 0]
330    stripped_header.v1_signatures = [b"\0" * 64] * 3
331    if header.magic == b"TRZV":
332        header_type = VendorHeader
333    else:
334        header_type = FirmwareHeader
335    header_bytes = header_type.build(stripped_header)
336    return hash_function(header_bytes).digest()
337
338
339def digest_v2(fw: c.Container) -> bytes:
340    return header_digest(fw.image.header, blake2s)
341
342
343def digest_onev2(fw: c.Container) -> bytes:
344    return header_digest(fw.header, hashlib.sha256)
345
346
347def calculate_code_hashes(
348    code: bytes,
349    code_offset: int,
350    hash_function: Callable = blake2s,
351    chunk_size: int = V2_CHUNK_SIZE,
352    padding_byte: Optional[bytes] = None,
353) -> List[bytes]:
354    hashes = []
355    # End offset for each chunk. Normally this would be (i+1)*chunk_size for i-th chunk,
356    # but the first chunk is shorter by code_offset, so all end offsets are shifted.
357    ends = [(i + 1) * chunk_size - code_offset for i in range(16)]
358    start = 0
359    for end in ends:
360        chunk = code[start:end]
361        # padding for last non-empty chunk
362        if padding_byte is not None and start < len(code) and end > len(code):
363            chunk += padding_byte[0:1] * (end - start - len(chunk))
364
365        if not chunk:
366            hashes.append(b"\0" * 32)
367        else:
368            hashes.append(hash_function(chunk).digest())
369
370        start = end
371
372    return hashes
373
374
375def validate_code_hashes(fw: c.Container, version: FirmwareFormat) -> None:
376    hash_function: Callable
377    padding_byte: Optional[bytes]
378    if version == FirmwareFormat.TREZOR_ONE_V2:
379        image = fw
380        hash_function = hashlib.sha256
381        chunk_size = ONEV2_CHUNK_SIZE
382        padding_byte = b"\xff"
383    else:
384        image = fw.image
385        hash_function = blake2s
386        chunk_size = V2_CHUNK_SIZE
387        padding_byte = None
388
389    expected_hashes = calculate_code_hashes(
390        image.code, image._code_offset, hash_function, chunk_size, padding_byte
391    )
392    if expected_hashes != image.header.hashes:
393        raise FirmwareIntegrityError("Invalid firmware data.")
394
395
396def validate_onev2(fw: c.Container, allow_unsigned: bool = False) -> None:
397    try:
398        check_sig_v1(
399            digest_onev2(fw),
400            fw.header.v1_key_indexes,
401            fw.header.v1_signatures,
402        )
403    except Unsigned:
404        if not allow_unsigned:
405            raise
406
407    validate_code_hashes(fw, FirmwareFormat.TREZOR_ONE_V2)
408
409
410def validate_onev1(fw: c.Container, allow_unsigned: bool = False) -> None:
411    try:
412        check_sig_v1(digest_onev1(fw), fw.key_indexes, fw.signatures)
413    except Unsigned:
414        if not allow_unsigned:
415            raise
416    if fw.embedded_onev2:
417        validate_onev2(fw.embedded_onev2, allow_unsigned)
418
419
420def validate_v2(fw: c.Container, skip_vendor_header: bool = False) -> None:
421    vendor_fingerprint = header_digest(fw.vendor_header)
422    fingerprint = digest_v2(fw)
423
424    if not skip_vendor_header:
425        try:
426            # if you want to validate a custom vendor header, you can modify
427            # the global variables to match your keys and m-of-n scheme
428            cosi.verify(
429                fw.vendor_header.signature,
430                vendor_fingerprint,
431                V2_SIGS_REQUIRED,
432                V2_BOOTLOADER_KEYS,
433                fw.vendor_header.sigmask,
434            )
435        except Exception:
436            raise InvalidSignatureError("Invalid vendor header signature.")
437
438        # XXX expiry is not used now
439        # now = time.gmtime()
440        # if time.gmtime(fw.vendor_header.expiry) < now:
441        #     raise ValueError("Vendor header expired.")
442
443    try:
444        cosi.verify(
445            fw.image.header.signature,
446            fingerprint,
447            fw.vendor_header.sig_m,
448            fw.vendor_header.pubkeys,
449            fw.image.header.sigmask,
450        )
451    except Exception:
452        raise InvalidSignatureError("Invalid firmware signature.")
453
454    # XXX expiry is not used now
455    # if time.gmtime(fw.image.header.expiry) < now:
456    #     raise ValueError("Firmware header expired.")
457    validate_code_hashes(fw, FirmwareFormat.TREZOR_T)
458
459
460def digest(version: FirmwareFormat, fw: c.Container) -> bytes:
461    if version == FirmwareFormat.TREZOR_ONE:
462        return digest_onev1(fw)
463    elif version == FirmwareFormat.TREZOR_ONE_V2:
464        return digest_onev2(fw)
465    elif version == FirmwareFormat.TREZOR_T:
466        return digest_v2(fw)
467    else:
468        raise ValueError("Unrecognized firmware version")
469
470
471def validate(
472    version: FirmwareFormat, fw: c.Container, allow_unsigned: bool = False
473) -> None:
474    if version == FirmwareFormat.TREZOR_ONE:
475        return validate_onev1(fw, allow_unsigned)
476    elif version == FirmwareFormat.TREZOR_ONE_V2:
477        return validate_onev2(fw, allow_unsigned)
478    elif version == FirmwareFormat.TREZOR_T:
479        return validate_v2(fw)
480    else:
481        raise ValueError("Unrecognized firmware version")
482
483
484# ====== Client functions ====== #
485
486
487@session
488def update(
489    client: "TrezorClient",
490    data: bytes,
491    progress_update: Callable[[int], Any] = lambda _: None,
492):
493    if client.features.bootloader_mode is False:
494        raise RuntimeError("Device must be in bootloader mode")
495
496    resp = client.call(messages.FirmwareErase(length=len(data)))
497
498    # TREZORv1 method
499    if isinstance(resp, messages.Success):
500        resp = client.call(messages.FirmwareUpload(payload=data))
501        progress_update(len(data))
502        if isinstance(resp, messages.Success):
503            return
504        else:
505            raise RuntimeError(f"Unexpected result {resp}")
506
507    # TREZORv2 method
508    while isinstance(resp, messages.FirmwareRequest):
509        assert resp.offset is not None
510        assert resp.length is not None
511        length = resp.length
512        payload = data[resp.offset : resp.offset + length]
513        digest = blake2s(payload).digest()
514        resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
515        progress_update(length)
516
517    if isinstance(resp, messages.Success):
518        return
519    else:
520        raise RuntimeError(f"Unexpected message {resp}")
521