1# Based on the public domain code from tlslite (Trevor Perrin, http://trevp.net/tlslite/)
2
3import copy
4
5shifts = [[[0, 0], [1, 3], [2, 2], [3, 1]],
6          [[0, 0], [1, 5], [2, 4], [3, 3]],
7          [[0, 0], [1, 7], [3, 5], [4, 4]]]
8
9# [keysize][block_size]
10num_rounds = {16: {16: 10, 24: 12, 32: 14}, 24: {16: 12, 24: 12, 32: 14}, 32: {16: 14, 24: 14, 32: 14}}
11
12A = [[1, 1, 1, 1, 1, 0, 0, 0],
13     [0, 1, 1, 1, 1, 1, 0, 0],
14     [0, 0, 1, 1, 1, 1, 1, 0],
15     [0, 0, 0, 1, 1, 1, 1, 1],
16     [1, 0, 0, 0, 1, 1, 1, 1],
17     [1, 1, 0, 0, 0, 1, 1, 1],
18     [1, 1, 1, 0, 0, 0, 1, 1],
19     [1, 1, 1, 1, 0, 0, 0, 1]]
20
21# produce log and alog tables, needed for multiplying in the
22# field GF(2^m) (generator = 3)
23alog = [1]
24for i in range(255):
25    j = (alog[-1] << 1) ^ alog[-1]
26    if j & 0x100 != 0:
27        j ^= 0x11B
28    alog.append(j)
29
30log = [0] * 256
31for i in range(1, 255):
32    log[alog[i]] = i
33
34# multiply two elements of GF(2^m)
35def mul(a, b):
36    if a == 0 or b == 0:
37        return 0
38    return alog[(log[a & 0xFF] + log[b & 0xFF]) % 255]
39
40# substitution box based on F^{-1}(x)
41box = [[0] * 8 for i in range(256)]
42box[1][7] = 1
43for i in range(2, 256):
44    j = alog[255 - log[i]]
45    for t in range(8):
46        box[i][t] = (j >> (7 - t)) & 0x01
47
48B = [0, 1, 1, 0, 0, 0, 1, 1]
49
50# affine transform:  box[i] <- B + A*box[i]
51cox = [[0] * 8 for i in range(256)]
52for i in range(256):
53    for t in range(8):
54        cox[i][t] = B[t]
55        for j in range(8):
56            cox[i][t] ^= A[t][j] * box[i][j]
57
58# S-boxes and inverse S-boxes
59S =  [0] * 256
60Si = [0] * 256
61for i in range(256):
62    S[i] = cox[i][0] << 7
63    for t in range(1, 8):
64        S[i] ^= cox[i][t] << (7-t)
65    Si[S[i] & 0xFF] = i
66
67# T-boxes
68G = [[2, 1, 1, 3],
69    [3, 2, 1, 1],
70    [1, 3, 2, 1],
71    [1, 1, 3, 2]]
72
73AA = [[0] * 8 for i in range(4)]
74
75for i in range(4):
76    for j in range(4):
77        AA[i][j] = G[i][j]
78        AA[i][i+4] = 1
79
80for i in range(4):
81    pivot = AA[i][i]
82    if pivot == 0:
83        t = i + 1
84        while AA[t][i] == 0 and t < 4:
85            t += 1
86            assert t != 4, 'G matrix must be invertible'
87            for j in range(8):
88                AA[i][j], AA[t][j] = AA[t][j], AA[i][j]
89            pivot = AA[i][i]
90    for j in range(8):
91        if AA[i][j] != 0:
92            AA[i][j] = alog[(255 + log[AA[i][j] & 0xFF] - log[pivot & 0xFF]) % 255]
93    for t in range(4):
94        if i != t:
95            for j in range(i+1, 8):
96                AA[t][j] ^= mul(AA[i][j], AA[t][i])
97            AA[t][i] = 0
98
99iG = [[0] * 4 for i in range(4)]
100
101for i in range(4):
102    for j in range(4):
103        iG[i][j] = AA[i][j + 4]
104
105def mul4(a, bs):
106    if a == 0:
107        return 0
108    r = 0
109    for b in bs:
110        r <<= 8
111        if b != 0:
112            r = r | mul(a, b)
113    return r
114
115T1 = []
116T2 = []
117T3 = []
118T4 = []
119T5 = []
120T6 = []
121T7 = []
122T8 = []
123U1 = []
124U2 = []
125U3 = []
126U4 = []
127
128for t in range(256):
129    s = S[t]
130    T1.append(mul4(s, G[0]))
131    T2.append(mul4(s, G[1]))
132    T3.append(mul4(s, G[2]))
133    T4.append(mul4(s, G[3]))
134
135    s = Si[t]
136    T5.append(mul4(s, iG[0]))
137    T6.append(mul4(s, iG[1]))
138    T7.append(mul4(s, iG[2]))
139    T8.append(mul4(s, iG[3]))
140
141    U1.append(mul4(t, iG[0]))
142    U2.append(mul4(t, iG[1]))
143    U3.append(mul4(t, iG[2]))
144    U4.append(mul4(t, iG[3]))
145
146# round constants
147rcon = [1]
148r = 1
149for t in range(1, 30):
150    r = mul(2, r)
151    rcon.append(r)
152
153del A
154del AA
155del pivot
156del B
157del G
158del box
159del log
160del alog
161del i
162del j
163del r
164del s
165del t
166del mul
167del mul4
168del cox
169del iG
170
171class rijndael:
172    def __init__(self, key, block_size = 16):
173        if block_size != 16 and block_size != 24 and block_size != 32:
174            raise ValueError('Invalid block size: ' + str(block_size))
175        if len(key) != 16 and len(key) != 24 and len(key) != 32:
176            raise ValueError('Invalid key size: ' + str(len(key)))
177        self.block_size = block_size
178
179        ROUNDS = num_rounds[len(key)][block_size]
180        BC = block_size // 4
181        # encryption round keys
182        Ke = [[0] * BC for i in range(ROUNDS + 1)]
183        # decryption round keys
184        Kd = [[0] * BC for i in range(ROUNDS + 1)]
185        ROUND_KEY_COUNT = (ROUNDS + 1) * BC
186        KC = len(key) // 4
187
188        # copy user material bytes into temporary ints
189        tk = []
190        for i in range(0, KC):
191            tk.append((key[i * 4] << 24) | (key[i * 4 + 1] << 16) |
192                (key[i * 4 + 2] << 8) | key[i * 4 + 3])
193
194        # copy values into round key arrays
195        t = 0
196        j = 0
197        while j < KC and t < ROUND_KEY_COUNT:
198            Ke[t // BC][t % BC] = tk[j]
199            Kd[ROUNDS - (t // BC)][t % BC] = tk[j]
200            j += 1
201            t += 1
202        tt = 0
203        rconpointer = 0
204        while t < ROUND_KEY_COUNT:
205            # extrapolate using phi (the round key evolution function)
206            tt = tk[KC - 1]
207            tk[0] ^= (S[(tt >> 16) & 0xFF] & 0xFF) << 24 ^  \
208                     (S[(tt >>  8) & 0xFF] & 0xFF) << 16 ^  \
209                     (S[ tt        & 0xFF] & 0xFF) <<  8 ^  \
210                     (S[(tt >> 24) & 0xFF] & 0xFF)       ^  \
211                     (rcon[rconpointer]    & 0xFF) << 24
212            rconpointer += 1
213            if KC != 8:
214                for i in range(1, KC):
215                    tk[i] ^= tk[i-1]
216            else:
217                for i in range(1, KC // 2):
218                    tk[i] ^= tk[i-1]
219                tt = tk[KC // 2 - 1]
220                tk[KC // 2] ^= (S[ tt        & 0xFF] & 0xFF)       ^ \
221                              (S[(tt >>  8) & 0xFF] & 0xFF) <<  8 ^ \
222                              (S[(tt >> 16) & 0xFF] & 0xFF) << 16 ^ \
223                              (S[(tt >> 24) & 0xFF] & 0xFF) << 24
224                for i in range(KC // 2 + 1, KC):
225                    tk[i] ^= tk[i-1]
226            # copy values into round key arrays
227            j = 0
228            while j < KC and t < ROUND_KEY_COUNT:
229                Ke[t // BC][t % BC] = tk[j]
230                Kd[ROUNDS - (t // BC)][t % BC] = tk[j]
231                j += 1
232                t += 1
233        # inverse MixColumn where needed
234        for r in range(1, ROUNDS):
235            for j in range(BC):
236                tt = Kd[r][j]
237                Kd[r][j] = U1[(tt >> 24) & 0xFF] ^ \
238                           U2[(tt >> 16) & 0xFF] ^ \
239                           U3[(tt >>  8) & 0xFF] ^ \
240                           U4[ tt        & 0xFF]
241        self.Ke = Ke
242        self.Kd = Kd
243
244    def encrypt(self, plaintext):
245        if len(plaintext) != self.block_size:
246            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext)))
247        Ke = self.Ke
248
249        BC = self.block_size // 4
250        ROUNDS = len(Ke) - 1
251        if BC == 4:
252            SC = 0
253        elif BC == 6:
254            SC = 1
255        else:
256            SC = 2
257        s1 = shifts[SC][1][0]
258        s2 = shifts[SC][2][0]
259        s3 = shifts[SC][3][0]
260        a = [0] * BC
261        # temporary work array
262        t = []
263        # plaintext to ints + key
264        for i in range(BC):
265            t.append((plaintext[i * 4    ] << 24 |
266                      plaintext[i * 4 + 1] << 16 |
267                      plaintext[i * 4 + 2] <<  8 |
268                      plaintext[i * 4 + 3]        ) ^ Ke[0][i])
269        # apply round transforms
270        for r in range(1, ROUNDS):
271            for i in range(BC):
272                a[i] = (T1[(t[ i           ] >> 24) & 0xFF] ^
273                        T2[(t[(i + s1) % BC] >> 16) & 0xFF] ^
274                        T3[(t[(i + s2) % BC] >>  8) & 0xFF] ^
275                        T4[ t[(i + s3) % BC]        & 0xFF]  ) ^ Ke[r][i]
276            t = copy.copy(a)
277        # last round is special
278        result = []
279        for i in range(BC):
280            tt = Ke[ROUNDS][i]
281            result.append((S[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
282            result.append((S[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
283            result.append((S[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
284            result.append((S[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
285        return bytes(result)
286
287    def decrypt(self, ciphertext):
288        if len(ciphertext) != self.block_size:
289            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(ciphertext)))
290        Kd = self.Kd
291
292        BC = self.block_size // 4
293        ROUNDS = len(Kd) - 1
294        if BC == 4:
295            SC = 0
296        elif BC == 6:
297            SC = 1
298        else:
299            SC = 2
300        s1 = shifts[SC][1][1]
301        s2 = shifts[SC][2][1]
302        s3 = shifts[SC][3][1]
303        a = [0] * BC
304        # temporary work array
305        t = [0] * BC
306        # ciphertext to ints + key
307        for i in range(BC):
308            t[i] = (ciphertext[i * 4    ] << 24 |
309                    ciphertext[i * 4 + 1] << 16 |
310                    ciphertext[i * 4 + 2] <<  8 |
311                    ciphertext[i * 4 + 3]        ) ^ Kd[0][i]
312        # apply round transforms
313        for r in range(1, ROUNDS):
314            for i in range(BC):
315                a[i] = (T5[(t[ i           ] >> 24) & 0xFF] ^
316                        T6[(t[(i + s1) % BC] >> 16) & 0xFF] ^
317                        T7[(t[(i + s2) % BC] >>  8) & 0xFF] ^
318                        T8[ t[(i + s3) % BC]        & 0xFF]  ) ^ Kd[r][i]
319            t = copy.copy(a)
320        # last round is special
321        result = []
322        for i in range(BC):
323            tt = Kd[ROUNDS][i]
324            result.append((Si[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
325            result.append((Si[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
326            result.append((Si[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
327            result.append((Si[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
328        return bytes(result)
329
330def cbc_encrypt(plaintext, key, IV):
331    # padding
332    padding_size = 16 - (len(plaintext) % 16)
333    plaintext += bytes([padding_size]) * padding_size
334
335    # init
336    chainBytes = IV
337    cipher = rijndael(key)
338    ciphertext = b''
339
340    # CBC Mode: For each block...
341    for x in range(len(plaintext)//16):
342
343        # XOR with the chaining block
344        block = bytes([plaintext[x*16+y] ^ chainBytes[y] for y in range(16)])
345
346        # Encrypt and chain
347        chainBytes = cipher.encrypt(block)
348        ciphertext += chainBytes
349
350    return ciphertext
351
352def cbc_decrypt(ciphertext, key, IV):
353    chainBytes = IV
354    cipher = rijndael(key)
355    plaintext = b''
356
357    # sanity check
358    if len(ciphertext)%16: raise ValueError('ciphertext not an integral number of blocks')
359
360    # CBC Mode: For each block...
361    for x in range(len(ciphertext)//16):
362
363        # Decrypt it
364        block = ciphertext[x*16 : (x*16)+16]
365        decrypted = cipher.decrypt(block)
366
367        # XOR with the chaining block and add to plaintext
368        plaintext += bytes([decrypted[y] ^ chainBytes[y] for y in range(16)])
369
370        # Set the next chaining block
371        chainBytes = block
372
373    # padding
374    padding_size = plaintext[-1]
375    if padding_size > 16: raise ValueError('invalid padding')
376
377    return plaintext[:-padding_size]
378