1 /* Rijndael Block Cipher - rijndael.c
2 
3    Written by Mike Scott 21st April 1999
4    mike@compapp.dcu.ie
5 
6    Permission for free direct or derivative use is granted subject
7    to compliance with any conditions that the originators of the
8    algorithm place on its exploitation.
9 
10 */
11 
12 #include <stdio.h>
13 #include <string.h>
14 #include "rijndael.h"
15 
16 #ifndef ASSERT
17     #define ASSERT(cond)
18 #endif
19 
20 /* rotates x one bit to the left */
21 
22 #define ROTL(x) (((x)>>7)|((x)<<1))
23 
24 /* Rotates 32-bit word left by 1, 2 or 3 byte  */
25 
26 #define ROTL8(x) (((x)<<8)|((x)>>24))
27 #define ROTL16(x) (((x)<<16)|((x)>>16))
28 #define ROTL24(x) (((x)<<24)|((x)>>8))
29 
30 /* Fixed Data */
31 
32 static int tab_init = 0;
33 
34 static u8 InCo[4]={0xB,0xD,0x9,0xE};  /* Inverse Coefficients */
35 
36 static u8 fbsub[256];
37 static u8 rbsub[256];
38 static u8 ptab[256],ltab[256];
39 static u32 ftable[256], ftable8[256], ftable16[256], ftable24[256];
40 static u32 rtable[256], rtable8[256], rtable16[256], rtable24[256];
41 static u32 rco[30];
42 
pack(u8 * b)43 static u32 pack(u8 *b)
44 {
45   /* pack bytes into a 32-bit Word */
46     return ((u32)b[3]<<24)|((u32)b[2]<<16)|((u32)b[1]<<8)|(u32)b[0];
47 }
48 
unpack(u32 a,u8 * b)49 static void unpack(u32 a,u8 *b)
50 {
51   /* unpack bytes from a word */
52     b[0]=(u8)a;
53     b[1]=(u8)(a>>8);
54     b[2]=(u8)(a>>16);
55     b[3]=(u8)(a>>24);
56 }
57 
xtime(u8 a)58 static u8 xtime(u8 a)
59 {
60     u8 b;
61     if (a&0x80) b=0x1B;
62     else        b=0;
63     a<<=1;
64     a^=b;
65     return a;
66 }
67 
bmul(u8 x,u8 y)68 static u8 bmul(u8 x,u8 y)
69 {
70   /* x.y= AntiLog(Log(x) + Log(y)) */
71     if (x && y) return ptab[(ltab[x]+ltab[y])%255];
72     else return 0;
73 }
74 
SubByte(u32 a)75 static u32 SubByte(u32 a)
76 {
77     u8 b[4];
78     unpack(a,b);
79     b[0]=fbsub[b[0]];
80     b[1]=fbsub[b[1]];
81     b[2]=fbsub[b[2]];
82     b[3]=fbsub[b[3]];
83     return pack(b);
84 }
85 
product(u32 x,u32 y)86 static u8 product(u32 x,u32 y)
87 {
88   /* dot product of two 4-byte arrays */
89     u8 xb[4],yb[4];
90     unpack(x,xb);
91     unpack(y,yb);
92     return bmul(xb[0],yb[0])^bmul(xb[1],yb[1])^bmul(xb[2],yb[2])^bmul(xb[3],yb[3]);
93 }
94 
InvMixCol(u32 x)95 static u32 InvMixCol(u32 x)
96 {
97   /* matrix Multiplication */
98     u32 y,m;
99     u8 b[4];
100 
101     m=pack(InCo);
102     b[3]=product(m,x);
103     m=ROTL24(m);
104     b[2]=product(m,x);
105     m=ROTL24(m);
106     b[1]=product(m,x);
107     m=ROTL24(m);
108     b[0]=product(m,x);
109     y=pack(b);
110     return y;
111 }
112 
ByteSub(u8 x)113 static u8 ByteSub(u8 x)
114 {
115     u8 y=ptab[255-ltab[x]];  /* multiplicative inverse */
116     x=y;  x=ROTL(x);
117     y^=x; x=ROTL(x);
118     y^=x; x=ROTL(x);
119     y^=x; x=ROTL(x);
120     y^=x; y^=0x63;
121     return y;
122 }
123 
gentables(void)124 static void gentables(void)
125 {
126   /* generate tables */
127     int i;
128     u8 y,b[4];
129 
130     if (tab_init)
131 	return;
132     tab_init = 1;
133 
134     /* use 3 as primitive root to generate power and log tables */
135 
136     ltab[0]=0;
137     ptab[0]=1;  ltab[1]=0;
138     ptab[1]=3;  ltab[3]=1;
139     for (i=2; i<256; i++)
140     {
141 	ptab[i]=ptab[i-1]^xtime(ptab[i-1]);
142 	ltab[ptab[i]]=i;
143     }
144 
145     /* affine transformation:- each bit is xored with itself shifted one bit */
146 
147     fbsub[0]=0x63;
148     rbsub[0x63]=0;
149     for (i=1; i<256; i++)
150     {
151 	y=ByteSub((u8)i);
152 	fbsub[i]=y; rbsub[y]=i;
153     }
154 
155     for (i=0,y=1; i<30; i++)
156     {
157 	rco[i]=y;
158 	y=xtime(y);
159     }
160 
161     /* calculate forward and reverse tables */
162     for (i=0; i<256; i++)
163     {
164 	y=fbsub[i];
165 	b[3]=y^xtime(y); b[2]=y;
166 	b[1]=y;          b[0]=xtime(y);
167 	ftable[i]=pack(b);
168 
169 	y=rbsub[i];
170 	b[3]=bmul(InCo[0],y); b[2]=bmul(InCo[1],y);
171 	b[1]=bmul(InCo[2],y); b[0]=bmul(InCo[3],y);
172 	rtable[i]=pack(b);
173     }
174 
175     /* calculate rotated tables */
176     for ( i=0; i<256; i++ )
177     {
178 	ftable8 [i] = ROTL8 (ftable[i]);
179 	ftable16[i] = ROTL16(ftable[i]);
180 	ftable24[i] = ROTL24(ftable[i]);
181 
182 	rtable8 [i] = ROTL8 (rtable[i]);
183 	rtable16[i] = ROTL16(rtable[i]);
184 	rtable24[i] = ROTL24(rtable[i]);
185     }
186 }
187 
gkey(aes_key_t * akey,int nb,int nk,char * key)188 static void gkey ( aes_key_t *akey, int nb, int nk, char *key )
189 {
190     /* blocksize=32*nb bits. Key=32*nk bits */
191     /* currently nb,bk = 4, 6 or 8          */
192     /* key comes as 4*akey->Nk bytes              */
193     /* Key Scheduler. Create expanded encryption key */
194     int i,j,k,m,N;
195     int C1,C2,C3;
196     u32 CipherKey[8];
197 
198     akey->Nb=nb; akey->Nk=nk;
199 
200     /* akey->Nr is number of rounds */
201     if (akey->Nb>=akey->Nk) akey->Nr=6+akey->Nb;
202     else        akey->Nr=6+akey->Nk;
203 
204     C1=1;
205     if (akey->Nb<8)
206     {
207 	C2=2; C3=3;
208     }
209     else
210     {
211 	C2=3; C3=4;
212     }
213 
214     /* pre-calculate forward and reverse increments */
215     for (m=j=0; j<nb; j++,m+=3)
216     {
217 	akey->fi[m]=(j+C1)%nb;
218 	akey->fi[m+1]=(j+C2)%nb;
219 	akey->fi[m+2]=(j+C3)%nb;
220 	akey->ri[m]=(nb+j-C1)%nb;
221 	akey->ri[m+1]=(nb+j-C2)%nb;
222 	akey->ri[m+2]=(nb+j-C3)%nb;
223     }
224 
225     N=akey->Nb*(akey->Nr+1);
226 
227     for (i=j=0; i<akey->Nk; i++,j+=4)
228     {
229 	CipherKey[i]=pack((u8 *)&key[j]);
230     }
231     for (i=0; i<akey->Nk; i++) akey->fkey[i]=CipherKey[i];
232     for (j=akey->Nk,k=0; j<N; j+=akey->Nk,k++)
233     {
234 	akey->fkey[j]=akey->fkey[j-akey->Nk]^SubByte(ROTL24(akey->fkey[j-1]))^rco[k];
235 	if (akey->Nk<=6)
236 	{
237 	    for (i=1; i<akey->Nk && (i+j)<N; i++)
238 		akey->fkey[i+j]=akey->fkey[i+j-akey->Nk]^akey->fkey[i+j-1];
239 	}
240 	else
241 	{
242 	    for (i=1; i<4 &&(i+j)<N; i++)
243 		akey->fkey[i+j]=akey->fkey[i+j-akey->Nk]^akey->fkey[i+j-1];
244 	    if ((j+4)<N) akey->fkey[j+4]=akey->fkey[j+4-akey->Nk]^SubByte(akey->fkey[j+3]);
245 	    for (i=5; i<akey->Nk && (i+j)<N; i++)
246 		akey->fkey[i+j]=akey->fkey[i+j-akey->Nk]^akey->fkey[i+j-1];
247 	}
248     }
249 
250     /* now for the expanded decrypt key in reverse order */
251 
252     for (j=0; j<akey->Nb; j++) akey->rkey[j+N-akey->Nb]=akey->fkey[j];
253     for (i=akey->Nb; i<N-akey->Nb; i+=akey->Nb)
254     {
255 	k=N-akey->Nb-i;
256 	for (j=0; j<akey->Nb; j++) akey->rkey[k+j]=InvMixCol(akey->fkey[i+j]);
257     }
258     for (j=N-akey->Nb; j<N; j++) akey->rkey[j-N+akey->Nb]=akey->fkey[j];
259 }
260 
261 
262 /* There is an obvious time/space trade-off possible here.     *
263  * Instead of just one ftable[], I could have 4, the other     *
264  * 3 pre-rotated to save the ROTL8, ROTL16 and ROTL24 overhead */
265 
encrypt(const aes_key_t * akey,char * buff)266 static void encrypt ( const aes_key_t * akey, char *buff )
267 {
268     int i,j,k,m;
269     u32 a[8],b[8],*x,*y,*t;
270 
271     for (i=j=0; i<akey->Nb; i++,j+=4)
272     {
273 	a[i]=pack((u8 *)&buff[j]);
274 	a[i]^=akey->fkey[i];
275     }
276     k=akey->Nb;
277     x=a; y=b;
278 
279     /* State alternates between a and b */
280     for (i=1; i<akey->Nr; i++)
281     { /* akey->Nr is number of rounds. May be odd. */
282 
283 	/* if akey->Nb is fixed - unroll this next
284 	   loop and hard-code in the values of akey->fi[]  */
285 
286 	for (m=j=0; j<akey->Nb; j++,m+=3)
287 	{ /* deal with each 32-bit element of the State */
288 	    /* This is the time-critical bit */
289 	    y[j]=akey->fkey[k++]^ftable[(u8)x[j]]^
290 		 ftable8 [(u8)(x[akey->fi[m]]>>8)]^
291 		 ftable16[(u8)(x[akey->fi[m+1]]>>16)]^
292 		 ftable24[(u8)(x[akey->fi[m+2]]>>24)];
293 	}
294 	t=x; x=y; y=t;      /* swap pointers */
295     }
296 
297     /* Last Round - unroll if possible */
298     for (m=j=0; j<akey->Nb; j++,m+=3)
299     {
300 	y[j]=akey->fkey[k++]^(u32)fbsub[(u8)x[j]]^
301 	     ROTL8((u32)fbsub[(u8)(x[akey->fi[m]]>>8)])^
302 	     ROTL16((u32)fbsub[(u8)(x[akey->fi[m+1]]>>16)])^
303 	     ROTL24((u32)fbsub[(u8)(x[akey->fi[m+2]]>>24)]);
304     }
305     for (i=j=0; i<akey->Nb; i++,j+=4)
306     {
307 	unpack(y[i],(u8 *)&buff[j]);
308 	x[i]=y[i]=0;   /* clean up stack */
309     }
310     return;
311 }
312 
decrypt(const aes_key_t * akey,char * buff)313 static void decrypt ( const aes_key_t * akey, char *buff )
314 {
315     int i,j,k,m;
316     u32 a[8],b[8],*x,*y,*t;
317 
318     for (i=j=0; i<akey->Nb; i++,j+=4)
319     {
320 	a[i]=pack((u8 *)&buff[j]);
321 	a[i]^=akey->rkey[i];
322     }
323     k=akey->Nb;
324     x=a; y=b;
325 
326     /* State alternates between a and b */
327     for (i=1; i<akey->Nr; i++)
328     { /* akey->Nr is number of rounds. May be odd. */
329 
330 	/* if akey->Nb is fixed - unroll this next
331 	   loop and hard-code in the values of akey->ri[]  */
332 
333 	for (m=j=0; j<akey->Nb; j++,m+=3)
334 	{ /* This is the time-critical bit */
335 	    y[j]=akey->rkey[k++]^rtable[(u8)x[j]]^
336 		 rtable8 [(u8)(x[akey->ri[m]]>>8)]^
337 		 rtable16[(u8)(x[akey->ri[m+1]]>>16)]^
338 		 rtable24[(u8)(x[akey->ri[m+2]]>>24)];
339 	}
340 	t=x; x=y; y=t;      /* swap pointers */
341     }
342 
343     /* Last Round - unroll if possible */
344     for (m=j=0; j<akey->Nb; j++,m+=3)
345     {
346 	y[j]=akey->rkey[k++]^(u32)rbsub[(u8)x[j]]^
347 	     ROTL8((u32)rbsub[(u8)(x[akey->ri[m]]>>8)])^
348 	     ROTL16((u32)rbsub[(u8)(x[akey->ri[m+1]]>>16)])^
349 	     ROTL24((u32)rbsub[(u8)(x[akey->ri[m+2]]>>24)]);
350     }
351     for (i=j=0; i<akey->Nb; i++,j+=4)
352     {
353 	unpack(y[i],(u8 *)&buff[j]);
354 	x[i]=y[i]=0;   /* clean up stack */
355     }
356     return;
357 }
358 
wd_aes_set_key(aes_key_t * akey,const void * key)359 void wd_aes_set_key ( aes_key_t * akey, const void * key )
360 {
361     gentables();
362     gkey( akey, 4, 4, (char*)key );
363 }
364 
365 // CBC mode decryption
wd_aes_decrypt(const aes_key_t * akey,const void * p_iv,const void * p_inbuf,void * p_outbuf,u64 len)366 void wd_aes_decrypt
367 (
368 	const aes_key_t * akey,
369 	const void *p_iv,
370 	const void *p_inbuf,
371 	void *p_outbuf,
372 	u64 len
373 )
374 {
375     const u8 * iv	= p_iv;
376     const u8 * inbuf	= p_inbuf;
377 	  u8 * outbuf	= p_outbuf;
378 
379     ASSERT( inbuf != outbuf ); //no inplace decryption possible
380 
381     u8 block[16];
382     const u8 *ctext_ptr;
383     unsigned int blockno = 0, i;
384 
385     //printf("aes_decrypt(%p, %p, %p, %lld)\n", iv, inbuf, outbuf, len);
386 
387     for (blockno = 0; blockno <= (len / sizeof(block)); blockno++)
388     {
389 	unsigned int fraction;
390 	if (blockno == (len / sizeof(block)))   // last block
391 	{
392 	    fraction = len % sizeof(block);
393 	    if (fraction == 0) break;
394 	    memset(block, 0, sizeof(block));
395 	}
396 	else fraction = 16;
397 
398 	//    debug_printf("block %d: fraction = %d\n", blockno, fraction);
399 	memcpy(block, inbuf + blockno * sizeof(block), fraction);
400 	decrypt(akey,(char*)block);
401 
402 	if (blockno == 0)
403 	{
404 	    ctext_ptr = iv;
405 	}
406 	else
407 	{
408 	    ctext_ptr = inbuf + (blockno-1) * sizeof(block);
409 	}
410 
411 	for (i=0; i < fraction; i++)
412 	    outbuf[blockno * sizeof(block) + i] =
413 		ctext_ptr[i] ^ block[i];
414 	//    debug_printf("Block %d output: ", blockno);
415 	//    hexdump(outbuf + blockno*sizeof(block), 16);
416     }
417 }
418 
419 // CBC mode encryption
wd_aes_encrypt(const aes_key_t * akey,const void * p_iv,const void * p_inbuf,void * p_outbuf,u64 len)420 void wd_aes_encrypt
421 (
422 	const aes_key_t * akey,
423 	const void *p_iv,
424 	const void *p_inbuf,
425 	void *p_outbuf,
426 	u64 len
427 )
428 {
429     const u8 * inbuf	= p_inbuf;
430 	  u8 * outbuf	= p_outbuf;
431 
432     u8 block[16], iv[16];
433     memcpy(iv, p_iv, sizeof(iv));
434     unsigned int blockno = 0, i;
435 
436     //  debug_printf("aes_decrypt(%p, %p, %p, %lld)\n", iv, inbuf, outbuf, len);
437 
438     for (blockno = 0; blockno <= (len / sizeof(block)); blockno++)
439     {
440 	unsigned int fraction;
441 	if (blockno == (len / sizeof(block)))   // last block
442 	{
443 	    fraction = len % sizeof(block);
444 	    if (fraction == 0) break;
445 	    memset(block, 0, sizeof(block));
446 	}
447 	else fraction = 16;
448 
449 	//    debug_printf("block %d: fraction = %d\n", blockno, fraction);
450 	memcpy(block, inbuf + blockno * sizeof(block), fraction);
451 
452 	for (i=0; i < fraction; i++)
453 	    block[i] = inbuf[blockno * sizeof(block) + i] ^ iv[i];
454 
455 	encrypt(akey,(char*)block);
456 	memcpy(iv, block, sizeof(block));
457 	memcpy(outbuf + blockno * sizeof(block), block, sizeof(block));
458 	//    debug_printf("Block %d output: ", blockno);
459 	//    hexdump(outbuf + blockno*sizeof(block), 16);
460     }
461 }
462 
463