1 /*
2 * Multiplication and Squaring
3 * (C) 1999-2010,2018 Jack Lloyd
4 *     2016 Matthias Gierlings
5 *
6 * Botan is released under the Simplified BSD License (see license.txt)
7 */
8 
9 #include <botan/internal/mp_core.h>
10 #include <botan/internal/mp_asmi.h>
11 #include <botan/internal/ct_utils.h>
12 #include <botan/mem_ops.h>
13 #include <botan/exceptn.h>
14 
15 namespace Botan {
16 
17 namespace {
18 
19 const size_t KARATSUBA_MULTIPLY_THRESHOLD = 32;
20 const size_t KARATSUBA_SQUARE_THRESHOLD = 32;
21 
22 /*
23 * Simple O(N^2) Multiplication
24 */
basecase_mul(word z[],size_t z_size,const word x[],size_t x_size,const word y[],size_t y_size)25 void basecase_mul(word z[], size_t z_size,
26                   const word x[], size_t x_size,
27                   const word y[], size_t y_size)
28    {
29    if(z_size < x_size + y_size)
30       throw Invalid_Argument("basecase_mul z_size too small");
31 
32    const size_t x_size_8 = x_size - (x_size % 8);
33 
34    clear_mem(z, z_size);
35 
36    for(size_t i = 0; i != y_size; ++i)
37       {
38       const word y_i = y[i];
39 
40       word carry = 0;
41 
42       for(size_t j = 0; j != x_size_8; j += 8)
43          carry = word8_madd3(z + i + j, x + j, y_i, carry);
44 
45       for(size_t j = x_size_8; j != x_size; ++j)
46          z[i+j] = word_madd3(x[j], y_i, z[i+j], &carry);
47 
48       z[x_size+i] = carry;
49       }
50    }
51 
basecase_sqr(word z[],size_t z_size,const word x[],size_t x_size)52 void basecase_sqr(word z[], size_t z_size,
53                   const word x[], size_t x_size)
54    {
55    if(z_size < 2*x_size)
56       throw Invalid_Argument("basecase_sqr z_size too small");
57 
58    const size_t x_size_8 = x_size - (x_size % 8);
59 
60    clear_mem(z, z_size);
61 
62    for(size_t i = 0; i != x_size; ++i)
63       {
64       const word x_i = x[i];
65 
66       word carry = 0;
67 
68       for(size_t j = 0; j != x_size_8; j += 8)
69          carry = word8_madd3(z + i + j, x + j, x_i, carry);
70 
71       for(size_t j = x_size_8; j != x_size; ++j)
72          z[i+j] = word_madd3(x[j], x_i, z[i+j], &carry);
73 
74       z[x_size+i] = carry;
75       }
76    }
77 
78 /*
79 * Karatsuba Multiplication Operation
80 */
karatsuba_mul(word z[],const word x[],const word y[],size_t N,word workspace[])81 void karatsuba_mul(word z[], const word x[], const word y[], size_t N,
82                    word workspace[])
83    {
84    if(N < KARATSUBA_MULTIPLY_THRESHOLD || N % 2)
85       {
86       switch(N)
87          {
88          case 6:
89             return bigint_comba_mul6(z, x, y);
90          case 8:
91             return bigint_comba_mul8(z, x, y);
92          case 9:
93             return bigint_comba_mul9(z, x, y);
94          case 16:
95             return bigint_comba_mul16(z, x, y);
96          case 24:
97             return bigint_comba_mul24(z, x, y);
98          default:
99             return basecase_mul(z, 2*N, x, N, y, N);
100          }
101       }
102 
103    const size_t N2 = N / 2;
104 
105    const word* x0 = x;
106    const word* x1 = x + N2;
107    const word* y0 = y;
108    const word* y1 = y + N2;
109    word* z0 = z;
110    word* z1 = z + N;
111 
112    word* ws0 = workspace;
113    word* ws1 = workspace + N;
114 
115    clear_mem(workspace, 2*N);
116 
117    /*
118    * If either of cmp0 or cmp1 is zero then z0 or z1 resp is zero here,
119    * resulting in a no-op - z0*z1 will be equal to zero so we don't need to do
120    * anything, clear_mem above already set the correct result.
121    *
122    * However we ignore the result of the comparisons and always perform the
123    * subtractions and recursively multiply to avoid the timing channel.
124    */
125 
126    // First compute (X_lo - X_hi)*(Y_hi - Y_lo)
127    const auto cmp0 = bigint_sub_abs(z0, x0, x1, N2, workspace);
128    const auto cmp1 = bigint_sub_abs(z1, y1, y0, N2, workspace);
129    const auto neg_mask = ~(cmp0 ^ cmp1);
130 
131    karatsuba_mul(ws0, z0, z1, N2, ws1);
132 
133    // Compute X_lo * Y_lo
134    karatsuba_mul(z0, x0, y0, N2, ws1);
135 
136    // Compute X_hi * Y_hi
137    karatsuba_mul(z1, x1, y1, N2, ws1);
138 
139    const word ws_carry = bigint_add3_nc(ws1, z0, N, z1, N);
140    word z_carry = bigint_add2_nc(z + N2, N, ws1, N);
141 
142    z_carry += bigint_add2_nc(z + N + N2, N2, &ws_carry, 1);
143    bigint_add2_nc(z + N + N2, N2, &z_carry, 1);
144 
145    clear_mem(workspace + N, N2);
146 
147    bigint_cnd_add_or_sub(neg_mask, z + N2, workspace, 2*N-N2);
148    }
149 
150 /*
151 * Karatsuba Squaring Operation
152 */
karatsuba_sqr(word z[],const word x[],size_t N,word workspace[])153 void karatsuba_sqr(word z[], const word x[], size_t N, word workspace[])
154    {
155    if(N < KARATSUBA_SQUARE_THRESHOLD || N % 2)
156       {
157       switch(N)
158          {
159          case 6:
160             return bigint_comba_sqr6(z, x);
161          case 8:
162             return bigint_comba_sqr8(z, x);
163          case 9:
164             return bigint_comba_sqr9(z, x);
165          case 16:
166             return bigint_comba_sqr16(z, x);
167          case 24:
168             return bigint_comba_sqr24(z, x);
169          default:
170             return basecase_sqr(z, 2*N, x, N);
171          }
172       }
173 
174    const size_t N2 = N / 2;
175 
176    const word* x0 = x;
177    const word* x1 = x + N2;
178    word* z0 = z;
179    word* z1 = z + N;
180 
181    word* ws0 = workspace;
182    word* ws1 = workspace + N;
183 
184    clear_mem(workspace, 2*N);
185 
186    // See comment in karatsuba_mul
187    bigint_sub_abs(z0, x0, x1, N2, workspace);
188    karatsuba_sqr(ws0, z0, N2, ws1);
189 
190    karatsuba_sqr(z0, x0, N2, ws1);
191    karatsuba_sqr(z1, x1, N2, ws1);
192 
193    const word ws_carry = bigint_add3_nc(ws1, z0, N, z1, N);
194    word z_carry = bigint_add2_nc(z + N2, N, ws1, N);
195 
196    z_carry += bigint_add2_nc(z + N + N2, N2, &ws_carry, 1);
197    bigint_add2_nc(z + N + N2, N2, &z_carry, 1);
198 
199    /*
200    * This is only actually required if cmp (result of bigint_sub_abs) is != 0,
201    * however if cmp==0 then ws0[0:N] == 0 and avoiding the jump hides a
202    * timing channel.
203    */
204    bigint_sub2(z + N2, 2*N-N2, ws0, N);
205    }
206 
207 /*
208 * Pick a good size for the Karatsuba multiply
209 */
karatsuba_size(size_t z_size,size_t x_size,size_t x_sw,size_t y_size,size_t y_sw)210 size_t karatsuba_size(size_t z_size,
211                       size_t x_size, size_t x_sw,
212                       size_t y_size, size_t y_sw)
213    {
214    if(x_sw > x_size || x_sw > y_size || y_sw > x_size || y_sw > y_size)
215       return 0;
216 
217    if(((x_size == x_sw) && (x_size % 2)) ||
218       ((y_size == y_sw) && (y_size % 2)))
219       return 0;
220 
221    const size_t start = (x_sw > y_sw) ? x_sw : y_sw;
222    const size_t end = (x_size < y_size) ? x_size : y_size;
223 
224    if(start == end)
225       {
226       if(start % 2)
227          return 0;
228       return start;
229       }
230 
231    for(size_t j = start; j <= end; ++j)
232       {
233       if(j % 2)
234          continue;
235 
236       if(2*j > z_size)
237          return 0;
238 
239       if(x_sw <= j && j <= x_size && y_sw <= j && j <= y_size)
240          {
241          if(j % 4 == 2 &&
242             (j+2) <= x_size && (j+2) <= y_size && 2*(j+2) <= z_size)
243             return j+2;
244          return j;
245          }
246       }
247 
248    return 0;
249    }
250 
251 /*
252 * Pick a good size for the Karatsuba squaring
253 */
karatsuba_size(size_t z_size,size_t x_size,size_t x_sw)254 size_t karatsuba_size(size_t z_size, size_t x_size, size_t x_sw)
255    {
256    if(x_sw == x_size)
257       {
258       if(x_sw % 2)
259          return 0;
260       return x_sw;
261       }
262 
263    for(size_t j = x_sw; j <= x_size; ++j)
264       {
265       if(j % 2)
266          continue;
267 
268       if(2*j > z_size)
269          return 0;
270 
271       if(j % 4 == 2 && (j+2) <= x_size && 2*(j+2) <= z_size)
272          return j+2;
273       return j;
274       }
275 
276    return 0;
277    }
278 
279 template<size_t SZ>
sized_for_comba_mul(size_t x_sw,size_t x_size,size_t y_sw,size_t y_size,size_t z_size)280 inline bool sized_for_comba_mul(size_t x_sw, size_t x_size,
281                                 size_t y_sw, size_t y_size,
282                                 size_t z_size)
283    {
284    return (x_sw <= SZ && x_size >= SZ &&
285            y_sw <= SZ && y_size >= SZ &&
286            z_size >= 2*SZ);
287    }
288 
289 template<size_t SZ>
sized_for_comba_sqr(size_t x_sw,size_t x_size,size_t z_size)290 inline bool sized_for_comba_sqr(size_t x_sw, size_t x_size,
291                                 size_t z_size)
292    {
293    return (x_sw <= SZ && x_size >= SZ && z_size >= 2*SZ);
294    }
295 
296 }
297 
bigint_mul(word z[],size_t z_size,const word x[],size_t x_size,size_t x_sw,const word y[],size_t y_size,size_t y_sw,word workspace[],size_t ws_size)298 void bigint_mul(word z[], size_t z_size,
299                 const word x[], size_t x_size, size_t x_sw,
300                 const word y[], size_t y_size, size_t y_sw,
301                 word workspace[], size_t ws_size)
302    {
303    clear_mem(z, z_size);
304 
305    if(x_sw == 1)
306       {
307       bigint_linmul3(z, y, y_sw, x[0]);
308       }
309    else if(y_sw == 1)
310       {
311       bigint_linmul3(z, x, x_sw, y[0]);
312       }
313    else if(sized_for_comba_mul<4>(x_sw, x_size, y_sw, y_size, z_size))
314       {
315       bigint_comba_mul4(z, x, y);
316       }
317    else if(sized_for_comba_mul<6>(x_sw, x_size, y_sw, y_size, z_size))
318       {
319       bigint_comba_mul6(z, x, y);
320       }
321    else if(sized_for_comba_mul<8>(x_sw, x_size, y_sw, y_size, z_size))
322       {
323       bigint_comba_mul8(z, x, y);
324       }
325    else if(sized_for_comba_mul<9>(x_sw, x_size, y_sw, y_size, z_size))
326       {
327       bigint_comba_mul9(z, x, y);
328       }
329    else if(sized_for_comba_mul<16>(x_sw, x_size, y_sw, y_size, z_size))
330       {
331       bigint_comba_mul16(z, x, y);
332       }
333    else if(sized_for_comba_mul<24>(x_sw, x_size, y_sw, y_size, z_size))
334       {
335       bigint_comba_mul24(z, x, y);
336       }
337    else if(x_sw < KARATSUBA_MULTIPLY_THRESHOLD ||
338            y_sw < KARATSUBA_MULTIPLY_THRESHOLD ||
339            !workspace)
340       {
341       basecase_mul(z, z_size, x, x_sw, y, y_sw);
342       }
343    else
344       {
345       const size_t N = karatsuba_size(z_size, x_size, x_sw, y_size, y_sw);
346 
347       if(N && z_size >= 2*N && ws_size >= 2*N)
348          karatsuba_mul(z, x, y, N, workspace);
349       else
350          basecase_mul(z, z_size, x, x_sw, y, y_sw);
351       }
352    }
353 
354 /*
355 * Squaring Algorithm Dispatcher
356 */
bigint_sqr(word z[],size_t z_size,const word x[],size_t x_size,size_t x_sw,word workspace[],size_t ws_size)357 void bigint_sqr(word z[], size_t z_size,
358                 const word x[], size_t x_size, size_t x_sw,
359                 word workspace[], size_t ws_size)
360    {
361    clear_mem(z, z_size);
362 
363    BOTAN_ASSERT(z_size/2 >= x_sw, "Output size is sufficient");
364 
365    if(x_sw == 1)
366       {
367       bigint_linmul3(z, x, x_sw, x[0]);
368       }
369    else if(sized_for_comba_sqr<4>(x_sw, x_size, z_size))
370       {
371       bigint_comba_sqr4(z, x);
372       }
373    else if(sized_for_comba_sqr<6>(x_sw, x_size, z_size))
374       {
375       bigint_comba_sqr6(z, x);
376       }
377    else if(sized_for_comba_sqr<8>(x_sw, x_size, z_size))
378       {
379       bigint_comba_sqr8(z, x);
380       }
381    else if(sized_for_comba_sqr<9>(x_sw, x_size, z_size))
382       {
383       bigint_comba_sqr9(z, x);
384       }
385    else if(sized_for_comba_sqr<16>(x_sw, x_size, z_size))
386       {
387       bigint_comba_sqr16(z, x);
388       }
389    else if(sized_for_comba_sqr<24>(x_sw, x_size, z_size))
390       {
391       bigint_comba_sqr24(z, x);
392       }
393    else if(x_size < KARATSUBA_SQUARE_THRESHOLD || !workspace)
394       {
395       basecase_sqr(z, z_size, x, x_sw);
396       }
397    else
398       {
399       const size_t N = karatsuba_size(z_size, x_size, x_sw);
400 
401       if(N && z_size >= 2*N && ws_size >= 2*N)
402          karatsuba_sqr(z, x, N, workspace);
403       else
404          basecase_sqr(z, z_size, x, x_sw);
405       }
406    }
407 
408 }
409