1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package rsa implements RSA encryption as specified in PKCS#1.
6//
7// RSA is a single, fundamental operation that is used in this package to
8// implement either public-key encryption or public-key signatures.
9//
10// The original specification for encryption and signatures with RSA is PKCS#1
11// and the terms "RSA encryption" and "RSA signatures" by default refer to
12// PKCS#1 version 1.5. However, that specification has flaws and new designs
13// should use version two, usually called by just OAEP and PSS, where
14// possible.
15//
16// Two sets of interfaces are included in this package. When a more abstract
17// interface isn't necessary, there are functions for encrypting/decrypting
18// with v1.5/OAEP and signing/verifying with v1.5/PSS. If one needs to abstract
19// over the public-key primitive, the PrivateKey struct implements the
20// Decrypter and Signer interfaces from the crypto package.
21//
22// The RSA operations in this package are not implemented using constant-time algorithms.
23package rsa
24
25import (
26	"crypto"
27	"crypto/rand"
28	"crypto/subtle"
29	"errors"
30	"hash"
31	"io"
32	"math"
33	"math/big"
34
35	"crypto/internal/randutil"
36)
37
38var bigZero = big.NewInt(0)
39var bigOne = big.NewInt(1)
40
41// A PublicKey represents the public part of an RSA key.
42type PublicKey struct {
43	N *big.Int // modulus
44	E int      // public exponent
45}
46
47// Size returns the modulus size in bytes. Raw signatures and ciphertexts
48// for or by this public key will have the same size.
49func (pub *PublicKey) Size() int {
50	return (pub.N.BitLen() + 7) / 8
51}
52
53// OAEPOptions is an interface for passing options to OAEP decryption using the
54// crypto.Decrypter interface.
55type OAEPOptions struct {
56	// Hash is the hash function that will be used when generating the mask.
57	Hash crypto.Hash
58	// Label is an arbitrary byte string that must be equal to the value
59	// used when encrypting.
60	Label []byte
61}
62
63var (
64	errPublicModulus       = errors.New("crypto/rsa: missing public modulus")
65	errPublicExponentSmall = errors.New("crypto/rsa: public exponent too small")
66	errPublicExponentLarge = errors.New("crypto/rsa: public exponent too large")
67)
68
69// checkPub sanity checks the public key before we use it.
70// We require pub.E to fit into a 32-bit integer so that we
71// do not have different behavior depending on whether
72// int is 32 or 64 bits. See also
73// https://www.imperialviolet.org/2012/03/16/rsae.html.
74func checkPub(pub *PublicKey) error {
75	if pub.N == nil {
76		return errPublicModulus
77	}
78	if pub.E < 2 {
79		return errPublicExponentSmall
80	}
81	if pub.E > 1<<31-1 {
82		return errPublicExponentLarge
83	}
84	return nil
85}
86
87// A PrivateKey represents an RSA key
88type PrivateKey struct {
89	PublicKey            // public part.
90	D         *big.Int   // private exponent
91	Primes    []*big.Int // prime factors of N, has >= 2 elements.
92
93	// Precomputed contains precomputed values that speed up private
94	// operations, if available.
95	Precomputed PrecomputedValues
96}
97
98// Public returns the public key corresponding to priv.
99func (priv *PrivateKey) Public() crypto.PublicKey {
100	return &priv.PublicKey
101}
102
103// Sign signs digest with priv, reading randomness from rand. If opts is a
104// *PSSOptions then the PSS algorithm will be used, otherwise PKCS#1 v1.5 will
105// be used.
106//
107// This method implements crypto.Signer, which is an interface to support keys
108// where the private part is kept in, for example, a hardware module. Common
109// uses should use the Sign* functions in this package directly.
110func (priv *PrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
111	if pssOpts, ok := opts.(*PSSOptions); ok {
112		return SignPSS(rand, priv, pssOpts.Hash, digest, pssOpts)
113	}
114
115	return SignPKCS1v15(rand, priv, opts.HashFunc(), digest)
116}
117
118// Decrypt decrypts ciphertext with priv. If opts is nil or of type
119// *PKCS1v15DecryptOptions then PKCS#1 v1.5 decryption is performed. Otherwise
120// opts must have type *OAEPOptions and OAEP decryption is done.
121func (priv *PrivateKey) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
122	if opts == nil {
123		return DecryptPKCS1v15(rand, priv, ciphertext)
124	}
125
126	switch opts := opts.(type) {
127	case *OAEPOptions:
128		return DecryptOAEP(opts.Hash.New(), rand, priv, ciphertext, opts.Label)
129
130	case *PKCS1v15DecryptOptions:
131		if l := opts.SessionKeyLen; l > 0 {
132			plaintext = make([]byte, l)
133			if _, err := io.ReadFull(rand, plaintext); err != nil {
134				return nil, err
135			}
136			if err := DecryptPKCS1v15SessionKey(rand, priv, ciphertext, plaintext); err != nil {
137				return nil, err
138			}
139			return plaintext, nil
140		} else {
141			return DecryptPKCS1v15(rand, priv, ciphertext)
142		}
143
144	default:
145		return nil, errors.New("crypto/rsa: invalid options for Decrypt")
146	}
147}
148
149type PrecomputedValues struct {
150	Dp, Dq *big.Int // D mod (P-1) (or mod Q-1)
151	Qinv   *big.Int // Q^-1 mod P
152
153	// CRTValues is used for the 3rd and subsequent primes. Due to a
154	// historical accident, the CRT for the first two primes is handled
155	// differently in PKCS#1 and interoperability is sufficiently
156	// important that we mirror this.
157	CRTValues []CRTValue
158}
159
160// CRTValue contains the precomputed Chinese remainder theorem values.
161type CRTValue struct {
162	Exp   *big.Int // D mod (prime-1).
163	Coeff *big.Int // R·Coeff ≡ 1 mod Prime.
164	R     *big.Int // product of primes prior to this (inc p and q).
165}
166
167// Validate performs basic sanity checks on the key.
168// It returns nil if the key is valid, or else an error describing a problem.
169func (priv *PrivateKey) Validate() error {
170	if err := checkPub(&priv.PublicKey); err != nil {
171		return err
172	}
173
174	// Check that Πprimes == n.
175	modulus := new(big.Int).Set(bigOne)
176	for _, prime := range priv.Primes {
177		// Any primes ≤ 1 will cause divide-by-zero panics later.
178		if prime.Cmp(bigOne) <= 0 {
179			return errors.New("crypto/rsa: invalid prime value")
180		}
181		modulus.Mul(modulus, prime)
182	}
183	if modulus.Cmp(priv.N) != 0 {
184		return errors.New("crypto/rsa: invalid modulus")
185	}
186
187	// Check that de ≡ 1 mod p-1, for each prime.
188	// This implies that e is coprime to each p-1 as e has a multiplicative
189	// inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
190	// exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
191	// mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
192	congruence := new(big.Int)
193	de := new(big.Int).SetInt64(int64(priv.E))
194	de.Mul(de, priv.D)
195	for _, prime := range priv.Primes {
196		pminus1 := new(big.Int).Sub(prime, bigOne)
197		congruence.Mod(de, pminus1)
198		if congruence.Cmp(bigOne) != 0 {
199			return errors.New("crypto/rsa: invalid exponents")
200		}
201	}
202	return nil
203}
204
205// GenerateKey generates an RSA keypair of the given bit size using the
206// random source random (for example, crypto/rand.Reader).
207func GenerateKey(random io.Reader, bits int) (*PrivateKey, error) {
208	return GenerateMultiPrimeKey(random, 2, bits)
209}
210
211// GenerateMultiPrimeKey generates a multi-prime RSA keypair of the given bit
212// size and the given random source, as suggested in [1]. Although the public
213// keys are compatible (actually, indistinguishable) from the 2-prime case,
214// the private keys are not. Thus it may not be possible to export multi-prime
215// private keys in certain formats or to subsequently import them into other
216// code.
217//
218// Table 1 in [2] suggests maximum numbers of primes for a given size.
219//
220// [1] US patent 4405829 (1972, expired)
221// [2] http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf
222func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey, error) {
223	randutil.MaybeReadByte(random)
224
225	priv := new(PrivateKey)
226	priv.E = 65537
227
228	if nprimes < 2 {
229		return nil, errors.New("crypto/rsa: GenerateMultiPrimeKey: nprimes must be >= 2")
230	}
231
232	if bits < 64 {
233		primeLimit := float64(uint64(1) << uint(bits/nprimes))
234		// pi approximates the number of primes less than primeLimit
235		pi := primeLimit / (math.Log(primeLimit) - 1)
236		// Generated primes start with 11 (in binary) so we can only
237		// use a quarter of them.
238		pi /= 4
239		// Use a factor of two to ensure that key generation terminates
240		// in a reasonable amount of time.
241		pi /= 2
242		if pi <= float64(nprimes) {
243			return nil, errors.New("crypto/rsa: too few primes of given length to generate an RSA key")
244		}
245	}
246
247	primes := make([]*big.Int, nprimes)
248
249NextSetOfPrimes:
250	for {
251		todo := bits
252		// crypto/rand should set the top two bits in each prime.
253		// Thus each prime has the form
254		//   p_i = 2^bitlen(p_i) × 0.11... (in base 2).
255		// And the product is:
256		//   P = 2^todo × α
257		// where α is the product of nprimes numbers of the form 0.11...
258		//
259		// If α < 1/2 (which can happen for nprimes > 2), we need to
260		// shift todo to compensate for lost bits: the mean value of 0.11...
261		// is 7/8, so todo + shift - nprimes * log2(7/8) ~= bits - 1/2
262		// will give good results.
263		if nprimes >= 7 {
264			todo += (nprimes - 2) / 5
265		}
266		for i := 0; i < nprimes; i++ {
267			var err error
268			primes[i], err = rand.Prime(random, todo/(nprimes-i))
269			if err != nil {
270				return nil, err
271			}
272			todo -= primes[i].BitLen()
273		}
274
275		// Make sure that primes is pairwise unequal.
276		for i, prime := range primes {
277			for j := 0; j < i; j++ {
278				if prime.Cmp(primes[j]) == 0 {
279					continue NextSetOfPrimes
280				}
281			}
282		}
283
284		n := new(big.Int).Set(bigOne)
285		totient := new(big.Int).Set(bigOne)
286		pminus1 := new(big.Int)
287		for _, prime := range primes {
288			n.Mul(n, prime)
289			pminus1.Sub(prime, bigOne)
290			totient.Mul(totient, pminus1)
291		}
292		if n.BitLen() != bits {
293			// This should never happen for nprimes == 2 because
294			// crypto/rand should set the top two bits in each prime.
295			// For nprimes > 2 we hope it does not happen often.
296			continue NextSetOfPrimes
297		}
298
299		priv.D = new(big.Int)
300		e := big.NewInt(int64(priv.E))
301		ok := priv.D.ModInverse(e, totient)
302
303		if ok != nil {
304			priv.Primes = primes
305			priv.N = n
306			break
307		}
308	}
309
310	priv.Precompute()
311	return priv, nil
312}
313
314// incCounter increments a four byte, big-endian counter.
315func incCounter(c *[4]byte) {
316	if c[3]++; c[3] != 0 {
317		return
318	}
319	if c[2]++; c[2] != 0 {
320		return
321	}
322	if c[1]++; c[1] != 0 {
323		return
324	}
325	c[0]++
326}
327
328// mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function
329// specified in PKCS#1 v2.1.
330func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
331	var counter [4]byte
332	var digest []byte
333
334	done := 0
335	for done < len(out) {
336		hash.Write(seed)
337		hash.Write(counter[0:4])
338		digest = hash.Sum(digest[:0])
339		hash.Reset()
340
341		for i := 0; i < len(digest) && done < len(out); i++ {
342			out[done] ^= digest[i]
343			done++
344		}
345		incCounter(&counter)
346	}
347}
348
349// ErrMessageTooLong is returned when attempting to encrypt a message which is
350// too large for the size of the public key.
351var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA public key size")
352
353func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int {
354	e := big.NewInt(int64(pub.E))
355	c.Exp(m, e, pub.N)
356	return c
357}
358
359// EncryptOAEP encrypts the given message with RSA-OAEP.
360//
361// OAEP is parameterised by a hash function that is used as a random oracle.
362// Encryption and decryption of a given message must use the same hash function
363// and sha256.New() is a reasonable choice.
364//
365// The random parameter is used as a source of entropy to ensure that
366// encrypting the same message twice doesn't result in the same ciphertext.
367//
368// The label parameter may contain arbitrary data that will not be encrypted,
369// but which gives important context to the message. For example, if a given
370// public key is used to decrypt two types of messages then distinct label
371// values could be used to ensure that a ciphertext for one purpose cannot be
372// used for another by an attacker. If not required it can be empty.
373//
374// The message must be no longer than the length of the public modulus minus
375// twice the hash length, minus a further 2.
376func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, label []byte) ([]byte, error) {
377	if err := checkPub(pub); err != nil {
378		return nil, err
379	}
380	hash.Reset()
381	k := pub.Size()
382	if len(msg) > k-2*hash.Size()-2 {
383		return nil, ErrMessageTooLong
384	}
385
386	hash.Write(label)
387	lHash := hash.Sum(nil)
388	hash.Reset()
389
390	em := make([]byte, k)
391	seed := em[1 : 1+hash.Size()]
392	db := em[1+hash.Size():]
393
394	copy(db[0:hash.Size()], lHash)
395	db[len(db)-len(msg)-1] = 1
396	copy(db[len(db)-len(msg):], msg)
397
398	_, err := io.ReadFull(random, seed)
399	if err != nil {
400		return nil, err
401	}
402
403	mgf1XOR(db, hash, seed)
404	mgf1XOR(seed, hash, db)
405
406	m := new(big.Int)
407	m.SetBytes(em)
408	c := encrypt(new(big.Int), pub, m)
409	out := c.Bytes()
410
411	if len(out) < k {
412		// If the output is too small, we need to left-pad with zeros.
413		t := make([]byte, k)
414		copy(t[k-len(out):], out)
415		out = t
416	}
417
418	return out, nil
419}
420
421// ErrDecryption represents a failure to decrypt a message.
422// It is deliberately vague to avoid adaptive attacks.
423var ErrDecryption = errors.New("crypto/rsa: decryption error")
424
425// ErrVerification represents a failure to verify a signature.
426// It is deliberately vague to avoid adaptive attacks.
427var ErrVerification = errors.New("crypto/rsa: verification error")
428
429// Precompute performs some calculations that speed up private key operations
430// in the future.
431func (priv *PrivateKey) Precompute() {
432	if priv.Precomputed.Dp != nil {
433		return
434	}
435
436	priv.Precomputed.Dp = new(big.Int).Sub(priv.Primes[0], bigOne)
437	priv.Precomputed.Dp.Mod(priv.D, priv.Precomputed.Dp)
438
439	priv.Precomputed.Dq = new(big.Int).Sub(priv.Primes[1], bigOne)
440	priv.Precomputed.Dq.Mod(priv.D, priv.Precomputed.Dq)
441
442	priv.Precomputed.Qinv = new(big.Int).ModInverse(priv.Primes[1], priv.Primes[0])
443
444	r := new(big.Int).Mul(priv.Primes[0], priv.Primes[1])
445	priv.Precomputed.CRTValues = make([]CRTValue, len(priv.Primes)-2)
446	for i := 2; i < len(priv.Primes); i++ {
447		prime := priv.Primes[i]
448		values := &priv.Precomputed.CRTValues[i-2]
449
450		values.Exp = new(big.Int).Sub(prime, bigOne)
451		values.Exp.Mod(priv.D, values.Exp)
452
453		values.R = new(big.Int).Set(r)
454		values.Coeff = new(big.Int).ModInverse(r, prime)
455
456		r.Mul(r, prime)
457	}
458}
459
460// decrypt performs an RSA decryption, resulting in a plaintext integer. If a
461// random source is given, RSA blinding is used.
462func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) {
463	// TODO(agl): can we get away with reusing blinds?
464	if c.Cmp(priv.N) > 0 {
465		err = ErrDecryption
466		return
467	}
468	if priv.N.Sign() == 0 {
469		return nil, ErrDecryption
470	}
471
472	var ir *big.Int
473	if random != nil {
474		randutil.MaybeReadByte(random)
475
476		// Blinding enabled. Blinding involves multiplying c by r^e.
477		// Then the decryption operation performs (m^e * r^e)^d mod n
478		// which equals mr mod n. The factor of r can then be removed
479		// by multiplying by the multiplicative inverse of r.
480
481		var r *big.Int
482		ir = new(big.Int)
483		for {
484			r, err = rand.Int(random, priv.N)
485			if err != nil {
486				return
487			}
488			if r.Cmp(bigZero) == 0 {
489				r = bigOne
490			}
491			ok := ir.ModInverse(r, priv.N)
492			if ok != nil {
493				break
494			}
495		}
496		bigE := big.NewInt(int64(priv.E))
497		rpowe := new(big.Int).Exp(r, bigE, priv.N) // N != 0
498		cCopy := new(big.Int).Set(c)
499		cCopy.Mul(cCopy, rpowe)
500		cCopy.Mod(cCopy, priv.N)
501		c = cCopy
502	}
503
504	if priv.Precomputed.Dp == nil {
505		m = new(big.Int).Exp(c, priv.D, priv.N)
506	} else {
507		// We have the precalculated values needed for the CRT.
508		m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0])
509		m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1])
510		m.Sub(m, m2)
511		if m.Sign() < 0 {
512			m.Add(m, priv.Primes[0])
513		}
514		m.Mul(m, priv.Precomputed.Qinv)
515		m.Mod(m, priv.Primes[0])
516		m.Mul(m, priv.Primes[1])
517		m.Add(m, m2)
518
519		for i, values := range priv.Precomputed.CRTValues {
520			prime := priv.Primes[2+i]
521			m2.Exp(c, values.Exp, prime)
522			m2.Sub(m2, m)
523			m2.Mul(m2, values.Coeff)
524			m2.Mod(m2, prime)
525			if m2.Sign() < 0 {
526				m2.Add(m2, prime)
527			}
528			m2.Mul(m2, values.R)
529			m.Add(m, m2)
530		}
531	}
532
533	if ir != nil {
534		// Unblind.
535		m.Mul(m, ir)
536		m.Mod(m, priv.N)
537	}
538
539	return
540}
541
542func decryptAndCheck(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) {
543	m, err = decrypt(random, priv, c)
544	if err != nil {
545		return nil, err
546	}
547
548	// In order to defend against errors in the CRT computation, m^e is
549	// calculated, which should match the original ciphertext.
550	check := encrypt(new(big.Int), &priv.PublicKey, m)
551	if c.Cmp(check) != 0 {
552		return nil, errors.New("rsa: internal error")
553	}
554	return m, nil
555}
556
557// DecryptOAEP decrypts ciphertext using RSA-OAEP.
558
559// OAEP is parameterised by a hash function that is used as a random oracle.
560// Encryption and decryption of a given message must use the same hash function
561// and sha256.New() is a reasonable choice.
562//
563// The random parameter, if not nil, is used to blind the private-key operation
564// and avoid timing side-channel attacks. Blinding is purely internal to this
565// function – the random data need not match that used when encrypting.
566//
567// The label parameter must match the value given when encrypting. See
568// EncryptOAEP for details.
569func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) ([]byte, error) {
570	if err := checkPub(&priv.PublicKey); err != nil {
571		return nil, err
572	}
573	k := priv.Size()
574	if len(ciphertext) > k ||
575		k < hash.Size()*2+2 {
576		return nil, ErrDecryption
577	}
578
579	c := new(big.Int).SetBytes(ciphertext)
580
581	m, err := decrypt(random, priv, c)
582	if err != nil {
583		return nil, err
584	}
585
586	hash.Write(label)
587	lHash := hash.Sum(nil)
588	hash.Reset()
589
590	// Converting the plaintext number to bytes will strip any
591	// leading zeros so we may have to left pad. We do this unconditionally
592	// to avoid leaking timing information. (Although we still probably
593	// leak the number of leading zeros. It's not clear that we can do
594	// anything about this.)
595	em := leftPad(m.Bytes(), k)
596
597	firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
598
599	seed := em[1 : hash.Size()+1]
600	db := em[hash.Size()+1:]
601
602	mgf1XOR(seed, hash, db)
603	mgf1XOR(db, hash, seed)
604
605	lHash2 := db[0:hash.Size()]
606
607	// We have to validate the plaintext in constant time in order to avoid
608	// attacks like: J. Manger. A Chosen Ciphertext Attack on RSA Optimal
609	// Asymmetric Encryption Padding (OAEP) as Standardized in PKCS #1
610	// v2.0. In J. Kilian, editor, Advances in Cryptology.
611	lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2)
612
613	// The remainder of the plaintext must be zero or more 0x00, followed
614	// by 0x01, followed by the message.
615	//   lookingForIndex: 1 iff we are still looking for the 0x01
616	//   index: the offset of the first 0x01 byte
617	//   invalid: 1 iff we saw a non-zero byte before the 0x01.
618	var lookingForIndex, index, invalid int
619	lookingForIndex = 1
620	rest := db[hash.Size():]
621
622	for i := 0; i < len(rest); i++ {
623		equals0 := subtle.ConstantTimeByteEq(rest[i], 0)
624		equals1 := subtle.ConstantTimeByteEq(rest[i], 1)
625		index = subtle.ConstantTimeSelect(lookingForIndex&equals1, i, index)
626		lookingForIndex = subtle.ConstantTimeSelect(equals1, 0, lookingForIndex)
627		invalid = subtle.ConstantTimeSelect(lookingForIndex&^equals0, 1, invalid)
628	}
629
630	if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
631		return nil, ErrDecryption
632	}
633
634	return rest[index+1:], nil
635}
636
637// leftPad returns a new slice of length size. The contents of input are right
638// aligned in the new slice.
639func leftPad(input []byte, size int) (out []byte) {
640	n := len(input)
641	if n > size {
642		n = size
643	}
644	out = make([]byte, size)
645	copy(out[len(out)-n:], input)
646	return
647}
648