1 #ifndef MPU_MONTMATH_H
2 #define MPU_MONTMATH_H
3 
4 #include "ptypes.h"
5 #include "mulmod.h"
6 
7 #if BITS_PER_WORD == 64 && HAVE_STD_U64 && defined(__GNUC__) && defined(__x86_64__)
8 #define USE_MONTMATH 1
9 #else
10 #define USE_MONTMATH 0
11 #endif
12 
13 #if USE_MONTMATH
14 
15 #define mont_get1(n)              _u64div(1,n)
16 /* Must have npi = mont_inverse(n), mont1 = mont_get1(n) */
17 #define mont_get2(n)              addmod(mont1,mont1,n)
18 #define mont_geta(a,n)            mulmod(a,mont1,n)
19 #define mont_mulmod(a,b,n)        _mulredc(a,b,n,npi)
20 #define mont_sqrmod(a,n)          _mulredc(a,a,n,npi)
21 #define mont_powmod(a,k,n)        _powredc(a,k,mont1,n,npi)
22 #define mont_recover(a,n)         mont_mulmod(a,1,n)
23 
24 /* Save one branch if desired by calling directly */
25 #define mont_mulmod63(a,b,n)        _mulredc63(a,b,n,npi)
26 #define mont_mulmod64(a,b,n)        _mulredc64(a,b,n,npi)
27 
28 /* See https://arxiv.org/pdf/1303.0328.pdf for lots of details on this.
29  * The 128-entry table solution is about 20% faster */
mont_inverse(const uint64_t n)30 static INLINE uint64_t mont_inverse(const uint64_t n) {
31   uint64_t ret = (3*n) ^ 2;
32   ret *= (uint64_t)2 - n * ret;
33   ret *= (uint64_t)2 - n * ret;
34   ret *= (uint64_t)2 - n * ret;
35   ret *= (uint64_t)2 - n * ret;
36   return (uint64_t)0 - ret;
37 }
38 
39 /* MULREDC asm from Ben Buhrow */
_mulredc63(uint64_t a,uint64_t b,uint64_t n,uint64_t npi)40 static INLINE uint64_t _mulredc63(uint64_t a, uint64_t b, uint64_t n, uint64_t npi) {
41     asm("mulq %2 \n\t"
42         "movq %%rax, %%r10 \n\t"
43         "movq %%rdx, %%r11 \n\t"
44         "mulq %3 \n\t"
45         "mulq %4 \n\t"
46         "addq %%r10, %%rax \n\t"
47         "adcq %%r11, %%rdx \n\t"
48         "xorq %%rax, %%rax \n\t"
49         "subq %4, %%rdx \n\t"
50         "cmovc %4, %%rax \n\t"
51         "addq %%rdx, %%rax \n\t"
52         : "=a"(a)
53         : "0"(a), "r"(b), "r"(npi), "r"(n)
54         : "rdx", "r10", "r11", "cc");
55   return a;
56 }
_mulredc64(uint64_t a,uint64_t b,uint64_t n,uint64_t npi)57 static INLINE uint64_t _mulredc64(uint64_t a, uint64_t b, uint64_t n, uint64_t npi) {
58     asm("mulq %1 \n\t"
59         "movq %%rax, %%r10 \n\t"
60         "movq %%rdx, %%r11 \n\t"
61         "movq $0, %%r12 \n\t"
62         "mulq %2 \n\t"
63         "mulq %3 \n\t"
64         "addq %%r10, %%rax \n\t"
65         "adcq %%r11, %%rdx \n\t"
66         "cmovae %3, %%r12 \n\t"
67         "xorq %%rax, %%rax \n\t"
68         "subq %3, %%rdx \n\t"
69         "cmovc %%r12, %%rax \n\t"
70         "addq %%rdx, %%rax \n\t"
71         : "+&a"(a)
72         : "r"(b), "r"(npi), "r"(n)
73         : "rdx", "r10", "r11", "r12", "cc");
74   return a;
75 }
76 #define _mulredc(a,b,n,npi) ((n & 0x8000000000000000ULL) ? _mulredc64(a,b,n,npi) : _mulredc63(a,b,n,npi))
77 
_powredc(uint64_t a,uint64_t k,uint64_t one,uint64_t n,uint64_t npi)78 static INLINE UV _powredc(uint64_t a, uint64_t k, uint64_t one, uint64_t n, uint64_t npi) {
79   uint64_t t = one;
80   while (k) {
81     if (k & 1) t = mont_mulmod(t, a, n);
82     k >>= 1;
83     if (k)     a = mont_sqrmod(a, n);
84   }
85   return t;
86 }
87 
_u64div(uint64_t c,uint64_t n)88 static INLINE uint64_t _u64div(uint64_t c, uint64_t n) {
89   asm("divq %4"
90       : "=a"(c), "=d"(n)
91       : "1"(c), "0"(0), "r"(n));
92   return n;
93 }
94 
95 #endif /* use_montmath */
96 
97 #endif
98