1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This file implements unsigned multi-precision integers (natural
6// numbers). They are the building blocks for the implementation
7// of signed integers, rationals, and floating-point numbers.
8
9package big
10
11import (
12	"math/bits"
13	"math/rand"
14	"sync"
15)
16
17// An unsigned integer x of the form
18//
19//   x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
20//
21// with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
22// with the digits x[i] as the slice elements.
23//
24// A number is normalized if the slice contains no leading 0 digits.
25// During arithmetic operations, denormalized values may occur but are
26// always normalized before returning the final result. The normalized
27// representation of 0 is the empty or nil slice (length = 0).
28//
29type nat []Word
30
31var (
32	natOne = nat{1}
33	natTwo = nat{2}
34	natTen = nat{10}
35)
36
37func (z nat) clear() {
38	for i := range z {
39		z[i] = 0
40	}
41}
42
43func (z nat) norm() nat {
44	i := len(z)
45	for i > 0 && z[i-1] == 0 {
46		i--
47	}
48	return z[0:i]
49}
50
51func (z nat) make(n int) nat {
52	if n <= cap(z) {
53		return z[:n] // reuse z
54	}
55	// Choosing a good value for e has significant performance impact
56	// because it increases the chance that a value can be reused.
57	const e = 4 // extra capacity
58	return make(nat, n, n+e)
59}
60
61func (z nat) setWord(x Word) nat {
62	if x == 0 {
63		return z[:0]
64	}
65	z = z.make(1)
66	z[0] = x
67	return z
68}
69
70func (z nat) setUint64(x uint64) nat {
71	// single-word value
72	if w := Word(x); uint64(w) == x {
73		return z.setWord(w)
74	}
75	// 2-word value
76	z = z.make(2)
77	z[1] = Word(x >> 32)
78	z[0] = Word(x)
79	return z
80}
81
82func (z nat) set(x nat) nat {
83	z = z.make(len(x))
84	copy(z, x)
85	return z
86}
87
88func (z nat) add(x, y nat) nat {
89	m := len(x)
90	n := len(y)
91
92	switch {
93	case m < n:
94		return z.add(y, x)
95	case m == 0:
96		// n == 0 because m >= n; result is 0
97		return z[:0]
98	case n == 0:
99		// result is x
100		return z.set(x)
101	}
102	// m > 0
103
104	z = z.make(m + 1)
105	c := addVV(z[0:n], x, y)
106	if m > n {
107		c = addVW(z[n:m], x[n:], c)
108	}
109	z[m] = c
110
111	return z.norm()
112}
113
114func (z nat) sub(x, y nat) nat {
115	m := len(x)
116	n := len(y)
117
118	switch {
119	case m < n:
120		panic("underflow")
121	case m == 0:
122		// n == 0 because m >= n; result is 0
123		return z[:0]
124	case n == 0:
125		// result is x
126		return z.set(x)
127	}
128	// m > 0
129
130	z = z.make(m)
131	c := subVV(z[0:n], x, y)
132	if m > n {
133		c = subVW(z[n:], x[n:], c)
134	}
135	if c != 0 {
136		panic("underflow")
137	}
138
139	return z.norm()
140}
141
142func (x nat) cmp(y nat) (r int) {
143	m := len(x)
144	n := len(y)
145	if m != n || m == 0 {
146		switch {
147		case m < n:
148			r = -1
149		case m > n:
150			r = 1
151		}
152		return
153	}
154
155	i := m - 1
156	for i > 0 && x[i] == y[i] {
157		i--
158	}
159
160	switch {
161	case x[i] < y[i]:
162		r = -1
163	case x[i] > y[i]:
164		r = 1
165	}
166	return
167}
168
169func (z nat) mulAddWW(x nat, y, r Word) nat {
170	m := len(x)
171	if m == 0 || y == 0 {
172		return z.setWord(r) // result is r
173	}
174	// m > 0
175
176	z = z.make(m + 1)
177	z[m] = mulAddVWW(z[0:m], x, y, r)
178
179	return z.norm()
180}
181
182// basicMul multiplies x and y and leaves the result in z.
183// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
184func basicMul(z, x, y nat) {
185	z[0 : len(x)+len(y)].clear() // initialize z
186	for i, d := range y {
187		if d != 0 {
188			z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
189		}
190	}
191}
192
193// montgomery computes z mod m = x*y*2**(-n*_W) mod m,
194// assuming k = -1/m mod 2**_W.
195// z is used for storing the result which is returned;
196// z must not alias x, y or m.
197// See Gueron, "Efficient Software Implementations of Modular Exponentiation".
198// https://eprint.iacr.org/2011/239.pdf
199// In the terminology of that paper, this is an "Almost Montgomery Multiplication":
200// x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
201// z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
202func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
203	// This code assumes x, y, m are all the same length, n.
204	// (required by addMulVVW and the for loop).
205	// It also assumes that x, y are already reduced mod m,
206	// or else the result will not be properly reduced.
207	if len(x) != n || len(y) != n || len(m) != n {
208		panic("math/big: mismatched montgomery number lengths")
209	}
210	z = z.make(n)
211	z.clear()
212	var c Word
213	for i := 0; i < n; i++ {
214		d := y[i]
215		c2 := addMulVVW(z, x, d)
216		t := z[0] * k
217		c3 := addMulVVW(z, m, t)
218		copy(z, z[1:])
219		cx := c + c2
220		cy := cx + c3
221		z[n-1] = cy
222		if cx < c2 || cy < c3 {
223			c = 1
224		} else {
225			c = 0
226		}
227	}
228	if c != 0 {
229		subVV(z, z, m)
230	}
231	return z
232}
233
234// Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
235// Factored out for readability - do not use outside karatsuba.
236func karatsubaAdd(z, x nat, n int) {
237	if c := addVV(z[0:n], z, x); c != 0 {
238		addVW(z[n:n+n>>1], z[n:], c)
239	}
240}
241
242// Like karatsubaAdd, but does subtract.
243func karatsubaSub(z, x nat, n int) {
244	if c := subVV(z[0:n], z, x); c != 0 {
245		subVW(z[n:n+n>>1], z[n:], c)
246	}
247}
248
249// Operands that are shorter than karatsubaThreshold are multiplied using
250// "grade school" multiplication; for longer operands the Karatsuba algorithm
251// is used.
252var karatsubaThreshold int = 40 // computed by calibrate.go
253
254// karatsuba multiplies x and y and leaves the result in z.
255// Both x and y must have the same length n and n must be a
256// power of 2. The result vector z must have len(z) >= 6*n.
257// The (non-normalized) result is placed in z[0 : 2*n].
258func karatsuba(z, x, y nat) {
259	n := len(y)
260
261	// Switch to basic multiplication if numbers are odd or small.
262	// (n is always even if karatsubaThreshold is even, but be
263	// conservative)
264	if n&1 != 0 || n < karatsubaThreshold || n < 2 {
265		basicMul(z, x, y)
266		return
267	}
268	// n&1 == 0 && n >= karatsubaThreshold && n >= 2
269
270	// Karatsuba multiplication is based on the observation that
271	// for two numbers x and y with:
272	//
273	//   x = x1*b + x0
274	//   y = y1*b + y0
275	//
276	// the product x*y can be obtained with 3 products z2, z1, z0
277	// instead of 4:
278	//
279	//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
280	//       =    z2*b*b +              z1*b +    z0
281	//
282	// with:
283	//
284	//   xd = x1 - x0
285	//   yd = y0 - y1
286	//
287	//   z1 =      xd*yd                    + z2 + z0
288	//      = (x1-x0)*(y0 - y1)             + z2 + z0
289	//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
290	//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
291	//      = x1*y0                 + x0*y1
292
293	// split x, y into "digits"
294	n2 := n >> 1              // n2 >= 1
295	x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
296	y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
297
298	// z is used for the result and temporary storage:
299	//
300	//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
301	// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
302	//
303	// For each recursive call of karatsuba, an unused slice of
304	// z is passed in that has (at least) half the length of the
305	// caller's z.
306
307	// compute z0 and z2 with the result "in place" in z
308	karatsuba(z, x0, y0)     // z0 = x0*y0
309	karatsuba(z[n:], x1, y1) // z2 = x1*y1
310
311	// compute xd (or the negative value if underflow occurs)
312	s := 1 // sign of product xd*yd
313	xd := z[2*n : 2*n+n2]
314	if subVV(xd, x1, x0) != 0 { // x1-x0
315		s = -s
316		subVV(xd, x0, x1) // x0-x1
317	}
318
319	// compute yd (or the negative value if underflow occurs)
320	yd := z[2*n+n2 : 3*n]
321	if subVV(yd, y0, y1) != 0 { // y0-y1
322		s = -s
323		subVV(yd, y1, y0) // y1-y0
324	}
325
326	// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
327	// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
328	p := z[n*3:]
329	karatsuba(p, xd, yd)
330
331	// save original z2:z0
332	// (ok to use upper half of z since we're done recursing)
333	r := z[n*4:]
334	copy(r, z[:n*2])
335
336	// add up all partial products
337	//
338	//   2*n     n     0
339	// z = [ z2  | z0  ]
340	//   +    [ z0  ]
341	//   +    [ z2  ]
342	//   +    [  p  ]
343	//
344	karatsubaAdd(z[n2:], r, n)
345	karatsubaAdd(z[n2:], r[n:], n)
346	if s > 0 {
347		karatsubaAdd(z[n2:], p, n)
348	} else {
349		karatsubaSub(z[n2:], p, n)
350	}
351}
352
353// alias reports whether x and y share the same base array.
354func alias(x, y nat) bool {
355	return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
356}
357
358// addAt implements z += x<<(_W*i); z must be long enough.
359// (we don't use nat.add because we need z to stay the same
360// slice, and we don't need to normalize z after each addition)
361func addAt(z, x nat, i int) {
362	if n := len(x); n > 0 {
363		if c := addVV(z[i:i+n], z[i:], x); c != 0 {
364			j := i + n
365			if j < len(z) {
366				addVW(z[j:], z[j:], c)
367			}
368		}
369	}
370}
371
372func max(x, y int) int {
373	if x > y {
374		return x
375	}
376	return y
377}
378
379// karatsubaLen computes an approximation to the maximum k <= n such that
380// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
381// result is the largest number that can be divided repeatedly by 2 before
382// becoming about the value of karatsubaThreshold.
383func karatsubaLen(n int) int {
384	i := uint(0)
385	for n > karatsubaThreshold {
386		n >>= 1
387		i++
388	}
389	return n << i
390}
391
392func (z nat) mul(x, y nat) nat {
393	m := len(x)
394	n := len(y)
395
396	switch {
397	case m < n:
398		return z.mul(y, x)
399	case m == 0 || n == 0:
400		return z[:0]
401	case n == 1:
402		return z.mulAddWW(x, y[0], 0)
403	}
404	// m >= n > 1
405
406	// determine if z can be reused
407	if alias(z, x) || alias(z, y) {
408		z = nil // z is an alias for x or y - cannot reuse
409	}
410
411	// use basic multiplication if the numbers are small
412	if n < karatsubaThreshold {
413		z = z.make(m + n)
414		basicMul(z, x, y)
415		return z.norm()
416	}
417	// m >= n && n >= karatsubaThreshold && n >= 2
418
419	// determine Karatsuba length k such that
420	//
421	//   x = xh*b + x0  (0 <= x0 < b)
422	//   y = yh*b + y0  (0 <= y0 < b)
423	//   b = 1<<(_W*k)  ("base" of digits xi, yi)
424	//
425	k := karatsubaLen(n)
426	// k <= n
427
428	// multiply x0 and y0 via Karatsuba
429	x0 := x[0:k]              // x0 is not normalized
430	y0 := y[0:k]              // y0 is not normalized
431	z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
432	karatsuba(z, x0, y0)
433	z = z[0 : m+n]  // z has final length but may be incomplete
434	z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
435
436	// If xh != 0 or yh != 0, add the missing terms to z. For
437	//
438	//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
439	//   yh =                         y1*b (0 <= y1 < b)
440	//
441	// the missing terms are
442	//
443	//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
444	//
445	// since all the yi for i > 1 are 0 by choice of k: If any of them
446	// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
447	// be a larger valid threshold contradicting the assumption about k.
448	//
449	if k < n || m != n {
450		var t nat
451
452		// add x0*y1*b
453		x0 := x0.norm()
454		y1 := y[k:]       // y1 is normalized because y is
455		t = t.mul(x0, y1) // update t so we don't lose t's underlying array
456		addAt(z, t, k)
457
458		// add xi*y0<<i, xi*y1*b<<(i+k)
459		y0 := y0.norm()
460		for i := k; i < len(x); i += k {
461			xi := x[i:]
462			if len(xi) > k {
463				xi = xi[:k]
464			}
465			xi = xi.norm()
466			t = t.mul(xi, y0)
467			addAt(z, t, i)
468			t = t.mul(xi, y1)
469			addAt(z, t, i+k)
470		}
471	}
472
473	return z.norm()
474}
475
476// mulRange computes the product of all the unsigned integers in the
477// range [a, b] inclusively. If a > b (empty range), the result is 1.
478func (z nat) mulRange(a, b uint64) nat {
479	switch {
480	case a == 0:
481		// cut long ranges short (optimization)
482		return z.setUint64(0)
483	case a > b:
484		return z.setUint64(1)
485	case a == b:
486		return z.setUint64(a)
487	case a+1 == b:
488		return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
489	}
490	m := (a + b) / 2
491	return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
492}
493
494// q = (x-r)/y, with 0 <= r < y
495func (z nat) divW(x nat, y Word) (q nat, r Word) {
496	m := len(x)
497	switch {
498	case y == 0:
499		panic("division by zero")
500	case y == 1:
501		q = z.set(x) // result is x
502		return
503	case m == 0:
504		q = z[:0] // result is 0
505		return
506	}
507	// m > 0
508	z = z.make(m)
509	r = divWVW(z, 0, x, y)
510	q = z.norm()
511	return
512}
513
514func (z nat) div(z2, u, v nat) (q, r nat) {
515	if len(v) == 0 {
516		panic("division by zero")
517	}
518
519	if u.cmp(v) < 0 {
520		q = z[:0]
521		r = z2.set(u)
522		return
523	}
524
525	if len(v) == 1 {
526		var r2 Word
527		q, r2 = z.divW(u, v[0])
528		r = z2.setWord(r2)
529		return
530	}
531
532	q, r = z.divLarge(z2, u, v)
533	return
534}
535
536// getNat returns a *nat of len n. The contents may not be zero.
537// The pool holds *nat to avoid allocation when converting to interface{}.
538func getNat(n int) *nat {
539	var z *nat
540	if v := natPool.Get(); v != nil {
541		z = v.(*nat)
542	}
543	if z == nil {
544		z = new(nat)
545	}
546	*z = z.make(n)
547	return z
548}
549
550func putNat(x *nat) {
551	natPool.Put(x)
552}
553
554var natPool sync.Pool
555
556// q = (uIn-r)/v, with 0 <= r < y
557// Uses z as storage for q, and u as storage for r if possible.
558// See Knuth, Volume 2, section 4.3.1, Algorithm D.
559// Preconditions:
560//    len(v) >= 2
561//    len(uIn) >= len(v)
562func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
563	n := len(v)
564	m := len(uIn) - n
565
566	// determine if z can be reused
567	// TODO(gri) should find a better solution - this if statement
568	//           is very costly (see e.g. time pidigits -s -n 10000)
569	if alias(z, u) || alias(z, uIn) || alias(z, v) {
570		z = nil // z is an alias for u or uIn or v - cannot reuse
571	}
572	q = z.make(m + 1)
573
574	qhatvp := getNat(n + 1)
575	qhatv := *qhatvp
576	if alias(u, uIn) || alias(u, v) {
577		u = nil // u is an alias for uIn or v - cannot reuse
578	}
579	u = u.make(len(uIn) + 1)
580	u.clear() // TODO(gri) no need to clear if we allocated a new u
581
582	// D1.
583	var v1p *nat
584	shift := nlz(v[n-1])
585	if shift > 0 {
586		// do not modify v, it may be used by another goroutine simultaneously
587		v1p = getNat(n)
588		v1 := *v1p
589		shlVU(v1, v, shift)
590		v = v1
591	}
592	u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
593
594	// D2.
595	vn1 := v[n-1]
596	for j := m; j >= 0; j-- {
597		// D3.
598		qhat := Word(_M)
599		if ujn := u[j+n]; ujn != vn1 {
600			var rhat Word
601			qhat, rhat = divWW(ujn, u[j+n-1], vn1)
602
603			// x1 | x2 = q̂v_{n-2}
604			vn2 := v[n-2]
605			x1, x2 := mulWW(qhat, vn2)
606			// test if q̂v_{n-2} > br̂ + u_{j+n-2}
607			ujn2 := u[j+n-2]
608			for greaterThan(x1, x2, rhat, ujn2) {
609				qhat--
610				prevRhat := rhat
611				rhat += vn1
612				// v[n-1] >= 0, so this tests for overflow.
613				if rhat < prevRhat {
614					break
615				}
616				x1, x2 = mulWW(qhat, vn2)
617			}
618		}
619
620		// D4.
621		qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
622
623		c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
624		if c != 0 {
625			c := addVV(u[j:j+n], u[j:], v)
626			u[j+n] += c
627			qhat--
628		}
629
630		q[j] = qhat
631	}
632	if v1p != nil {
633		putNat(v1p)
634	}
635	putNat(qhatvp)
636
637	q = q.norm()
638	shrVU(u, u, shift)
639	r = u.norm()
640
641	return q, r
642}
643
644// Length of x in bits. x must be normalized.
645func (x nat) bitLen() int {
646	if i := len(x) - 1; i >= 0 {
647		return i*_W + bits.Len(uint(x[i]))
648	}
649	return 0
650}
651
652// trailingZeroBits returns the number of consecutive least significant zero
653// bits of x.
654func (x nat) trailingZeroBits() uint {
655	if len(x) == 0 {
656		return 0
657	}
658	var i uint
659	for x[i] == 0 {
660		i++
661	}
662	// x[i] != 0
663	return i*_W + uint(bits.TrailingZeros(uint(x[i])))
664}
665
666// z = x << s
667func (z nat) shl(x nat, s uint) nat {
668	m := len(x)
669	if m == 0 {
670		return z[:0]
671	}
672	// m > 0
673
674	n := m + int(s/_W)
675	z = z.make(n + 1)
676	z[n] = shlVU(z[n-m:n], x, s%_W)
677	z[0 : n-m].clear()
678
679	return z.norm()
680}
681
682// z = x >> s
683func (z nat) shr(x nat, s uint) nat {
684	m := len(x)
685	n := m - int(s/_W)
686	if n <= 0 {
687		return z[:0]
688	}
689	// n > 0
690
691	z = z.make(n)
692	shrVU(z, x[m-n:], s%_W)
693
694	return z.norm()
695}
696
697func (z nat) setBit(x nat, i uint, b uint) nat {
698	j := int(i / _W)
699	m := Word(1) << (i % _W)
700	n := len(x)
701	switch b {
702	case 0:
703		z = z.make(n)
704		copy(z, x)
705		if j >= n {
706			// no need to grow
707			return z
708		}
709		z[j] &^= m
710		return z.norm()
711	case 1:
712		if j >= n {
713			z = z.make(j + 1)
714			z[n:].clear()
715		} else {
716			z = z.make(n)
717		}
718		copy(z, x)
719		z[j] |= m
720		// no need to normalize
721		return z
722	}
723	panic("set bit is not 0 or 1")
724}
725
726// bit returns the value of the i'th bit, with lsb == bit 0.
727func (x nat) bit(i uint) uint {
728	j := i / _W
729	if j >= uint(len(x)) {
730		return 0
731	}
732	// 0 <= j < len(x)
733	return uint(x[j] >> (i % _W) & 1)
734}
735
736// sticky returns 1 if there's a 1 bit within the
737// i least significant bits, otherwise it returns 0.
738func (x nat) sticky(i uint) uint {
739	j := i / _W
740	if j >= uint(len(x)) {
741		if len(x) == 0 {
742			return 0
743		}
744		return 1
745	}
746	// 0 <= j < len(x)
747	for _, x := range x[:j] {
748		if x != 0 {
749			return 1
750		}
751	}
752	if x[j]<<(_W-i%_W) != 0 {
753		return 1
754	}
755	return 0
756}
757
758func (z nat) and(x, y nat) nat {
759	m := len(x)
760	n := len(y)
761	if m > n {
762		m = n
763	}
764	// m <= n
765
766	z = z.make(m)
767	for i := 0; i < m; i++ {
768		z[i] = x[i] & y[i]
769	}
770
771	return z.norm()
772}
773
774func (z nat) andNot(x, y nat) nat {
775	m := len(x)
776	n := len(y)
777	if n > m {
778		n = m
779	}
780	// m >= n
781
782	z = z.make(m)
783	for i := 0; i < n; i++ {
784		z[i] = x[i] &^ y[i]
785	}
786	copy(z[n:m], x[n:m])
787
788	return z.norm()
789}
790
791func (z nat) or(x, y nat) nat {
792	m := len(x)
793	n := len(y)
794	s := x
795	if m < n {
796		n, m = m, n
797		s = y
798	}
799	// m >= n
800
801	z = z.make(m)
802	for i := 0; i < n; i++ {
803		z[i] = x[i] | y[i]
804	}
805	copy(z[n:m], s[n:m])
806
807	return z.norm()
808}
809
810func (z nat) xor(x, y nat) nat {
811	m := len(x)
812	n := len(y)
813	s := x
814	if m < n {
815		n, m = m, n
816		s = y
817	}
818	// m >= n
819
820	z = z.make(m)
821	for i := 0; i < n; i++ {
822		z[i] = x[i] ^ y[i]
823	}
824	copy(z[n:m], s[n:m])
825
826	return z.norm()
827}
828
829// greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2)
830func greaterThan(x1, x2, y1, y2 Word) bool {
831	return x1 > y1 || x1 == y1 && x2 > y2
832}
833
834// modW returns x % d.
835func (x nat) modW(d Word) (r Word) {
836	// TODO(agl): we don't actually need to store the q value.
837	var q nat
838	q = q.make(len(x))
839	return divWVW(q, 0, x, d)
840}
841
842// random creates a random integer in [0..limit), using the space in z if
843// possible. n is the bit length of limit.
844func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
845	if alias(z, limit) {
846		z = nil // z is an alias for limit - cannot reuse
847	}
848	z = z.make(len(limit))
849
850	bitLengthOfMSW := uint(n % _W)
851	if bitLengthOfMSW == 0 {
852		bitLengthOfMSW = _W
853	}
854	mask := Word((1 << bitLengthOfMSW) - 1)
855
856	for {
857		switch _W {
858		case 32:
859			for i := range z {
860				z[i] = Word(rand.Uint32())
861			}
862		case 64:
863			for i := range z {
864				z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
865			}
866		default:
867			panic("unknown word size")
868		}
869		z[len(limit)-1] &= mask
870		if z.cmp(limit) < 0 {
871			break
872		}
873	}
874
875	return z.norm()
876}
877
878// If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
879// otherwise it sets z to x**y. The result is the value of z.
880func (z nat) expNN(x, y, m nat) nat {
881	if alias(z, x) || alias(z, y) {
882		// We cannot allow in-place modification of x or y.
883		z = nil
884	}
885
886	// x**y mod 1 == 0
887	if len(m) == 1 && m[0] == 1 {
888		return z.setWord(0)
889	}
890	// m == 0 || m > 1
891
892	// x**0 == 1
893	if len(y) == 0 {
894		return z.setWord(1)
895	}
896	// y > 0
897
898	// x**1 mod m == x mod m
899	if len(y) == 1 && y[0] == 1 && len(m) != 0 {
900		_, z = z.div(z, x, m)
901		return z
902	}
903	// y > 1
904
905	if len(m) != 0 {
906		// We likely end up being as long as the modulus.
907		z = z.make(len(m))
908	}
909	z = z.set(x)
910
911	// If the base is non-trivial and the exponent is large, we use
912	// 4-bit, windowed exponentiation. This involves precomputing 14 values
913	// (x^2...x^15) but then reduces the number of multiply-reduces by a
914	// third. Even for a 32-bit exponent, this reduces the number of
915	// operations. Uses Montgomery method for odd moduli.
916	if x.cmp(natOne) > 0 && len(y) > 1 && len(m) > 0 {
917		if m[0]&1 == 1 {
918			return z.expNNMontgomery(x, y, m)
919		}
920		return z.expNNWindowed(x, y, m)
921	}
922
923	v := y[len(y)-1] // v > 0 because y is normalized and y > 0
924	shift := nlz(v) + 1
925	v <<= shift
926	var q nat
927
928	const mask = 1 << (_W - 1)
929
930	// We walk through the bits of the exponent one by one. Each time we
931	// see a bit, we square, thus doubling the power. If the bit is a one,
932	// we also multiply by x, thus adding one to the power.
933
934	w := _W - int(shift)
935	// zz and r are used to avoid allocating in mul and div as
936	// otherwise the arguments would alias.
937	var zz, r nat
938	for j := 0; j < w; j++ {
939		zz = zz.mul(z, z)
940		zz, z = z, zz
941
942		if v&mask != 0 {
943			zz = zz.mul(z, x)
944			zz, z = z, zz
945		}
946
947		if len(m) != 0 {
948			zz, r = zz.div(r, z, m)
949			zz, r, q, z = q, z, zz, r
950		}
951
952		v <<= 1
953	}
954
955	for i := len(y) - 2; i >= 0; i-- {
956		v = y[i]
957
958		for j := 0; j < _W; j++ {
959			zz = zz.mul(z, z)
960			zz, z = z, zz
961
962			if v&mask != 0 {
963				zz = zz.mul(z, x)
964				zz, z = z, zz
965			}
966
967			if len(m) != 0 {
968				zz, r = zz.div(r, z, m)
969				zz, r, q, z = q, z, zz, r
970			}
971
972			v <<= 1
973		}
974	}
975
976	return z.norm()
977}
978
979// expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
980func (z nat) expNNWindowed(x, y, m nat) nat {
981	// zz and r are used to avoid allocating in mul and div as otherwise
982	// the arguments would alias.
983	var zz, r nat
984
985	const n = 4
986	// powers[i] contains x^i.
987	var powers [1 << n]nat
988	powers[0] = natOne
989	powers[1] = x
990	for i := 2; i < 1<<n; i += 2 {
991		p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
992		*p = p.mul(*p2, *p2)
993		zz, r = zz.div(r, *p, m)
994		*p, r = r, *p
995		*p1 = p1.mul(*p, x)
996		zz, r = zz.div(r, *p1, m)
997		*p1, r = r, *p1
998	}
999
1000	z = z.setWord(1)
1001
1002	for i := len(y) - 1; i >= 0; i-- {
1003		yi := y[i]
1004		for j := 0; j < _W; j += n {
1005			if i != len(y)-1 || j != 0 {
1006				// Unrolled loop for significant performance
1007				// gain. Use go test -bench=".*" in crypto/rsa
1008				// to check performance before making changes.
1009				zz = zz.mul(z, z)
1010				zz, z = z, zz
1011				zz, r = zz.div(r, z, m)
1012				z, r = r, z
1013
1014				zz = zz.mul(z, z)
1015				zz, z = z, zz
1016				zz, r = zz.div(r, z, m)
1017				z, r = r, z
1018
1019				zz = zz.mul(z, z)
1020				zz, z = z, zz
1021				zz, r = zz.div(r, z, m)
1022				z, r = r, z
1023
1024				zz = zz.mul(z, z)
1025				zz, z = z, zz
1026				zz, r = zz.div(r, z, m)
1027				z, r = r, z
1028			}
1029
1030			zz = zz.mul(z, powers[yi>>(_W-n)])
1031			zz, z = z, zz
1032			zz, r = zz.div(r, z, m)
1033			z, r = r, z
1034
1035			yi <<= n
1036		}
1037	}
1038
1039	return z.norm()
1040}
1041
1042// expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
1043// Uses Montgomery representation.
1044func (z nat) expNNMontgomery(x, y, m nat) nat {
1045	numWords := len(m)
1046
1047	// We want the lengths of x and m to be equal.
1048	// It is OK if x >= m as long as len(x) == len(m).
1049	if len(x) > numWords {
1050		_, x = nat(nil).div(nil, x, m)
1051		// Note: now len(x) <= numWords, not guaranteed ==.
1052	}
1053	if len(x) < numWords {
1054		rr := make(nat, numWords)
1055		copy(rr, x)
1056		x = rr
1057	}
1058
1059	// Ideally the precomputations would be performed outside, and reused
1060	// k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
1061	// Iteration for Multiplicative Inverses Modulo Prime Powers".
1062	k0 := 2 - m[0]
1063	t := m[0] - 1
1064	for i := 1; i < _W; i <<= 1 {
1065		t *= t
1066		k0 *= (t + 1)
1067	}
1068	k0 = -k0
1069
1070	// RR = 2**(2*_W*len(m)) mod m
1071	RR := nat(nil).setWord(1)
1072	zz := nat(nil).shl(RR, uint(2*numWords*_W))
1073	_, RR = RR.div(RR, zz, m)
1074	if len(RR) < numWords {
1075		zz = zz.make(numWords)
1076		copy(zz, RR)
1077		RR = zz
1078	}
1079	// one = 1, with equal length to that of m
1080	one := make(nat, numWords)
1081	one[0] = 1
1082
1083	const n = 4
1084	// powers[i] contains x^i
1085	var powers [1 << n]nat
1086	powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
1087	powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
1088	for i := 2; i < 1<<n; i++ {
1089		powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
1090	}
1091
1092	// initialize z = 1 (Montgomery 1)
1093	z = z.make(numWords)
1094	copy(z, powers[0])
1095
1096	zz = zz.make(numWords)
1097
1098	// same windowed exponent, but with Montgomery multiplications
1099	for i := len(y) - 1; i >= 0; i-- {
1100		yi := y[i]
1101		for j := 0; j < _W; j += n {
1102			if i != len(y)-1 || j != 0 {
1103				zz = zz.montgomery(z, z, m, k0, numWords)
1104				z = z.montgomery(zz, zz, m, k0, numWords)
1105				zz = zz.montgomery(z, z, m, k0, numWords)
1106				z = z.montgomery(zz, zz, m, k0, numWords)
1107			}
1108			zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
1109			z, zz = zz, z
1110			yi <<= n
1111		}
1112	}
1113	// convert to regular number
1114	zz = zz.montgomery(z, one, m, k0, numWords)
1115
1116	// One last reduction, just in case.
1117	// See golang.org/issue/13907.
1118	if zz.cmp(m) >= 0 {
1119		// Common case is m has high bit set; in that case,
1120		// since zz is the same length as m, there can be just
1121		// one multiple of m to remove. Just subtract.
1122		// We think that the subtract should be sufficient in general,
1123		// so do that unconditionally, but double-check,
1124		// in case our beliefs are wrong.
1125		// The div is not expected to be reached.
1126		zz = zz.sub(zz, m)
1127		if zz.cmp(m) >= 0 {
1128			_, zz = nat(nil).div(nil, zz, m)
1129		}
1130	}
1131
1132	return zz.norm()
1133}
1134
1135// bytes writes the value of z into buf using big-endian encoding.
1136// len(buf) must be >= len(z)*_S. The value of z is encoded in the
1137// slice buf[i:]. The number i of unused bytes at the beginning of
1138// buf is returned as result.
1139func (z nat) bytes(buf []byte) (i int) {
1140	i = len(buf)
1141	for _, d := range z {
1142		for j := 0; j < _S; j++ {
1143			i--
1144			buf[i] = byte(d)
1145			d >>= 8
1146		}
1147	}
1148
1149	for i < len(buf) && buf[i] == 0 {
1150		i++
1151	}
1152
1153	return
1154}
1155
1156// setBytes interprets buf as the bytes of a big-endian unsigned
1157// integer, sets z to that value, and returns z.
1158func (z nat) setBytes(buf []byte) nat {
1159	z = z.make((len(buf) + _S - 1) / _S)
1160
1161	k := 0
1162	s := uint(0)
1163	var d Word
1164	for i := len(buf); i > 0; i-- {
1165		d |= Word(buf[i-1]) << s
1166		if s += 8; s == _S*8 {
1167			z[k] = d
1168			k++
1169			s = 0
1170			d = 0
1171		}
1172	}
1173	if k < len(z) {
1174		z[k] = d
1175	}
1176
1177	return z.norm()
1178}
1179
1180// sqrt sets z = ⌊√x⌋
1181func (z nat) sqrt(x nat) nat {
1182	if x.cmp(natOne) <= 0 {
1183		return z.set(x)
1184	}
1185	if alias(z, x) {
1186		z = nil
1187	}
1188
1189	// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
1190	// See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
1191	// https://members.loria.fr/PZimmermann/mca/pub226.html
1192	// If x is one less than a perfect square, the sequence oscillates between the correct z and z+1;
1193	// otherwise it converges to the correct z and stays there.
1194	var z1, z2 nat
1195	z1 = z
1196	z1 = z1.setUint64(1)
1197	z1 = z1.shl(z1, uint(x.bitLen()/2+1)) // must be ≥ √x
1198	for n := 0; ; n++ {
1199		z2, _ = z2.div(nil, x, z1)
1200		z2 = z2.add(z2, z1)
1201		z2 = z2.shr(z2, 1)
1202		if z2.cmp(z1) >= 0 {
1203			// z1 is answer.
1204			// Figure out whether z1 or z2 is currently aliased to z by looking at loop count.
1205			if n&1 == 0 {
1206				return z1
1207			}
1208			return z.set(z1)
1209		}
1210		z1, z2 = z2, z1
1211	}
1212}
1213