1 /* This Source Code Form is subject to the terms of the Mozilla Public
2  * License, v. 2.0. If a copy of the MPL was not distributed with this
3  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4 
5 #include "mpi-priv.h"
6 #include <c_asm.h>
7 
8 #define MP_MUL_DxD(a, b, Phi, Plo)              \
9     {                                           \
10         Plo = asm("mulq %a0, %a1, %v0", a, b);  \
11         Phi = asm("umulh %a0, %a1, %v0", a, b); \
12     }
13 
14 /* This is empty for the loop in s_mpv_mul_d    */
15 #define CARRY_ADD
16 
17 #define ONE_MUL                     \
18     a_i = *a++;                     \
19     MP_MUL_DxD(a_i, b, a1b1, a0b0); \
20     a0b0 += carry;                  \
21     if (a0b0 < carry)               \
22         ++a1b1;                     \
23     CARRY_ADD                       \
24     *c++ = a0b0;                    \
25     carry = a1b1;
26 
27 #define FOUR_MUL \
28     ONE_MUL      \
29     ONE_MUL      \
30     ONE_MUL      \
31     ONE_MUL
32 
33 #define SIXTEEN_MUL \
34     FOUR_MUL        \
35     FOUR_MUL        \
36     FOUR_MUL        \
37     FOUR_MUL
38 
39 #define THIRTYTWO_MUL \
40     SIXTEEN_MUL       \
41     SIXTEEN_MUL
42 
43 #define ONETWENTYEIGHT_MUL \
44     THIRTYTWO_MUL          \
45     THIRTYTWO_MUL          \
46     THIRTYTWO_MUL          \
47     THIRTYTWO_MUL
48 
49 #define EXPAND_256(CALL)                     \
50     mp_digit carry = 0;                      \
51     mp_digit a_i;                            \
52     mp_digit a0b0, a1b1;                     \
53     if (a_len & 255) {                       \
54         if (a_len & 1) {                     \
55             ONE_MUL                          \
56         }                                    \
57         if (a_len & 2) {                     \
58             ONE_MUL                          \
59             ONE_MUL                          \
60         }                                    \
61         if (a_len & 4) {                     \
62             FOUR_MUL                         \
63         }                                    \
64         if (a_len & 8) {                     \
65             FOUR_MUL                         \
66             FOUR_MUL                         \
67         }                                    \
68         if (a_len & 16) {                    \
69             SIXTEEN_MUL                      \
70         }                                    \
71         if (a_len & 32) {                    \
72             THIRTYTWO_MUL                    \
73         }                                    \
74         if (a_len & 64) {                    \
75             THIRTYTWO_MUL                    \
76             THIRTYTWO_MUL                    \
77         }                                    \
78         if (a_len & 128) {                   \
79             ONETWENTYEIGHT_MUL               \
80         }                                    \
81         a_len = a_len & (-256);              \
82     }                                        \
83     if (a_len >= 256) {                      \
84         carry = CALL(a, a_len, b, c, carry); \
85         c += a_len;                          \
86     }
87 
88 #define FUNC_NAME(NAME)                    \
89     mp_digit NAME(const mp_digit *a,       \
90                   mp_size a_len,           \
91                   mp_digit b, mp_digit *c, \
92                   mp_digit carry)
93 
94 #define DECLARE_MUL_256(FNAME) \
95     FUNC_NAME(FNAME)           \
96     {                          \
97         mp_digit a_i;          \
98         mp_digit a0b0, a1b1;   \
99         while (a_len) {        \
100             ONETWENTYEIGHT_MUL \
101             ONETWENTYEIGHT_MUL \
102             a_len -= 256;      \
103         }                      \
104         return carry;          \
105     }
106 
107 /* Expanding the loop in s_mpv_mul_d appeared to slow down the
108    (admittedly) small number of tests (i.e., timetest) used to
109    measure performance, so this define disables that optimization. */
110 #define DO_NOT_EXPAND 1
111 
112 /* Need forward declaration so it can be instantiated after
113    the routine that uses it; this helps locality somewhat  */
114 #if !defined(DO_NOT_EXPAND)
115 FUNC_NAME(s_mpv_mul_d_MUL256);
116 #endif
117 
118 /* c = a * b */
119 void
s_mpv_mul_d(const mp_digit * a,mp_size a_len,mp_digit b,mp_digit * c)120 s_mpv_mul_d(const mp_digit *a, mp_size a_len,
121             mp_digit b, mp_digit *c)
122 {
123 #if defined(DO_NOT_EXPAND)
124     mp_digit carry = 0;
125     while (a_len--) {
126         mp_digit a_i = *a++;
127         mp_digit a0b0, a1b1;
128 
129         MP_MUL_DxD(a_i, b, a1b1, a0b0);
130 
131         a0b0 += carry;
132         if (a0b0 < carry)
133             ++a1b1;
134         *c++ = a0b0;
135         carry = a1b1;
136     }
137 #else
138     EXPAND_256(s_mpv_mul_d_MUL256)
139 #endif
140     *c = carry;
141 }
142 
143 #if !defined(DO_NOT_EXPAND)
144 DECLARE_MUL_256(s_mpv_mul_d_MUL256)
145 #endif
146 
147 #undef CARRY_ADD
148 /* This is redefined for the loop in s_mpv_mul_d_add */
149 #define CARRY_ADD     \
150     a0b0 += a_i = *c; \
151     if (a0b0 < a_i)   \
152         ++a1b1;
153 
154 /* Need forward declaration so it can be instantiated between the
155    two routines that use it; this helps locality somewhat  */
156 FUNC_NAME(s_mpv_mul_d_add_MUL256);
157 
158 /* c += a * b */
159 void
s_mpv_mul_d_add(const mp_digit * a,mp_size a_len,mp_digit b,mp_digit * c)160 s_mpv_mul_d_add(const mp_digit *a, mp_size a_len,
161                 mp_digit b, mp_digit *c)
162 {
163     EXPAND_256(s_mpv_mul_d_add_MUL256)
164     *c = carry;
165 }
166 
167 /* Instantiate multiply 256 routine here */
DECLARE_MUL_256(s_mpv_mul_d_add_MUL256)168 DECLARE_MUL_256(s_mpv_mul_d_add_MUL256)
169 
170 /* Presently, this is only used by the Montgomery arithmetic code. */
171 /* c += a * b */
172 void
173 s_mpv_mul_d_add_prop(const mp_digit *a, mp_size a_len,
174                      mp_digit b, mp_digit *c)
175 {
176     EXPAND_256(s_mpv_mul_d_add_MUL256)
177     while (carry) {
178         mp_digit c_i = *c;
179         carry += c_i;
180         *c++ = carry;
181         carry = carry < c_i;
182     }
183 }
184