1 #include "tommath_private.h"
2 #ifdef BN_S_MP_KARATSUBA_SQR_C
3 /* LibTomMath, multiple-precision integer library -- Tom St Denis */
4 /* SPDX-License-Identifier: Unlicense */
5 
6 /* Karatsuba squaring, computes b = a*a using three
7  * half size squarings
8  *
9  * See comments of karatsuba_mul for details.  It
10  * is essentially the same algorithm but merely
11  * tuned to perform recursive squarings.
12  */
s_mp_karatsuba_sqr(const mp_int * a,mp_int * b)13 mp_err s_mp_karatsuba_sqr(const mp_int *a, mp_int *b)
14 {
15    mp_int  x0, x1, t1, t2, x0x0, x1x1;
16    int     B;
17    mp_err  err = MP_MEM;
18 
19    /* min # of digits */
20    B = a->used;
21 
22    /* now divide in two */
23    B = B >> 1;
24 
25    /* init copy all the temps */
26    if (mp_init_size(&x0, B) != MP_OKAY)
27       goto LBL_ERR;
28    if (mp_init_size(&x1, a->used - B) != MP_OKAY)
29       goto X0;
30 
31    /* init temps */
32    if (mp_init_size(&t1, a->used * 2) != MP_OKAY)
33       goto X1;
34    if (mp_init_size(&t2, a->used * 2) != MP_OKAY)
35       goto T1;
36    if (mp_init_size(&x0x0, B * 2) != MP_OKAY)
37       goto T2;
38    if (mp_init_size(&x1x1, (a->used - B) * 2) != MP_OKAY)
39       goto X0X0;
40 
41    {
42       int x;
43       mp_digit *dst, *src;
44 
45       src = a->dp;
46 
47       /* now shift the digits */
48       dst = x0.dp;
49       for (x = 0; x < B; x++) {
50          *dst++ = *src++;
51       }
52 
53       dst = x1.dp;
54       for (x = B; x < a->used; x++) {
55          *dst++ = *src++;
56       }
57    }
58 
59    x0.used = B;
60    x1.used = a->used - B;
61 
62    mp_clamp(&x0);
63 
64    /* now calc the products x0*x0 and x1*x1 */
65    if (mp_sqr(&x0, &x0x0) != MP_OKAY)
66       goto X1X1;           /* x0x0 = x0*x0 */
67    if (mp_sqr(&x1, &x1x1) != MP_OKAY)
68       goto X1X1;           /* x1x1 = x1*x1 */
69 
70    /* now calc (x1+x0)**2 */
71    if (s_mp_add(&x1, &x0, &t1) != MP_OKAY)
72       goto X1X1;           /* t1 = x1 - x0 */
73    if (mp_sqr(&t1, &t1) != MP_OKAY)
74       goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
75 
76    /* add x0y0 */
77    if (s_mp_add(&x0x0, &x1x1, &t2) != MP_OKAY)
78       goto X1X1;           /* t2 = x0x0 + x1x1 */
79    if (s_mp_sub(&t1, &t2, &t1) != MP_OKAY)
80       goto X1X1;           /* t1 = (x1+x0)**2 - (x0x0 + x1x1) */
81 
82    /* shift by B */
83    if (mp_lshd(&t1, B) != MP_OKAY)
84       goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
85    if (mp_lshd(&x1x1, B * 2) != MP_OKAY)
86       goto X1X1;           /* x1x1 = x1x1 << 2*B */
87 
88    if (mp_add(&x0x0, &t1, &t1) != MP_OKAY)
89       goto X1X1;           /* t1 = x0x0 + t1 */
90    if (mp_add(&t1, &x1x1, b) != MP_OKAY)
91       goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
92 
93    err = MP_OKAY;
94 
95 X1X1:
96    mp_clear(&x1x1);
97 X0X0:
98    mp_clear(&x0x0);
99 T2:
100    mp_clear(&t2);
101 T1:
102    mp_clear(&t1);
103 X1:
104    mp_clear(&x1);
105 X0:
106    mp_clear(&x0);
107 LBL_ERR:
108    return err;
109 }
110 #endif
111