1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package rsa implements RSA encryption as specified in PKCS#1.
6package rsa
7
8// TODO(agl): Add support for PSS padding.
9
10import (
11	"crypto/rand"
12	"crypto/subtle"
13	"errors"
14	"hash"
15	"io"
16	"math/big"
17)
18
19var bigZero = big.NewInt(0)
20var bigOne = big.NewInt(1)
21
22// A PublicKey represents the public part of an RSA key.
23type PublicKey struct {
24	N *big.Int // modulus
25	E int      // public exponent
26}
27
28var (
29	errPublicModulus       = errors.New("crypto/rsa: missing public modulus")
30	errPublicExponentSmall = errors.New("crypto/rsa: public exponent too small")
31	errPublicExponentLarge = errors.New("crypto/rsa: public exponent too large")
32)
33
34// checkPub sanity checks the public key before we use it.
35// We require pub.E to fit into a 32-bit integer so that we
36// do not have different behavior depending on whether
37// int is 32 or 64 bits. See also
38// http://www.imperialviolet.org/2012/03/16/rsae.html.
39func checkPub(pub *PublicKey) error {
40	if pub.N == nil {
41		return errPublicModulus
42	}
43	if pub.E < 2 {
44		return errPublicExponentSmall
45	}
46	if pub.E > 1<<31-1 {
47		return errPublicExponentLarge
48	}
49	return nil
50}
51
52// A PrivateKey represents an RSA key
53type PrivateKey struct {
54	PublicKey            // public part.
55	D         *big.Int   // private exponent
56	Primes    []*big.Int // prime factors of N, has >= 2 elements.
57
58	// Precomputed contains precomputed values that speed up private
59	// operations, if available.
60	Precomputed PrecomputedValues
61}
62
63type PrecomputedValues struct {
64	Dp, Dq *big.Int // D mod (P-1) (or mod Q-1)
65	Qinv   *big.Int // Q^-1 mod Q
66
67	// CRTValues is used for the 3rd and subsequent primes. Due to a
68	// historical accident, the CRT for the first two primes is handled
69	// differently in PKCS#1 and interoperability is sufficiently
70	// important that we mirror this.
71	CRTValues []CRTValue
72}
73
74// CRTValue contains the precomputed chinese remainder theorem values.
75type CRTValue struct {
76	Exp   *big.Int // D mod (prime-1).
77	Coeff *big.Int // R·Coeff ≡ 1 mod Prime.
78	R     *big.Int // product of primes prior to this (inc p and q).
79}
80
81// Validate performs basic sanity checks on the key.
82// It returns nil if the key is valid, or else an error describing a problem.
83func (priv *PrivateKey) Validate() error {
84	if err := checkPub(&priv.PublicKey); err != nil {
85		return err
86	}
87
88	// Check that the prime factors are actually prime. Note that this is
89	// just a sanity check. Since the random witnesses chosen by
90	// ProbablyPrime are deterministic, given the candidate number, it's
91	// easy for an attack to generate composites that pass this test.
92	for _, prime := range priv.Primes {
93		if !prime.ProbablyPrime(20) {
94			return errors.New("crypto/rsa: prime factor is composite")
95		}
96	}
97
98	// Check that Πprimes == n.
99	modulus := new(big.Int).Set(bigOne)
100	for _, prime := range priv.Primes {
101		modulus.Mul(modulus, prime)
102	}
103	if modulus.Cmp(priv.N) != 0 {
104		return errors.New("crypto/rsa: invalid modulus")
105	}
106
107	// Check that de ≡ 1 mod p-1, for each prime.
108	// This implies that e is coprime to each p-1 as e has a multiplicative
109	// inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
110	// exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
111	// mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
112	congruence := new(big.Int)
113	de := new(big.Int).SetInt64(int64(priv.E))
114	de.Mul(de, priv.D)
115	for _, prime := range priv.Primes {
116		pminus1 := new(big.Int).Sub(prime, bigOne)
117		congruence.Mod(de, pminus1)
118		if congruence.Cmp(bigOne) != 0 {
119			return errors.New("crypto/rsa: invalid exponents")
120		}
121	}
122	return nil
123}
124
125// GenerateKey generates an RSA keypair of the given bit size.
126func GenerateKey(random io.Reader, bits int) (priv *PrivateKey, err error) {
127	return GenerateMultiPrimeKey(random, 2, bits)
128}
129
130// GenerateMultiPrimeKey generates a multi-prime RSA keypair of the given bit
131// size, as suggested in [1]. Although the public keys are compatible
132// (actually, indistinguishable) from the 2-prime case, the private keys are
133// not. Thus it may not be possible to export multi-prime private keys in
134// certain formats or to subsequently import them into other code.
135//
136// Table 1 in [2] suggests maximum numbers of primes for a given size.
137//
138// [1] US patent 4405829 (1972, expired)
139// [2] http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf
140func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (priv *PrivateKey, err error) {
141	priv = new(PrivateKey)
142	priv.E = 65537
143
144	if nprimes < 2 {
145		return nil, errors.New("crypto/rsa: GenerateMultiPrimeKey: nprimes must be >= 2")
146	}
147
148	primes := make([]*big.Int, nprimes)
149
150NextSetOfPrimes:
151	for {
152		todo := bits
153		// crypto/rand should set the top two bits in each prime.
154		// Thus each prime has the form
155		//   p_i = 2^bitlen(p_i) × 0.11... (in base 2).
156		// And the product is:
157		//   P = 2^todo × α
158		// where α is the product of nprimes numbers of the form 0.11...
159		//
160		// If α < 1/2 (which can happen for nprimes > 2), we need to
161		// shift todo to compensate for lost bits: the mean value of 0.11...
162		// is 7/8, so todo + shift - nprimes * log2(7/8) ~= bits - 1/2
163		// will give good results.
164		if nprimes >= 7 {
165			todo += (nprimes - 2) / 5
166		}
167		for i := 0; i < nprimes; i++ {
168			primes[i], err = rand.Prime(random, todo/(nprimes-i))
169			if err != nil {
170				return nil, err
171			}
172			todo -= primes[i].BitLen()
173		}
174
175		// Make sure that primes is pairwise unequal.
176		for i, prime := range primes {
177			for j := 0; j < i; j++ {
178				if prime.Cmp(primes[j]) == 0 {
179					continue NextSetOfPrimes
180				}
181			}
182		}
183
184		n := new(big.Int).Set(bigOne)
185		totient := new(big.Int).Set(bigOne)
186		pminus1 := new(big.Int)
187		for _, prime := range primes {
188			n.Mul(n, prime)
189			pminus1.Sub(prime, bigOne)
190			totient.Mul(totient, pminus1)
191		}
192		if n.BitLen() != bits {
193			// This should never happen for nprimes == 2 because
194			// crypto/rand should set the top two bits in each prime.
195			// For nprimes > 2 we hope it does not happen often.
196			continue NextSetOfPrimes
197		}
198
199		g := new(big.Int)
200		priv.D = new(big.Int)
201		y := new(big.Int)
202		e := big.NewInt(int64(priv.E))
203		g.GCD(priv.D, y, e, totient)
204
205		if g.Cmp(bigOne) == 0 {
206			if priv.D.Sign() < 0 {
207				priv.D.Add(priv.D, totient)
208			}
209			priv.Primes = primes
210			priv.N = n
211
212			break
213		}
214	}
215
216	priv.Precompute()
217	return
218}
219
220// incCounter increments a four byte, big-endian counter.
221func incCounter(c *[4]byte) {
222	if c[3]++; c[3] != 0 {
223		return
224	}
225	if c[2]++; c[2] != 0 {
226		return
227	}
228	if c[1]++; c[1] != 0 {
229		return
230	}
231	c[0]++
232}
233
234// mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function
235// specified in PKCS#1 v2.1.
236func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
237	var counter [4]byte
238	var digest []byte
239
240	done := 0
241	for done < len(out) {
242		hash.Write(seed)
243		hash.Write(counter[0:4])
244		digest = hash.Sum(digest[:0])
245		hash.Reset()
246
247		for i := 0; i < len(digest) && done < len(out); i++ {
248			out[done] ^= digest[i]
249			done++
250		}
251		incCounter(&counter)
252	}
253}
254
255// ErrMessageTooLong is returned when attempting to encrypt a message which is
256// too large for the size of the public key.
257var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA public key size")
258
259func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int {
260	e := big.NewInt(int64(pub.E))
261	c.Exp(m, e, pub.N)
262	return c
263}
264
265// EncryptOAEP encrypts the given message with RSA-OAEP.
266// The message must be no longer than the length of the public modulus less
267// twice the hash length plus 2.
268func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, label []byte) (out []byte, err error) {
269	if err := checkPub(pub); err != nil {
270		return nil, err
271	}
272	hash.Reset()
273	k := (pub.N.BitLen() + 7) / 8
274	if len(msg) > k-2*hash.Size()-2 {
275		err = ErrMessageTooLong
276		return
277	}
278
279	hash.Write(label)
280	lHash := hash.Sum(nil)
281	hash.Reset()
282
283	em := make([]byte, k)
284	seed := em[1 : 1+hash.Size()]
285	db := em[1+hash.Size():]
286
287	copy(db[0:hash.Size()], lHash)
288	db[len(db)-len(msg)-1] = 1
289	copy(db[len(db)-len(msg):], msg)
290
291	_, err = io.ReadFull(random, seed)
292	if err != nil {
293		return
294	}
295
296	mgf1XOR(db, hash, seed)
297	mgf1XOR(seed, hash, db)
298
299	m := new(big.Int)
300	m.SetBytes(em)
301	c := encrypt(new(big.Int), pub, m)
302	out = c.Bytes()
303
304	if len(out) < k {
305		// If the output is too small, we need to left-pad with zeros.
306		t := make([]byte, k)
307		copy(t[k-len(out):], out)
308		out = t
309	}
310
311	return
312}
313
314// ErrDecryption represents a failure to decrypt a message.
315// It is deliberately vague to avoid adaptive attacks.
316var ErrDecryption = errors.New("crypto/rsa: decryption error")
317
318// ErrVerification represents a failure to verify a signature.
319// It is deliberately vague to avoid adaptive attacks.
320var ErrVerification = errors.New("crypto/rsa: verification error")
321
322// modInverse returns ia, the inverse of a in the multiplicative group of prime
323// order n. It requires that a be a member of the group (i.e. less than n).
324func modInverse(a, n *big.Int) (ia *big.Int, ok bool) {
325	g := new(big.Int)
326	x := new(big.Int)
327	y := new(big.Int)
328	g.GCD(x, y, a, n)
329	if g.Cmp(bigOne) != 0 {
330		// In this case, a and n aren't coprime and we cannot calculate
331		// the inverse. This happens because the values of n are nearly
332		// prime (being the product of two primes) rather than truly
333		// prime.
334		return
335	}
336
337	if x.Cmp(bigOne) < 0 {
338		// 0 is not the multiplicative inverse of any element so, if x
339		// < 1, then x is negative.
340		x.Add(x, n)
341	}
342
343	return x, true
344}
345
346// Precompute performs some calculations that speed up private key operations
347// in the future.
348func (priv *PrivateKey) Precompute() {
349	if priv.Precomputed.Dp != nil {
350		return
351	}
352
353	priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne)
354	priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp)
355
356	priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne)
357	priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq)
358
359	priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0])
360
361	r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1])
362	priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2)
363	for i := 2; i < len(priv.Primes); i++ {
364		prime := priv.Primes[i]
365		values := &priv.Precomputed.CRTValues[i-2]
366
367		values.Exp = new(big.Int).Sub(prime, bigOne)
368		values.Exp.Mod(priv.D, values.Exp)
369
370		values.R = new(big.Int).Set(r)
371		values.Coeff = new(big.Int).ModInverse(r, prime)
372
373		r.Mul(r, prime)
374	}
375}
376
377// decrypt performs an RSA decryption, resulting in a plaintext integer. If a
378// random source is given, RSA blinding is used.
379func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) {
380	// TODO(agl): can we get away with reusing blinds?
381	if c.Cmp(priv.N) > 0 {
382		err = ErrDecryption
383		return
384	}
385
386	var ir *big.Int
387	if random != nil {
388		// Blinding enabled. Blinding involves multiplying c by r^e.
389		// Then the decryption operation performs (m^e * r^e)^d mod n
390		// which equals mr mod n. The factor of r can then be removed
391		// by multiplying by the multiplicative inverse of r.
392
393		var r *big.Int
394
395		for {
396			r, err = rand.Int(random, priv.N)
397			if err != nil {
398				return
399			}
400			if r.Cmp(bigZero) == 0 {
401				r = bigOne
402			}
403			var ok bool
404			ir, ok = modInverse(r, priv.N)
405			if ok {
406				break
407			}
408		}
409		bigE := big.NewInt(int64(priv.E))
410		rpowe := new(big.Int).Exp(r, bigE, priv.N)
411		cCopy := new(big.Int).Set(c)
412		cCopy.Mul(cCopy, rpowe)
413		cCopy.Mod(cCopy, priv.N)
414		c = cCopy
415	}
416
417	if priv.Precomputed.Dp == nil {
418		m = new(big.Int).Exp(c, priv.D, priv.N)
419	} else {
420		// We have the precalculated values needed for the CRT.
421		m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0])
422		m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1])
423		m.Sub(m, m2)
424		if m.Sign() < 0 {
425			m.Add(m, priv.Primes[0])
426		}
427		m.Mul(m, priv.Precomputed.Qinv)
428		m.Mod(m, priv.Primes[0])
429		m.Mul(m, priv.Primes[1])
430		m.Add(m, m2)
431
432		for i, values := range priv.Precomputed.CRTValues {
433			prime := priv.Primes[2+i]
434			m2.Exp(c, values.Exp, prime)
435			m2.Sub(m2, m)
436			m2.Mul(m2, values.Coeff)
437			m2.Mod(m2, prime)
438			if m2.Sign() < 0 {
439				m2.Add(m2, prime)
440			}
441			m2.Mul(m2, values.R)
442			m.Add(m, m2)
443		}
444	}
445
446	if ir != nil {
447		// Unblind.
448		m.Mul(m, ir)
449		m.Mod(m, priv.N)
450	}
451
452	return
453}
454
455// DecryptOAEP decrypts ciphertext using RSA-OAEP.
456// If random != nil, DecryptOAEP uses RSA blinding to avoid timing side-channel attacks.
457func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) (msg []byte, err error) {
458	if err := checkPub(&priv.PublicKey); err != nil {
459		return nil, err
460	}
461	k := (priv.N.BitLen() + 7) / 8
462	if len(ciphertext) > k ||
463		k < hash.Size()*2+2 {
464		err = ErrDecryption
465		return
466	}
467
468	c := new(big.Int).SetBytes(ciphertext)
469
470	m, err := decrypt(random, priv, c)
471	if err != nil {
472		return
473	}
474
475	hash.Write(label)
476	lHash := hash.Sum(nil)
477	hash.Reset()
478
479	// Converting the plaintext number to bytes will strip any
480	// leading zeros so we may have to left pad. We do this unconditionally
481	// to avoid leaking timing information. (Although we still probably
482	// leak the number of leading zeros. It's not clear that we can do
483	// anything about this.)
484	em := leftPad(m.Bytes(), k)
485
486	firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
487
488	seed := em[1 : hash.Size()+1]
489	db := em[hash.Size()+1:]
490
491	mgf1XOR(seed, hash, db)
492	mgf1XOR(db, hash, seed)
493
494	lHash2 := db[0:hash.Size()]
495
496	// We have to validate the plaintext in constant time in order to avoid
497	// attacks like: J. Manger. A Chosen Ciphertext Attack on RSA Optimal
498	// Asymmetric Encryption Padding (OAEP) as Standardized in PKCS #1
499	// v2.0. In J. Kilian, editor, Advances in Cryptology.
500	lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2)
501
502	// The remainder of the plaintext must be zero or more 0x00, followed
503	// by 0x01, followed by the message.
504	//   lookingForIndex: 1 iff we are still looking for the 0x01
505	//   index: the offset of the first 0x01 byte
506	//   invalid: 1 iff we saw a non-zero byte before the 0x01.
507	var lookingForIndex, index, invalid int
508	lookingForIndex = 1
509	rest := db[hash.Size():]
510
511	for i := 0; i < len(rest); i++ {
512		equals0 := subtle.ConstantTimeByteEq(rest[i], 0)
513		equals1 := subtle.ConstantTimeByteEq(rest[i], 1)
514		index = subtle.ConstantTimeSelect(lookingForIndex&equals1, i, index)
515		lookingForIndex = subtle.ConstantTimeSelect(equals1, 0, lookingForIndex)
516		invalid = subtle.ConstantTimeSelect(lookingForIndex&^equals0, 1, invalid)
517	}
518
519	if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
520		err = ErrDecryption
521		return
522	}
523
524	msg = rest[index+1:]
525	return
526}
527
528// leftPad returns a new slice of length size. The contents of input are right
529// aligned in the new slice.
530func leftPad(input []byte, size int) (out []byte) {
531	n := len(input)
532	if n > size {
533		n = size
534	}
535	out = make([]byte, size)
536	copy(out[len(out)-n:], input)
537	return
538}
539