1 /* ecc-mod-inv.c
2 
3    Copyright (C) 2013, 2014 Niels Möller
4 
5    This file is part of GNU Nettle.
6 
7    GNU Nettle is free software: you can redistribute it and/or
8    modify it under the terms of either:
9 
10      * the GNU Lesser General Public License as published by the Free
11        Software Foundation; either version 3 of the License, or (at your
12        option) any later version.
13 
14    or
15 
16      * the GNU General Public License as published by the Free
17        Software Foundation; either version 2 of the License, or (at your
18        option) any later version.
19 
20    or both in parallel, as here.
21 
22    GNU Nettle is distributed in the hope that it will be useful,
23    but WITHOUT ANY WARRANTY; without even the implied warranty of
24    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
25    General Public License for more details.
26 
27    You should have received copies of the GNU General Public License and
28    the GNU Lesser General Public License along with this program.  If
29    not, see http://www.gnu.org/licenses/.
30 */
31 
32 /* Development of Nettle's ECC support was funded by the .SE Internet Fund. */
33 
34 #if HAVE_CONFIG_H
35 # include "config.h"
36 #endif
37 
38 #include <assert.h>
39 
40 #include "ecc-internal.h"
41 
42 static void
cnd_neg(int cnd,mp_limb_t * rp,const mp_limb_t * ap,mp_size_t n)43 cnd_neg (int cnd, mp_limb_t *rp, const mp_limb_t *ap, mp_size_t n)
44 {
45   mp_limb_t cy = (cnd != 0);
46   mp_limb_t mask = -cy;
47   mp_size_t i;
48 
49   for (i = 0; i < n; i++)
50     {
51       mp_limb_t r = (ap[i] ^ mask) + cy;
52       cy = r < cy;
53       rp[i] = r;
54     }
55 }
56 
57 /* Compute a^{-1} mod m, with running time depending only on the size.
58    Returns zero if a == 0 (mod m), to be consistent with a^{phi(m)-1}.
59    Also needs (m+1)/2, and m must be odd.
60 
61    Needs 2n limbs available at rp, and 2n additional scratch limbs.
62 */
63 
64 /* FIXME: Could use mpn_sec_invert (in GMP-6), but with a bit more
65    scratch need since it doesn't precompute (m+1)/2. */
66 void
ecc_mod_inv(const struct ecc_modulo * m,mp_limb_t * vp,const mp_limb_t * in_ap,mp_limb_t * scratch)67 ecc_mod_inv (const struct ecc_modulo *m,
68 	     mp_limb_t *vp, const mp_limb_t *in_ap,
69 	     mp_limb_t *scratch)
70 {
71 #define ap scratch
72 #define bp (scratch + n)
73 #define up (vp + n)
74 
75   mp_size_t n = m->size;
76   /* Avoid the mp_bitcnt_t type for compatibility with older GMP
77      versions. */
78   unsigned i;
79 
80   /* Maintain
81 
82        a = u * orig_a (mod m)
83        b = v * orig_a (mod m)
84 
85      and b odd at all times. Initially,
86 
87        a = a_orig, u = 1
88        b = m,      v = 0
89      */
90 
91   assert (ap != vp);
92 
93   up[0] = 1;
94   mpn_zero (up+1, n - 1);
95   mpn_copyi (bp, m->m, n);
96   mpn_zero (vp, n);
97   mpn_copyi (ap, in_ap, n);
98 
99   for (i = m->bit_size + GMP_NUMB_BITS * n; i-- > 0; )
100     {
101       mp_limb_t odd, swap, cy;
102 
103       /* Always maintain b odd. The logic of the iteration is as
104 	 follows. For a, b:
105 
106 	   odd = a & 1
107 	   a -= odd * b
108 	   if (underflow from a-b)
109 	     {
110 	       b += a, assigns old a
111 	       a = B^n-a
112 	     }
113 
114 	   a /= 2
115 
116 	 For u, v:
117 
118 	   if (underflow from a - b)
119 	     swap u, v
120 	   u -= odd * v
121 	   if (underflow from u - v)
122 	     u += m
123 
124 	   u /= 2
125 	   if (a one bit was shifted out)
126 	     u += (m+1)/2
127 
128 	 As long as a > 0, the quantity
129 
130 	   (bitsize of a) + (bitsize of b)
131 
132 	 is reduced by at least one bit per iteration, hence after
133          (bit_size of orig_a) + (bit_size of m) - 1 iterations we
134          surely have a = 0. Then b = gcd(orig_a, m) and if b = 1 then
135          also v = orig_a^{-1} (mod m)
136       */
137 
138       assert (bp[0] & 1);
139       odd = ap[0] & 1;
140 
141       swap = cnd_sub_n (odd, ap, bp, n);
142       cnd_add_n (swap, bp, ap, n);
143       cnd_neg (swap, ap, ap, n);
144 
145       cnd_swap (swap, up, vp, n);
146       cy = cnd_sub_n (odd, up, vp, n);
147       cy -= cnd_add_n (cy, up, m->m, n);
148       assert (cy == 0);
149 
150       cy = mpn_rshift (ap, ap, n, 1);
151       assert (cy == 0);
152       cy = mpn_rshift (up, up, n, 1);
153       cy = cnd_add_n (cy, up, m->mp1h, n);
154       assert (cy == 0);
155     }
156   assert ( (ap[0] | ap[n-1]) == 0);
157 #undef ap
158 #undef bp
159 #undef up
160 }
161