1# Driver for Mboot, the MicroPython boot loader
2# MIT license; Copyright (c) 2018 Damien P. George
3
4import struct, time, os, hashlib
5
6
7I2C_CMD_ECHO = 1
8I2C_CMD_GETID = 2
9I2C_CMD_GETCAPS = 3
10I2C_CMD_RESET = 4
11I2C_CMD_CONFIG = 5
12I2C_CMD_GETLAYOUT = 6
13I2C_CMD_MASSERASE = 7
14I2C_CMD_PAGEERASE = 8
15I2C_CMD_SETRDADDR = 9
16I2C_CMD_SETWRADDR = 10
17I2C_CMD_READ = 11
18I2C_CMD_WRITE = 12
19I2C_CMD_COPY = 13
20I2C_CMD_CALCHASH = 14
21I2C_CMD_MARKVALID = 15
22
23
24class Bootloader:
25    def __init__(self, i2c, addr):
26        self.i2c = i2c
27        self.addr = addr
28        self.buf1 = bytearray(1)
29        try:
30            self.i2c.writeto(addr, b"")
31        except OSError:
32            raise Exception("no I2C mboot device found")
33
34    def wait_response(self):
35        start = time.ticks_ms()
36        while 1:
37            try:
38                self.i2c.readfrom_into(self.addr, self.buf1)
39                n = self.buf1[0]
40                break
41            except OSError as er:
42                time.sleep_us(500)
43            if time.ticks_diff(time.ticks_ms(), start) > 5000:
44                raise Exception("timeout")
45        if n >= 129:
46            raise Exception(n)
47        if n == 0:
48            return b""
49        else:
50            return self.i2c.readfrom(self.addr, n)
51
52    def wait_empty_response(self):
53        ret = self.wait_response()
54        if ret:
55            raise Exception("expected empty response got %r" % ret)
56        else:
57            return None
58
59    def echo(self, data):
60        self.i2c.writeto(self.addr, struct.pack("<B", I2C_CMD_ECHO) + data)
61        return self.wait_response()
62
63    def getid(self):
64        self.i2c.writeto(self.addr, struct.pack("<B", I2C_CMD_GETID))
65        ret = self.wait_response()
66        unique_id = ret[:12]
67        mcu_name, board_name = ret[12:].split(b"\x00")
68        return unique_id, str(mcu_name, "ascii"), str(board_name, "ascii")
69
70    def reset(self):
71        self.i2c.writeto(self.addr, struct.pack("<B", I2C_CMD_RESET))
72        # we don't expect any response
73
74    def getlayout(self):
75        self.i2c.writeto(self.addr, struct.pack("<B", I2C_CMD_GETLAYOUT))
76        layout = self.wait_response()
77        id, flash_addr, layout = layout.split(b"/")
78        assert id == b"@Internal Flash  "
79        flash_addr = int(flash_addr, 16)
80        pages = []
81        for chunk in layout.split(b","):
82            n, sz = chunk.split(b"*")
83            n = int(n)
84            assert sz.endswith(b"Kg")
85            sz = int(sz[:-2]) * 1024
86            for i in range(n):
87                pages.append((flash_addr, sz))
88                flash_addr += sz
89        return pages
90
91    def pageerase(self, addr):
92        self.i2c.writeto(self.addr, struct.pack("<BI", I2C_CMD_PAGEERASE, addr))
93        self.wait_empty_response()
94
95    def setrdaddr(self, addr):
96        self.i2c.writeto(self.addr, struct.pack("<BI", I2C_CMD_SETRDADDR, addr))
97        self.wait_empty_response()
98
99    def setwraddr(self, addr):
100        self.i2c.writeto(self.addr, struct.pack("<BI", I2C_CMD_SETWRADDR, addr))
101        self.wait_empty_response()
102
103    def read(self, n):
104        self.i2c.writeto(self.addr, struct.pack("<BB", I2C_CMD_READ, n))
105        return self.wait_response()
106
107    def write(self, buf):
108        self.i2c.writeto(self.addr, struct.pack("<B", I2C_CMD_WRITE) + buf)
109        self.wait_empty_response()
110
111    def calchash(self, n):
112        self.i2c.writeto(self.addr, struct.pack("<BI", I2C_CMD_CALCHASH, n))
113        return self.wait_response()
114
115    def markvalid(self):
116        self.i2c.writeto(self.addr, struct.pack("<B", I2C_CMD_MARKVALID))
117        self.wait_empty_response()
118
119    def deployfile(self, filename, addr):
120        pages = self.getlayout()
121        page_erased = [False] * len(pages)
122        buf = bytearray(128)  # maximum payload supported by I2C protocol
123        start_addr = addr
124        self.setwraddr(addr)
125        fsize = os.stat(filename)[6]
126        local_sha = hashlib.sha256()
127        print("Deploying %s to location 0x%08x" % (filename, addr))
128        with open(filename, "rb") as f:
129            t0 = time.ticks_ms()
130            while True:
131                n = f.readinto(buf)
132                if n == 0:
133                    break
134
135                # check if we need to erase the page
136                for i, p in enumerate(pages):
137                    if p[0] <= addr < p[0] + p[1]:
138                        # found page
139                        if not page_erased[i]:
140                            print(
141                                "\r% 3u%% erase 0x%08x"
142                                % (100 * (addr - start_addr) // fsize, addr),
143                                end="",
144                            )
145                            self.pageerase(addr)
146                            page_erased[i] = True
147                        break
148                else:
149                    raise Exception("address 0x%08x not valid" % addr)
150
151                # write the data
152                self.write(buf)
153
154                # update local SHA256, with validity bits set
155                if addr == start_addr:
156                    buf[0] |= 3
157                if n == len(buf):
158                    local_sha.update(buf)
159                else:
160                    local_sha.update(buf[:n])
161
162                addr += n
163                ntotal = addr - start_addr
164                if ntotal % 2048 == 0 or ntotal == fsize:
165                    print("\r% 3u%% % 7u bytes   " % (100 * ntotal // fsize, ntotal), end="")
166            t1 = time.ticks_ms()
167        print()
168        print("rate: %.2f KiB/sec" % (1024 * ntotal / (t1 - t0) / 1000))
169
170        local_sha = local_sha.digest()
171        print("Local SHA256: ", "".join("%02x" % x for x in local_sha))
172
173        self.setrdaddr(start_addr)
174        remote_sha = self.calchash(ntotal)
175        print("Remote SHA256:", "".join("%02x" % x for x in remote_sha))
176
177        if local_sha == remote_sha:
178            print("Marking app firmware as valid")
179            self.markvalid()
180
181        self.reset()
182