1# Based on the public domain code from tlslite (Trevor Perrin, http://trevp.net/tlslite/) 2 3import copy 4 5shifts = [[[0, 0], [1, 3], [2, 2], [3, 1]], 6 [[0, 0], [1, 5], [2, 4], [3, 3]], 7 [[0, 0], [1, 7], [3, 5], [4, 4]]] 8 9# [keysize][block_size] 10num_rounds = {16: {16: 10, 24: 12, 32: 14}, 24: {16: 12, 24: 12, 32: 14}, 32: {16: 14, 24: 14, 32: 14}} 11 12A = [[1, 1, 1, 1, 1, 0, 0, 0], 13 [0, 1, 1, 1, 1, 1, 0, 0], 14 [0, 0, 1, 1, 1, 1, 1, 0], 15 [0, 0, 0, 1, 1, 1, 1, 1], 16 [1, 0, 0, 0, 1, 1, 1, 1], 17 [1, 1, 0, 0, 0, 1, 1, 1], 18 [1, 1, 1, 0, 0, 0, 1, 1], 19 [1, 1, 1, 1, 0, 0, 0, 1]] 20 21# produce log and alog tables, needed for multiplying in the 22# field GF(2^m) (generator = 3) 23alog = [1] 24for i in range(255): 25 j = (alog[-1] << 1) ^ alog[-1] 26 if j & 0x100 != 0: 27 j ^= 0x11B 28 alog.append(j) 29 30log = [0] * 256 31for i in range(1, 255): 32 log[alog[i]] = i 33 34# multiply two elements of GF(2^m) 35def mul(a, b): 36 if a == 0 or b == 0: 37 return 0 38 return alog[(log[a & 0xFF] + log[b & 0xFF]) % 255] 39 40# substitution box based on F^{-1}(x) 41box = [[0] * 8 for i in range(256)] 42box[1][7] = 1 43for i in range(2, 256): 44 j = alog[255 - log[i]] 45 for t in range(8): 46 box[i][t] = (j >> (7 - t)) & 0x01 47 48B = [0, 1, 1, 0, 0, 0, 1, 1] 49 50# affine transform: box[i] <- B + A*box[i] 51cox = [[0] * 8 for i in range(256)] 52for i in range(256): 53 for t in range(8): 54 cox[i][t] = B[t] 55 for j in range(8): 56 cox[i][t] ^= A[t][j] * box[i][j] 57 58# S-boxes and inverse S-boxes 59S = [0] * 256 60Si = [0] * 256 61for i in range(256): 62 S[i] = cox[i][0] << 7 63 for t in range(1, 8): 64 S[i] ^= cox[i][t] << (7-t) 65 Si[S[i] & 0xFF] = i 66 67# T-boxes 68G = [[2, 1, 1, 3], 69 [3, 2, 1, 1], 70 [1, 3, 2, 1], 71 [1, 1, 3, 2]] 72 73AA = [[0] * 8 for i in range(4)] 74 75for i in range(4): 76 for j in range(4): 77 AA[i][j] = G[i][j] 78 AA[i][i+4] = 1 79 80for i in range(4): 81 pivot = AA[i][i] 82 if pivot == 0: 83 t = i + 1 84 while AA[t][i] == 0 and t < 4: 85 t += 1 86 assert t != 4, 'G matrix must be invertible' 87 for j in range(8): 88 AA[i][j], AA[t][j] = AA[t][j], AA[i][j] 89 pivot = AA[i][i] 90 for j in range(8): 91 if AA[i][j] != 0: 92 AA[i][j] = alog[(255 + log[AA[i][j] & 0xFF] - log[pivot & 0xFF]) % 255] 93 for t in range(4): 94 if i != t: 95 for j in range(i+1, 8): 96 AA[t][j] ^= mul(AA[i][j], AA[t][i]) 97 AA[t][i] = 0 98 99iG = [[0] * 4 for i in range(4)] 100 101for i in range(4): 102 for j in range(4): 103 iG[i][j] = AA[i][j + 4] 104 105def mul4(a, bs): 106 if a == 0: 107 return 0 108 r = 0 109 for b in bs: 110 r <<= 8 111 if b != 0: 112 r = r | mul(a, b) 113 return r 114 115T1 = [] 116T2 = [] 117T3 = [] 118T4 = [] 119T5 = [] 120T6 = [] 121T7 = [] 122T8 = [] 123U1 = [] 124U2 = [] 125U3 = [] 126U4 = [] 127 128for t in range(256): 129 s = S[t] 130 T1.append(mul4(s, G[0])) 131 T2.append(mul4(s, G[1])) 132 T3.append(mul4(s, G[2])) 133 T4.append(mul4(s, G[3])) 134 135 s = Si[t] 136 T5.append(mul4(s, iG[0])) 137 T6.append(mul4(s, iG[1])) 138 T7.append(mul4(s, iG[2])) 139 T8.append(mul4(s, iG[3])) 140 141 U1.append(mul4(t, iG[0])) 142 U2.append(mul4(t, iG[1])) 143 U3.append(mul4(t, iG[2])) 144 U4.append(mul4(t, iG[3])) 145 146# round constants 147rcon = [1] 148r = 1 149for t in range(1, 30): 150 r = mul(2, r) 151 rcon.append(r) 152 153del A 154del AA 155del pivot 156del B 157del G 158del box 159del log 160del alog 161del i 162del j 163del r 164del s 165del t 166del mul 167del mul4 168del cox 169del iG 170 171class rijndael: 172 def __init__(self, key, block_size = 16): 173 if block_size != 16 and block_size != 24 and block_size != 32: 174 raise ValueError('Invalid block size: ' + str(block_size)) 175 if len(key) != 16 and len(key) != 24 and len(key) != 32: 176 raise ValueError('Invalid key size: ' + str(len(key))) 177 self.block_size = block_size 178 179 ROUNDS = num_rounds[len(key)][block_size] 180 BC = block_size // 4 181 # encryption round keys 182 Ke = [[0] * BC for i in range(ROUNDS + 1)] 183 # decryption round keys 184 Kd = [[0] * BC for i in range(ROUNDS + 1)] 185 ROUND_KEY_COUNT = (ROUNDS + 1) * BC 186 KC = len(key) // 4 187 188 # copy user material bytes into temporary ints 189 tk = [] 190 for i in range(0, KC): 191 tk.append((key[i * 4] << 24) | (key[i * 4 + 1] << 16) | 192 (key[i * 4 + 2] << 8) | key[i * 4 + 3]) 193 194 # copy values into round key arrays 195 t = 0 196 j = 0 197 while j < KC and t < ROUND_KEY_COUNT: 198 Ke[t // BC][t % BC] = tk[j] 199 Kd[ROUNDS - (t // BC)][t % BC] = tk[j] 200 j += 1 201 t += 1 202 tt = 0 203 rconpointer = 0 204 while t < ROUND_KEY_COUNT: 205 # extrapolate using phi (the round key evolution function) 206 tt = tk[KC - 1] 207 tk[0] ^= (S[(tt >> 16) & 0xFF] & 0xFF) << 24 ^ \ 208 (S[(tt >> 8) & 0xFF] & 0xFF) << 16 ^ \ 209 (S[ tt & 0xFF] & 0xFF) << 8 ^ \ 210 (S[(tt >> 24) & 0xFF] & 0xFF) ^ \ 211 (rcon[rconpointer] & 0xFF) << 24 212 rconpointer += 1 213 if KC != 8: 214 for i in range(1, KC): 215 tk[i] ^= tk[i-1] 216 else: 217 for i in range(1, KC // 2): 218 tk[i] ^= tk[i-1] 219 tt = tk[KC // 2 - 1] 220 tk[KC // 2] ^= (S[ tt & 0xFF] & 0xFF) ^ \ 221 (S[(tt >> 8) & 0xFF] & 0xFF) << 8 ^ \ 222 (S[(tt >> 16) & 0xFF] & 0xFF) << 16 ^ \ 223 (S[(tt >> 24) & 0xFF] & 0xFF) << 24 224 for i in range(KC // 2 + 1, KC): 225 tk[i] ^= tk[i-1] 226 # copy values into round key arrays 227 j = 0 228 while j < KC and t < ROUND_KEY_COUNT: 229 Ke[t // BC][t % BC] = tk[j] 230 Kd[ROUNDS - (t // BC)][t % BC] = tk[j] 231 j += 1 232 t += 1 233 # inverse MixColumn where needed 234 for r in range(1, ROUNDS): 235 for j in range(BC): 236 tt = Kd[r][j] 237 Kd[r][j] = U1[(tt >> 24) & 0xFF] ^ \ 238 U2[(tt >> 16) & 0xFF] ^ \ 239 U3[(tt >> 8) & 0xFF] ^ \ 240 U4[ tt & 0xFF] 241 self.Ke = Ke 242 self.Kd = Kd 243 244 def encrypt(self, plaintext): 245 if len(plaintext) != self.block_size: 246 raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext))) 247 Ke = self.Ke 248 249 BC = self.block_size // 4 250 ROUNDS = len(Ke) - 1 251 if BC == 4: 252 SC = 0 253 elif BC == 6: 254 SC = 1 255 else: 256 SC = 2 257 s1 = shifts[SC][1][0] 258 s2 = shifts[SC][2][0] 259 s3 = shifts[SC][3][0] 260 a = [0] * BC 261 # temporary work array 262 t = [] 263 # plaintext to ints + key 264 for i in range(BC): 265 t.append((plaintext[i * 4 ] << 24 | 266 plaintext[i * 4 + 1] << 16 | 267 plaintext[i * 4 + 2] << 8 | 268 plaintext[i * 4 + 3] ) ^ Ke[0][i]) 269 # apply round transforms 270 for r in range(1, ROUNDS): 271 for i in range(BC): 272 a[i] = (T1[(t[ i ] >> 24) & 0xFF] ^ 273 T2[(t[(i + s1) % BC] >> 16) & 0xFF] ^ 274 T3[(t[(i + s2) % BC] >> 8) & 0xFF] ^ 275 T4[ t[(i + s3) % BC] & 0xFF] ) ^ Ke[r][i] 276 t = copy.copy(a) 277 # last round is special 278 result = [] 279 for i in range(BC): 280 tt = Ke[ROUNDS][i] 281 result.append((S[(t[ i ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF) 282 result.append((S[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF) 283 result.append((S[(t[(i + s2) % BC] >> 8) & 0xFF] ^ (tt >> 8)) & 0xFF) 284 result.append((S[ t[(i + s3) % BC] & 0xFF] ^ tt ) & 0xFF) 285 return bytes(result) 286 287 def decrypt(self, ciphertext): 288 if len(ciphertext) != self.block_size: 289 raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(ciphertext))) 290 Kd = self.Kd 291 292 BC = self.block_size // 4 293 ROUNDS = len(Kd) - 1 294 if BC == 4: 295 SC = 0 296 elif BC == 6: 297 SC = 1 298 else: 299 SC = 2 300 s1 = shifts[SC][1][1] 301 s2 = shifts[SC][2][1] 302 s3 = shifts[SC][3][1] 303 a = [0] * BC 304 # temporary work array 305 t = [0] * BC 306 # ciphertext to ints + key 307 for i in range(BC): 308 t[i] = (ciphertext[i * 4 ] << 24 | 309 ciphertext[i * 4 + 1] << 16 | 310 ciphertext[i * 4 + 2] << 8 | 311 ciphertext[i * 4 + 3] ) ^ Kd[0][i] 312 # apply round transforms 313 for r in range(1, ROUNDS): 314 for i in range(BC): 315 a[i] = (T5[(t[ i ] >> 24) & 0xFF] ^ 316 T6[(t[(i + s1) % BC] >> 16) & 0xFF] ^ 317 T7[(t[(i + s2) % BC] >> 8) & 0xFF] ^ 318 T8[ t[(i + s3) % BC] & 0xFF] ) ^ Kd[r][i] 319 t = copy.copy(a) 320 # last round is special 321 result = [] 322 for i in range(BC): 323 tt = Kd[ROUNDS][i] 324 result.append((Si[(t[ i ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF) 325 result.append((Si[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF) 326 result.append((Si[(t[(i + s2) % BC] >> 8) & 0xFF] ^ (tt >> 8)) & 0xFF) 327 result.append((Si[ t[(i + s3) % BC] & 0xFF] ^ tt ) & 0xFF) 328 return bytes(result) 329 330def cbc_encrypt(plaintext, key, IV): 331 # padding 332 padding_size = 16 - (len(plaintext) % 16) 333 plaintext += bytes([padding_size]) * padding_size 334 335 # init 336 chainBytes = IV 337 cipher = rijndael(key) 338 ciphertext = b'' 339 340 # CBC Mode: For each block... 341 for x in range(len(plaintext)//16): 342 343 # XOR with the chaining block 344 block = bytes([plaintext[x*16+y] ^ chainBytes[y] for y in range(16)]) 345 346 # Encrypt and chain 347 chainBytes = cipher.encrypt(block) 348 ciphertext += chainBytes 349 350 return ciphertext 351 352def cbc_decrypt(ciphertext, key, IV): 353 chainBytes = IV 354 cipher = rijndael(key) 355 plaintext = b'' 356 357 # sanity check 358 if len(ciphertext)%16: raise ValueError('ciphertext not an integral number of blocks') 359 360 # CBC Mode: For each block... 361 for x in range(len(ciphertext)//16): 362 363 # Decrypt it 364 block = ciphertext[x*16 : (x*16)+16] 365 decrypted = cipher.decrypt(block) 366 367 # XOR with the chaining block and add to plaintext 368 plaintext += bytes([decrypted[y] ^ chainBytes[y] for y in range(16)]) 369 370 # Set the next chaining block 371 chainBytes = block 372 373 # padding 374 padding_size = plaintext[-1] 375 if padding_size > 16: raise ValueError('invalid padding') 376 377 return plaintext[:-padding_size] 378