1 #ifndef MPU_MONTMATH_H
2 #define MPU_MONTMATH_H
3
4 #include "ptypes.h"
5 #include "mulmod.h"
6
7 #if BITS_PER_WORD == 64 && HAVE_STD_U64 && defined(__GNUC__) && defined(__x86_64__)
8 #define USE_MONTMATH 1
9 #else
10 #define USE_MONTMATH 0
11 #endif
12
13 #if USE_MONTMATH
14
15 #define mont_get1(n) _u64div(1,n)
16 /* Must have npi = mont_inverse(n), mont1 = mont_get1(n) */
17 #define mont_get2(n) addmod(mont1,mont1,n)
18 #define mont_geta(a,n) mulmod(a,mont1,n)
19 #define mont_mulmod(a,b,n) _mulredc(a,b,n,npi)
20 #define mont_sqrmod(a,n) _mulredc(a,a,n,npi)
21 #define mont_powmod(a,k,n) _powredc(a,k,mont1,n,npi)
22 #define mont_recover(a,n) mont_mulmod(a,1,n)
23
24 /* Save one branch if desired by calling directly */
25 #define mont_mulmod63(a,b,n) _mulredc63(a,b,n,npi)
26 #define mont_mulmod64(a,b,n) _mulredc64(a,b,n,npi)
27
28 /* See https://arxiv.org/pdf/1303.0328.pdf for lots of details on this.
29 * The 128-entry table solution is about 20% faster */
mont_inverse(const uint64_t n)30 static INLINE uint64_t mont_inverse(const uint64_t n) {
31 uint64_t ret = (3*n) ^ 2;
32 ret *= (uint64_t)2 - n * ret;
33 ret *= (uint64_t)2 - n * ret;
34 ret *= (uint64_t)2 - n * ret;
35 ret *= (uint64_t)2 - n * ret;
36 return (uint64_t)0 - ret;
37 }
38
39 /* MULREDC asm from Ben Buhrow */
_mulredc63(uint64_t a,uint64_t b,uint64_t n,uint64_t npi)40 static INLINE uint64_t _mulredc63(uint64_t a, uint64_t b, uint64_t n, uint64_t npi) {
41 asm("mulq %2 \n\t"
42 "movq %%rax, %%r10 \n\t"
43 "movq %%rdx, %%r11 \n\t"
44 "mulq %3 \n\t"
45 "mulq %4 \n\t"
46 "addq %%r10, %%rax \n\t"
47 "adcq %%r11, %%rdx \n\t"
48 "xorq %%rax, %%rax \n\t"
49 "subq %4, %%rdx \n\t"
50 "cmovc %4, %%rax \n\t"
51 "addq %%rdx, %%rax \n\t"
52 : "=a"(a)
53 : "0"(a), "r"(b), "r"(npi), "r"(n)
54 : "rdx", "r10", "r11", "cc");
55 return a;
56 }
_mulredc64(uint64_t a,uint64_t b,uint64_t n,uint64_t npi)57 static INLINE uint64_t _mulredc64(uint64_t a, uint64_t b, uint64_t n, uint64_t npi) {
58 asm("mulq %1 \n\t"
59 "movq %%rax, %%r10 \n\t"
60 "movq %%rdx, %%r11 \n\t"
61 "movq $0, %%r12 \n\t"
62 "mulq %2 \n\t"
63 "mulq %3 \n\t"
64 "addq %%r10, %%rax \n\t"
65 "adcq %%r11, %%rdx \n\t"
66 "cmovae %3, %%r12 \n\t"
67 "xorq %%rax, %%rax \n\t"
68 "subq %3, %%rdx \n\t"
69 "cmovc %%r12, %%rax \n\t"
70 "addq %%rdx, %%rax \n\t"
71 : "+&a"(a)
72 : "r"(b), "r"(npi), "r"(n)
73 : "rdx", "r10", "r11", "r12", "cc");
74 return a;
75 }
76 #define _mulredc(a,b,n,npi) ((n & 0x8000000000000000ULL) ? _mulredc64(a,b,n,npi) : _mulredc63(a,b,n,npi))
77
_powredc(uint64_t a,uint64_t k,uint64_t one,uint64_t n,uint64_t npi)78 static INLINE UV _powredc(uint64_t a, uint64_t k, uint64_t one, uint64_t n, uint64_t npi) {
79 uint64_t t = one;
80 while (k) {
81 if (k & 1) t = mont_mulmod(t, a, n);
82 k >>= 1;
83 if (k) a = mont_sqrmod(a, n);
84 }
85 return t;
86 }
87
_u64div(uint64_t c,uint64_t n)88 static INLINE uint64_t _u64div(uint64_t c, uint64_t n) {
89 asm("divq %4"
90 : "=a"(c), "=d"(n)
91 : "1"(c), "0"(0), "r"(n));
92 return n;
93 }
94
95 #endif /* use_montmath */
96
97 #endif
98