1 package org.bouncycastle.crypto.engines;
2 
3 import org.bouncycastle.crypto.CipherParameters;
4 import org.bouncycastle.crypto.DataLengthException;
5 import org.bouncycastle.crypto.MaxBytesExceededException;
6 import org.bouncycastle.crypto.OutputLengthException;
7 import org.bouncycastle.crypto.SkippingStreamCipher;
8 import org.bouncycastle.crypto.params.KeyParameter;
9 import org.bouncycastle.crypto.params.ParametersWithIV;
10 import org.bouncycastle.util.Integers;
11 import org.bouncycastle.util.Pack;
12 import org.bouncycastle.util.Strings;
13 
14 /**
15  * Implementation of Daniel J. Bernstein's Salsa20 stream cipher, Snuffle 2005
16  */
17 public class Salsa20Engine
18     implements SkippingStreamCipher
19 {
20     public final static int DEFAULT_ROUNDS = 20;
21 
22     /** Constants */
23     private final static int STATE_SIZE = 16; // 16, 32 bit ints = 64 bytes
24 
25     private final static int[] TAU_SIGMA = Pack.littleEndianToInt(Strings.toByteArray("expand 16-byte k" + "expand 32-byte k"), 0, 8);
26 
packTauOrSigma(int keyLength, int[] state, int stateOffset)27     protected void packTauOrSigma(int keyLength, int[] state, int stateOffset)
28     {
29         int tsOff = (keyLength - 16) / 4;
30         state[stateOffset    ] = TAU_SIGMA[tsOff    ];
31         state[stateOffset + 1] = TAU_SIGMA[tsOff + 1];
32         state[stateOffset + 2] = TAU_SIGMA[tsOff + 2];
33         state[stateOffset + 3] = TAU_SIGMA[tsOff + 3];
34     }
35 
36     /** @deprecated */
37     protected final static byte[]
38         sigma = Strings.toByteArray("expand 32-byte k"),
39         tau   = Strings.toByteArray("expand 16-byte k");
40 
41     protected int rounds;
42 
43     /*
44      * variables to hold the state of the engine
45      * during encryption and decryption
46      */
47     private int         index = 0;
48     protected int[]     engineState = new int[STATE_SIZE]; // state
49     protected int[]     x = new int[STATE_SIZE] ; // internal buffer
50     private byte[]      keyStream   = new byte[STATE_SIZE * 4]; // expanded state, 64 bytes
51     private boolean     initialised = false;
52 
53     /*
54      * internal counter
55      */
56     private int cW0, cW1, cW2;
57 
58     /**
59      * Creates a 20 round Salsa20 engine.
60      */
Salsa20Engine()61     public Salsa20Engine()
62     {
63         this(DEFAULT_ROUNDS);
64     }
65 
66     /**
67      * Creates a Salsa20 engine with a specific number of rounds.
68      * @param rounds the number of rounds (must be an even number).
69      */
Salsa20Engine(int rounds)70     public Salsa20Engine(int rounds)
71     {
72         if (rounds <= 0 || (rounds & 1) != 0)
73         {
74             throw new IllegalArgumentException("'rounds' must be a positive, even number");
75         }
76 
77         this.rounds = rounds;
78     }
79 
80     /**
81      * initialise a Salsa20 cipher.
82      *
83      * @param forEncryption whether or not we are for encryption.
84      * @param params the parameters required to set up the cipher.
85      * @exception IllegalArgumentException if the params argument is
86      * inappropriate.
87      */
init( boolean forEncryption, CipherParameters params)88     public void init(
89         boolean             forEncryption,
90         CipherParameters     params)
91     {
92         /*
93         * Salsa20 encryption and decryption is completely
94         * symmetrical, so the 'forEncryption' is
95         * irrelevant. (Like 90% of stream ciphers)
96         */
97 
98         if (!(params instanceof ParametersWithIV))
99         {
100             throw new IllegalArgumentException(getAlgorithmName() + " Init parameters must include an IV");
101         }
102 
103         ParametersWithIV ivParams = (ParametersWithIV) params;
104 
105         byte[] iv = ivParams.getIV();
106         if (iv == null || iv.length != getNonceSize())
107         {
108             throw new IllegalArgumentException(getAlgorithmName() + " requires exactly " + getNonceSize()
109                     + " bytes of IV");
110         }
111 
112         CipherParameters keyParam = ivParams.getParameters();
113         if (keyParam == null)
114         {
115             if (!initialised)
116             {
117                 throw new IllegalStateException(getAlgorithmName() + " KeyParameter can not be null for first initialisation");
118             }
119 
120             setKey(null, iv);
121         }
122         else if (keyParam instanceof KeyParameter)
123         {
124             setKey(((KeyParameter)keyParam).getKey(), iv);
125         }
126         else
127         {
128             throw new IllegalArgumentException(getAlgorithmName() + " Init parameters must contain a KeyParameter (or null for re-init)");
129         }
130 
131         reset();
132 
133         initialised = true;
134     }
135 
getNonceSize()136     protected int getNonceSize()
137     {
138         return 8;
139     }
140 
getAlgorithmName()141     public String getAlgorithmName()
142     {
143         String name = "Salsa20";
144         if (rounds != DEFAULT_ROUNDS)
145         {
146             name += "/" + rounds;
147         }
148         return name;
149     }
150 
returnByte(byte in)151     public byte returnByte(byte in)
152     {
153         if (limitExceeded())
154         {
155             throw new MaxBytesExceededException("2^70 byte limit per IV; Change IV");
156         }
157 
158         byte out = (byte)(keyStream[index]^in);
159         index = (index + 1) & 63;
160 
161         if (index == 0)
162         {
163             advanceCounter();
164             generateKeyStream(keyStream);
165         }
166 
167         return out;
168     }
169 
advanceCounter(long diff)170     protected void advanceCounter(long diff)
171     {
172         int hi = (int)(diff >>> 32);
173         int lo = (int)diff;
174 
175         if (hi > 0)
176         {
177             engineState[9] += hi;
178         }
179 
180         int oldState = engineState[8];
181 
182         engineState[8] += lo;
183 
184         if (oldState != 0 && engineState[8] < oldState)
185         {
186             engineState[9]++;
187         }
188     }
189 
advanceCounter()190     protected void advanceCounter()
191     {
192         if (++engineState[8] == 0)
193         {
194             ++engineState[9];
195         }
196     }
197 
retreatCounter(long diff)198     protected void retreatCounter(long diff)
199     {
200         int hi = (int)(diff >>> 32);
201         int lo = (int)diff;
202 
203         if (hi != 0)
204         {
205             if ((engineState[9] & 0xffffffffL) >= (hi & 0xffffffffL))
206             {
207                 engineState[9] -= hi;
208             }
209             else
210             {
211                 throw new IllegalStateException("attempt to reduce counter past zero.");
212             }
213         }
214 
215         if ((engineState[8] & 0xffffffffL) >= (lo & 0xffffffffL))
216         {
217             engineState[8] -= lo;
218         }
219         else
220         {
221             if (engineState[9] != 0)
222             {
223                 --engineState[9];
224                 engineState[8] -= lo;
225             }
226             else
227             {
228                 throw new IllegalStateException("attempt to reduce counter past zero.");
229             }
230         }
231     }
232 
retreatCounter()233     protected void retreatCounter()
234     {
235         if (engineState[8] == 0 && engineState[9] == 0)
236         {
237             throw new IllegalStateException("attempt to reduce counter past zero.");
238         }
239 
240         if (--engineState[8] == -1)
241         {
242             --engineState[9];
243         }
244     }
245 
processBytes( byte[] in, int inOff, int len, byte[] out, int outOff)246     public int processBytes(
247         byte[]     in,
248         int     inOff,
249         int     len,
250         byte[]     out,
251         int     outOff)
252     {
253         if (!initialised)
254         {
255             throw new IllegalStateException(getAlgorithmName() + " not initialised");
256         }
257 
258         if ((inOff + len) > in.length)
259         {
260             throw new DataLengthException("input buffer too short");
261         }
262 
263         if ((outOff + len) > out.length)
264         {
265             throw new OutputLengthException("output buffer too short");
266         }
267 
268         if (limitExceeded(len))
269         {
270             throw new MaxBytesExceededException("2^70 byte limit per IV would be exceeded; Change IV");
271         }
272 
273         for (int i = 0; i < len; i++)
274         {
275             out[i + outOff] = (byte)(keyStream[index] ^ in[i + inOff]);
276             index = (index + 1) & 63;
277 
278             if (index == 0)
279             {
280                 advanceCounter();
281                 generateKeyStream(keyStream);
282             }
283         }
284 
285         return len;
286     }
287 
skip(long numberOfBytes)288     public long skip(long numberOfBytes)
289     {
290         if (numberOfBytes >= 0)
291         {
292             long remaining = numberOfBytes;
293 
294             if (remaining >= 64)
295             {
296                 long count = remaining / 64;
297 
298                 advanceCounter(count);
299 
300                 remaining -= count * 64;
301             }
302 
303             int oldIndex = index;
304 
305             index = (index + (int)remaining) & 63;
306 
307             if (index < oldIndex)
308             {
309                 advanceCounter();
310             }
311         }
312         else
313         {
314             long remaining = -numberOfBytes;
315 
316             if (remaining >= 64)
317             {
318                 long count = remaining / 64;
319 
320                 retreatCounter(count);
321 
322                 remaining -= count * 64;
323             }
324 
325             for (long i = 0; i < remaining; i++)
326             {
327                 if (index == 0)
328                 {
329                     retreatCounter();
330                 }
331 
332                 index = (index - 1) & 63;
333             }
334         }
335 
336         generateKeyStream(keyStream);
337 
338         return numberOfBytes;
339     }
340 
seekTo(long position)341     public long seekTo(long position)
342     {
343         reset();
344 
345         return skip(position);
346     }
347 
getPosition()348     public long getPosition()
349     {
350         return getCounter() * 64 + index;
351     }
352 
reset()353     public void reset()
354     {
355         index = 0;
356         resetLimitCounter();
357         resetCounter();
358 
359         generateKeyStream(keyStream);
360     }
361 
getCounter()362     protected long getCounter()
363     {
364         return ((long)engineState[9] << 32) | (engineState[8] & 0xffffffffL);
365     }
366 
resetCounter()367     protected void resetCounter()
368     {
369         engineState[8] = engineState[9] = 0;
370     }
371 
setKey(byte[] keyBytes, byte[] ivBytes)372     protected void setKey(byte[] keyBytes, byte[] ivBytes)
373     {
374         if (keyBytes != null)
375         {
376             if ((keyBytes.length != 16) && (keyBytes.length != 32))
377             {
378                 throw new IllegalArgumentException(getAlgorithmName() + " requires 128 bit or 256 bit key");
379             }
380 
381             int tsOff = (keyBytes.length - 16) / 4;
382             engineState[0 ] = TAU_SIGMA[tsOff    ];
383             engineState[5 ] = TAU_SIGMA[tsOff + 1];
384             engineState[10] = TAU_SIGMA[tsOff + 2];
385             engineState[15] = TAU_SIGMA[tsOff + 3];
386 
387             // Key
388             Pack.littleEndianToInt(keyBytes, 0, engineState, 1, 4);
389             Pack.littleEndianToInt(keyBytes, keyBytes.length - 16, engineState, 11, 4);
390         }
391 
392         // IV
393         Pack.littleEndianToInt(ivBytes, 0, engineState, 6, 2);
394     }
395 
generateKeyStream(byte[] output)396     protected void generateKeyStream(byte[] output)
397     {
398         salsaCore(rounds, engineState, x);
399         Pack.intToLittleEndian(x, output, 0);
400     }
401 
402     /**
403      * Salsa20 function
404      *
405      * @param   input   input data
406      */
salsaCore(int rounds, int[] input, int[] x)407     public static void salsaCore(int rounds, int[] input, int[] x)
408     {
409         if (input.length != 16)
410         {
411             throw new IllegalArgumentException();
412         }
413         if (x.length != 16)
414         {
415             throw new IllegalArgumentException();
416         }
417         if (rounds % 2 != 0)
418         {
419             throw new IllegalArgumentException("Number of rounds must be even");
420         }
421 
422         int x00 = input[ 0];
423         int x01 = input[ 1];
424         int x02 = input[ 2];
425         int x03 = input[ 3];
426         int x04 = input[ 4];
427         int x05 = input[ 5];
428         int x06 = input[ 6];
429         int x07 = input[ 7];
430         int x08 = input[ 8];
431         int x09 = input[ 9];
432         int x10 = input[10];
433         int x11 = input[11];
434         int x12 = input[12];
435         int x13 = input[13];
436         int x14 = input[14];
437         int x15 = input[15];
438 
439         for (int i = rounds; i > 0; i -= 2)
440         {
441             x04 ^= Integers.rotateLeft(x00 + x12, 7);
442             x08 ^= Integers.rotateLeft(x04 + x00, 9);
443             x12 ^= Integers.rotateLeft(x08 + x04, 13);
444             x00 ^= Integers.rotateLeft(x12 + x08, 18);
445             x09 ^= Integers.rotateLeft(x05 + x01, 7);
446             x13 ^= Integers.rotateLeft(x09 + x05, 9);
447             x01 ^= Integers.rotateLeft(x13 + x09, 13);
448             x05 ^= Integers.rotateLeft(x01 + x13, 18);
449             x14 ^= Integers.rotateLeft(x10 + x06, 7);
450             x02 ^= Integers.rotateLeft(x14 + x10, 9);
451             x06 ^= Integers.rotateLeft(x02 + x14, 13);
452             x10 ^= Integers.rotateLeft(x06 + x02, 18);
453             x03 ^= Integers.rotateLeft(x15 + x11, 7);
454             x07 ^= Integers.rotateLeft(x03 + x15, 9);
455             x11 ^= Integers.rotateLeft(x07 + x03, 13);
456             x15 ^= Integers.rotateLeft(x11 + x07, 18);
457 
458             x01 ^= Integers.rotateLeft(x00 + x03, 7);
459             x02 ^= Integers.rotateLeft(x01 + x00, 9);
460             x03 ^= Integers.rotateLeft(x02 + x01, 13);
461             x00 ^= Integers.rotateLeft(x03 + x02, 18);
462             x06 ^= Integers.rotateLeft(x05 + x04, 7);
463             x07 ^= Integers.rotateLeft(x06 + x05, 9);
464             x04 ^= Integers.rotateLeft(x07 + x06, 13);
465             x05 ^= Integers.rotateLeft(x04 + x07, 18);
466             x11 ^= Integers.rotateLeft(x10 + x09, 7);
467             x08 ^= Integers.rotateLeft(x11 + x10, 9);
468             x09 ^= Integers.rotateLeft(x08 + x11, 13);
469             x10 ^= Integers.rotateLeft(x09 + x08, 18);
470             x12 ^= Integers.rotateLeft(x15 + x14, 7);
471             x13 ^= Integers.rotateLeft(x12 + x15, 9);
472             x14 ^= Integers.rotateLeft(x13 + x12, 13);
473             x15 ^= Integers.rotateLeft(x14 + x13, 18);
474         }
475 
476         x[ 0] = x00 + input[ 0];
477         x[ 1] = x01 + input[ 1];
478         x[ 2] = x02 + input[ 2];
479         x[ 3] = x03 + input[ 3];
480         x[ 4] = x04 + input[ 4];
481         x[ 5] = x05 + input[ 5];
482         x[ 6] = x06 + input[ 6];
483         x[ 7] = x07 + input[ 7];
484         x[ 8] = x08 + input[ 8];
485         x[ 9] = x09 + input[ 9];
486         x[10] = x10 + input[10];
487         x[11] = x11 + input[11];
488         x[12] = x12 + input[12];
489         x[13] = x13 + input[13];
490         x[14] = x14 + input[14];
491         x[15] = x15 + input[15];
492     }
493 
resetLimitCounter()494     private void resetLimitCounter()
495     {
496         cW0 = 0;
497         cW1 = 0;
498         cW2 = 0;
499     }
500 
limitExceeded()501     private boolean limitExceeded()
502     {
503         if (++cW0 == 0)
504         {
505             if (++cW1 == 0)
506             {
507                 return (++cW2 & 0x20) != 0;          // 2^(32 + 32 + 6)
508             }
509         }
510 
511         return false;
512     }
513 
514     /*
515      * this relies on the fact len will always be positive.
516      */
limitExceeded(int len)517     private boolean limitExceeded(int len)
518     {
519         cW0 += len;
520         if (cW0 < len && cW0 >= 0)
521         {
522             if (++cW1 == 0)
523             {
524                 return (++cW2 & 0x20) != 0;          // 2^(32 + 32 + 6)
525             }
526         }
527 
528         return false;
529     }
530 }
531