1 package org.bouncycastle.crypto.engines; 2 3 import org.bouncycastle.crypto.CipherParameters; 4 import org.bouncycastle.crypto.DataLengthException; 5 import org.bouncycastle.crypto.MaxBytesExceededException; 6 import org.bouncycastle.crypto.OutputLengthException; 7 import org.bouncycastle.crypto.SkippingStreamCipher; 8 import org.bouncycastle.crypto.params.KeyParameter; 9 import org.bouncycastle.crypto.params.ParametersWithIV; 10 import org.bouncycastle.util.Integers; 11 import org.bouncycastle.util.Pack; 12 import org.bouncycastle.util.Strings; 13 14 /** 15 * Implementation of Daniel J. Bernstein's Salsa20 stream cipher, Snuffle 2005 16 */ 17 public class Salsa20Engine 18 implements SkippingStreamCipher 19 { 20 public final static int DEFAULT_ROUNDS = 20; 21 22 /** Constants */ 23 private final static int STATE_SIZE = 16; // 16, 32 bit ints = 64 bytes 24 25 private final static int[] TAU_SIGMA = Pack.littleEndianToInt(Strings.toByteArray("expand 16-byte k" + "expand 32-byte k"), 0, 8); 26 packTauOrSigma(int keyLength, int[] state, int stateOffset)27 protected void packTauOrSigma(int keyLength, int[] state, int stateOffset) 28 { 29 int tsOff = (keyLength - 16) / 4; 30 state[stateOffset ] = TAU_SIGMA[tsOff ]; 31 state[stateOffset + 1] = TAU_SIGMA[tsOff + 1]; 32 state[stateOffset + 2] = TAU_SIGMA[tsOff + 2]; 33 state[stateOffset + 3] = TAU_SIGMA[tsOff + 3]; 34 } 35 36 /** @deprecated */ 37 protected final static byte[] 38 sigma = Strings.toByteArray("expand 32-byte k"), 39 tau = Strings.toByteArray("expand 16-byte k"); 40 41 protected int rounds; 42 43 /* 44 * variables to hold the state of the engine 45 * during encryption and decryption 46 */ 47 private int index = 0; 48 protected int[] engineState = new int[STATE_SIZE]; // state 49 protected int[] x = new int[STATE_SIZE] ; // internal buffer 50 private byte[] keyStream = new byte[STATE_SIZE * 4]; // expanded state, 64 bytes 51 private boolean initialised = false; 52 53 /* 54 * internal counter 55 */ 56 private int cW0, cW1, cW2; 57 58 /** 59 * Creates a 20 round Salsa20 engine. 60 */ Salsa20Engine()61 public Salsa20Engine() 62 { 63 this(DEFAULT_ROUNDS); 64 } 65 66 /** 67 * Creates a Salsa20 engine with a specific number of rounds. 68 * @param rounds the number of rounds (must be an even number). 69 */ Salsa20Engine(int rounds)70 public Salsa20Engine(int rounds) 71 { 72 if (rounds <= 0 || (rounds & 1) != 0) 73 { 74 throw new IllegalArgumentException("'rounds' must be a positive, even number"); 75 } 76 77 this.rounds = rounds; 78 } 79 80 /** 81 * initialise a Salsa20 cipher. 82 * 83 * @param forEncryption whether or not we are for encryption. 84 * @param params the parameters required to set up the cipher. 85 * @exception IllegalArgumentException if the params argument is 86 * inappropriate. 87 */ init( boolean forEncryption, CipherParameters params)88 public void init( 89 boolean forEncryption, 90 CipherParameters params) 91 { 92 /* 93 * Salsa20 encryption and decryption is completely 94 * symmetrical, so the 'forEncryption' is 95 * irrelevant. (Like 90% of stream ciphers) 96 */ 97 98 if (!(params instanceof ParametersWithIV)) 99 { 100 throw new IllegalArgumentException(getAlgorithmName() + " Init parameters must include an IV"); 101 } 102 103 ParametersWithIV ivParams = (ParametersWithIV) params; 104 105 byte[] iv = ivParams.getIV(); 106 if (iv == null || iv.length != getNonceSize()) 107 { 108 throw new IllegalArgumentException(getAlgorithmName() + " requires exactly " + getNonceSize() 109 + " bytes of IV"); 110 } 111 112 CipherParameters keyParam = ivParams.getParameters(); 113 if (keyParam == null) 114 { 115 if (!initialised) 116 { 117 throw new IllegalStateException(getAlgorithmName() + " KeyParameter can not be null for first initialisation"); 118 } 119 120 setKey(null, iv); 121 } 122 else if (keyParam instanceof KeyParameter) 123 { 124 setKey(((KeyParameter)keyParam).getKey(), iv); 125 } 126 else 127 { 128 throw new IllegalArgumentException(getAlgorithmName() + " Init parameters must contain a KeyParameter (or null for re-init)"); 129 } 130 131 reset(); 132 133 initialised = true; 134 } 135 getNonceSize()136 protected int getNonceSize() 137 { 138 return 8; 139 } 140 getAlgorithmName()141 public String getAlgorithmName() 142 { 143 String name = "Salsa20"; 144 if (rounds != DEFAULT_ROUNDS) 145 { 146 name += "/" + rounds; 147 } 148 return name; 149 } 150 returnByte(byte in)151 public byte returnByte(byte in) 152 { 153 if (limitExceeded()) 154 { 155 throw new MaxBytesExceededException("2^70 byte limit per IV; Change IV"); 156 } 157 158 byte out = (byte)(keyStream[index]^in); 159 index = (index + 1) & 63; 160 161 if (index == 0) 162 { 163 advanceCounter(); 164 generateKeyStream(keyStream); 165 } 166 167 return out; 168 } 169 advanceCounter(long diff)170 protected void advanceCounter(long diff) 171 { 172 int hi = (int)(diff >>> 32); 173 int lo = (int)diff; 174 175 if (hi > 0) 176 { 177 engineState[9] += hi; 178 } 179 180 int oldState = engineState[8]; 181 182 engineState[8] += lo; 183 184 if (oldState != 0 && engineState[8] < oldState) 185 { 186 engineState[9]++; 187 } 188 } 189 advanceCounter()190 protected void advanceCounter() 191 { 192 if (++engineState[8] == 0) 193 { 194 ++engineState[9]; 195 } 196 } 197 retreatCounter(long diff)198 protected void retreatCounter(long diff) 199 { 200 int hi = (int)(diff >>> 32); 201 int lo = (int)diff; 202 203 if (hi != 0) 204 { 205 if ((engineState[9] & 0xffffffffL) >= (hi & 0xffffffffL)) 206 { 207 engineState[9] -= hi; 208 } 209 else 210 { 211 throw new IllegalStateException("attempt to reduce counter past zero."); 212 } 213 } 214 215 if ((engineState[8] & 0xffffffffL) >= (lo & 0xffffffffL)) 216 { 217 engineState[8] -= lo; 218 } 219 else 220 { 221 if (engineState[9] != 0) 222 { 223 --engineState[9]; 224 engineState[8] -= lo; 225 } 226 else 227 { 228 throw new IllegalStateException("attempt to reduce counter past zero."); 229 } 230 } 231 } 232 retreatCounter()233 protected void retreatCounter() 234 { 235 if (engineState[8] == 0 && engineState[9] == 0) 236 { 237 throw new IllegalStateException("attempt to reduce counter past zero."); 238 } 239 240 if (--engineState[8] == -1) 241 { 242 --engineState[9]; 243 } 244 } 245 processBytes( byte[] in, int inOff, int len, byte[] out, int outOff)246 public int processBytes( 247 byte[] in, 248 int inOff, 249 int len, 250 byte[] out, 251 int outOff) 252 { 253 if (!initialised) 254 { 255 throw new IllegalStateException(getAlgorithmName() + " not initialised"); 256 } 257 258 if ((inOff + len) > in.length) 259 { 260 throw new DataLengthException("input buffer too short"); 261 } 262 263 if ((outOff + len) > out.length) 264 { 265 throw new OutputLengthException("output buffer too short"); 266 } 267 268 if (limitExceeded(len)) 269 { 270 throw new MaxBytesExceededException("2^70 byte limit per IV would be exceeded; Change IV"); 271 } 272 273 for (int i = 0; i < len; i++) 274 { 275 out[i + outOff] = (byte)(keyStream[index] ^ in[i + inOff]); 276 index = (index + 1) & 63; 277 278 if (index == 0) 279 { 280 advanceCounter(); 281 generateKeyStream(keyStream); 282 } 283 } 284 285 return len; 286 } 287 skip(long numberOfBytes)288 public long skip(long numberOfBytes) 289 { 290 if (numberOfBytes >= 0) 291 { 292 long remaining = numberOfBytes; 293 294 if (remaining >= 64) 295 { 296 long count = remaining / 64; 297 298 advanceCounter(count); 299 300 remaining -= count * 64; 301 } 302 303 int oldIndex = index; 304 305 index = (index + (int)remaining) & 63; 306 307 if (index < oldIndex) 308 { 309 advanceCounter(); 310 } 311 } 312 else 313 { 314 long remaining = -numberOfBytes; 315 316 if (remaining >= 64) 317 { 318 long count = remaining / 64; 319 320 retreatCounter(count); 321 322 remaining -= count * 64; 323 } 324 325 for (long i = 0; i < remaining; i++) 326 { 327 if (index == 0) 328 { 329 retreatCounter(); 330 } 331 332 index = (index - 1) & 63; 333 } 334 } 335 336 generateKeyStream(keyStream); 337 338 return numberOfBytes; 339 } 340 seekTo(long position)341 public long seekTo(long position) 342 { 343 reset(); 344 345 return skip(position); 346 } 347 getPosition()348 public long getPosition() 349 { 350 return getCounter() * 64 + index; 351 } 352 reset()353 public void reset() 354 { 355 index = 0; 356 resetLimitCounter(); 357 resetCounter(); 358 359 generateKeyStream(keyStream); 360 } 361 getCounter()362 protected long getCounter() 363 { 364 return ((long)engineState[9] << 32) | (engineState[8] & 0xffffffffL); 365 } 366 resetCounter()367 protected void resetCounter() 368 { 369 engineState[8] = engineState[9] = 0; 370 } 371 setKey(byte[] keyBytes, byte[] ivBytes)372 protected void setKey(byte[] keyBytes, byte[] ivBytes) 373 { 374 if (keyBytes != null) 375 { 376 if ((keyBytes.length != 16) && (keyBytes.length != 32)) 377 { 378 throw new IllegalArgumentException(getAlgorithmName() + " requires 128 bit or 256 bit key"); 379 } 380 381 int tsOff = (keyBytes.length - 16) / 4; 382 engineState[0 ] = TAU_SIGMA[tsOff ]; 383 engineState[5 ] = TAU_SIGMA[tsOff + 1]; 384 engineState[10] = TAU_SIGMA[tsOff + 2]; 385 engineState[15] = TAU_SIGMA[tsOff + 3]; 386 387 // Key 388 Pack.littleEndianToInt(keyBytes, 0, engineState, 1, 4); 389 Pack.littleEndianToInt(keyBytes, keyBytes.length - 16, engineState, 11, 4); 390 } 391 392 // IV 393 Pack.littleEndianToInt(ivBytes, 0, engineState, 6, 2); 394 } 395 generateKeyStream(byte[] output)396 protected void generateKeyStream(byte[] output) 397 { 398 salsaCore(rounds, engineState, x); 399 Pack.intToLittleEndian(x, output, 0); 400 } 401 402 /** 403 * Salsa20 function 404 * 405 * @param input input data 406 */ salsaCore(int rounds, int[] input, int[] x)407 public static void salsaCore(int rounds, int[] input, int[] x) 408 { 409 if (input.length != 16) 410 { 411 throw new IllegalArgumentException(); 412 } 413 if (x.length != 16) 414 { 415 throw new IllegalArgumentException(); 416 } 417 if (rounds % 2 != 0) 418 { 419 throw new IllegalArgumentException("Number of rounds must be even"); 420 } 421 422 int x00 = input[ 0]; 423 int x01 = input[ 1]; 424 int x02 = input[ 2]; 425 int x03 = input[ 3]; 426 int x04 = input[ 4]; 427 int x05 = input[ 5]; 428 int x06 = input[ 6]; 429 int x07 = input[ 7]; 430 int x08 = input[ 8]; 431 int x09 = input[ 9]; 432 int x10 = input[10]; 433 int x11 = input[11]; 434 int x12 = input[12]; 435 int x13 = input[13]; 436 int x14 = input[14]; 437 int x15 = input[15]; 438 439 for (int i = rounds; i > 0; i -= 2) 440 { 441 x04 ^= Integers.rotateLeft(x00 + x12, 7); 442 x08 ^= Integers.rotateLeft(x04 + x00, 9); 443 x12 ^= Integers.rotateLeft(x08 + x04, 13); 444 x00 ^= Integers.rotateLeft(x12 + x08, 18); 445 x09 ^= Integers.rotateLeft(x05 + x01, 7); 446 x13 ^= Integers.rotateLeft(x09 + x05, 9); 447 x01 ^= Integers.rotateLeft(x13 + x09, 13); 448 x05 ^= Integers.rotateLeft(x01 + x13, 18); 449 x14 ^= Integers.rotateLeft(x10 + x06, 7); 450 x02 ^= Integers.rotateLeft(x14 + x10, 9); 451 x06 ^= Integers.rotateLeft(x02 + x14, 13); 452 x10 ^= Integers.rotateLeft(x06 + x02, 18); 453 x03 ^= Integers.rotateLeft(x15 + x11, 7); 454 x07 ^= Integers.rotateLeft(x03 + x15, 9); 455 x11 ^= Integers.rotateLeft(x07 + x03, 13); 456 x15 ^= Integers.rotateLeft(x11 + x07, 18); 457 458 x01 ^= Integers.rotateLeft(x00 + x03, 7); 459 x02 ^= Integers.rotateLeft(x01 + x00, 9); 460 x03 ^= Integers.rotateLeft(x02 + x01, 13); 461 x00 ^= Integers.rotateLeft(x03 + x02, 18); 462 x06 ^= Integers.rotateLeft(x05 + x04, 7); 463 x07 ^= Integers.rotateLeft(x06 + x05, 9); 464 x04 ^= Integers.rotateLeft(x07 + x06, 13); 465 x05 ^= Integers.rotateLeft(x04 + x07, 18); 466 x11 ^= Integers.rotateLeft(x10 + x09, 7); 467 x08 ^= Integers.rotateLeft(x11 + x10, 9); 468 x09 ^= Integers.rotateLeft(x08 + x11, 13); 469 x10 ^= Integers.rotateLeft(x09 + x08, 18); 470 x12 ^= Integers.rotateLeft(x15 + x14, 7); 471 x13 ^= Integers.rotateLeft(x12 + x15, 9); 472 x14 ^= Integers.rotateLeft(x13 + x12, 13); 473 x15 ^= Integers.rotateLeft(x14 + x13, 18); 474 } 475 476 x[ 0] = x00 + input[ 0]; 477 x[ 1] = x01 + input[ 1]; 478 x[ 2] = x02 + input[ 2]; 479 x[ 3] = x03 + input[ 3]; 480 x[ 4] = x04 + input[ 4]; 481 x[ 5] = x05 + input[ 5]; 482 x[ 6] = x06 + input[ 6]; 483 x[ 7] = x07 + input[ 7]; 484 x[ 8] = x08 + input[ 8]; 485 x[ 9] = x09 + input[ 9]; 486 x[10] = x10 + input[10]; 487 x[11] = x11 + input[11]; 488 x[12] = x12 + input[12]; 489 x[13] = x13 + input[13]; 490 x[14] = x14 + input[14]; 491 x[15] = x15 + input[15]; 492 } 493 resetLimitCounter()494 private void resetLimitCounter() 495 { 496 cW0 = 0; 497 cW1 = 0; 498 cW2 = 0; 499 } 500 limitExceeded()501 private boolean limitExceeded() 502 { 503 if (++cW0 == 0) 504 { 505 if (++cW1 == 0) 506 { 507 return (++cW2 & 0x20) != 0; // 2^(32 + 32 + 6) 508 } 509 } 510 511 return false; 512 } 513 514 /* 515 * this relies on the fact len will always be positive. 516 */ limitExceeded(int len)517 private boolean limitExceeded(int len) 518 { 519 cW0 += len; 520 if (cW0 < len && cW0 >= 0) 521 { 522 if (++cW1 == 0) 523 { 524 return (++cW2 & 0x20) != 0; // 2^(32 + 32 + 6) 525 } 526 } 527 528 return false; 529 } 530 } 531