1 /*
2 * Copyright (c) 2018, Henry Corrigan-Gibbs
3 *
4 * This Source Code Form is subject to the terms of the Mozilla Public
5 * License, v. 2.0. If a copy of the MPL was not distributed with this
6 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
7 */
8
9 #include <mprio.h>
10
11 #include "config.h"
12 #include "poly.h"
13 #include "util.h"
14
15 /*
16 * A nice exposition of the recursive FFT/DFT algorithm we implement
17 * is in the book:
18 *
19 * "Modern Computer Algebra"
20 * by Von zur Gathen and Gerhard.
21 * Cambridge University Press, 2013.
22 *
23 * They present this algorithm as Algorithm 8.14.
24 */
25
26 static SECStatus
fft_recurse(mp_int * out,const mp_int * mod,int n,const mp_int * roots,const mp_int * ys,mp_int * tmp,mp_int * ySub,mp_int * rootsSub)27 fft_recurse(mp_int* out, const mp_int* mod, int n, const mp_int* roots,
28 const mp_int* ys, mp_int* tmp, mp_int* ySub, mp_int* rootsSub)
29 {
30 if (n == 1) {
31 MP_CHECK(mp_copy(&ys[0], &out[0]));
32 return SECSuccess;
33 }
34
35 // Recurse on the first half
36 for (int i = 0; i < n / 2; i++) {
37 MP_CHECK(mp_addmod(&ys[i], &ys[i + (n / 2)], mod, &ySub[i]));
38 MP_CHECK(mp_copy(&roots[2 * i], &rootsSub[i]));
39 }
40
41 MP_CHECK(fft_recurse(tmp, mod, n / 2, rootsSub, ySub, &tmp[n / 2],
42 &ySub[n / 2], &rootsSub[n / 2]));
43 for (int i = 0; i < n / 2; i++) {
44 MP_CHECK(mp_copy(&tmp[i], &out[2 * i]));
45 }
46
47 // Recurse on the second half
48 for (int i = 0; i < n / 2; i++) {
49 MP_CHECK(mp_submod(&ys[i], &ys[i + (n / 2)], mod, &ySub[i]));
50 MP_CHECK(mp_mulmod(&ySub[i], &roots[i], mod, &ySub[i]));
51 }
52
53 MP_CHECK(fft_recurse(tmp, mod, n / 2, rootsSub, ySub, &tmp[n / 2],
54 &ySub[n / 2], &rootsSub[n / 2]));
55 for (int i = 0; i < n / 2; i++) {
56 MP_CHECK(mp_copy(&tmp[i], &out[2 * i + 1]));
57 }
58
59 return SECSuccess;
60 }
61
62 static SECStatus
fft_interpolate_raw(mp_int * out,const mp_int * ys,int nPoints,const_MPArray roots,const mp_int * mod,bool invert)63 fft_interpolate_raw(mp_int* out, const mp_int* ys, int nPoints,
64 const_MPArray roots, const mp_int* mod, bool invert)
65 {
66 SECStatus rv = SECSuccess;
67 MPArray tmp = NULL;
68 MPArray ySub = NULL;
69 MPArray rootsSub = NULL;
70
71 P_CHECKA(tmp = MPArray_new(nPoints));
72 P_CHECKA(ySub = MPArray_new(nPoints));
73 P_CHECKA(rootsSub = MPArray_new(nPoints));
74
75 mp_int n_inverse;
76 MP_DIGITS(&n_inverse) = NULL;
77
78 MP_CHECKC(fft_recurse(out, mod, nPoints, roots->data, ys, tmp->data,
79 ySub->data, rootsSub->data));
80
81 if (invert) {
82 MP_CHECKC(mp_init(&n_inverse));
83
84 mp_set(&n_inverse, nPoints);
85 MP_CHECKC(mp_invmod(&n_inverse, mod, &n_inverse));
86 for (int i = 0; i < nPoints; i++) {
87 MP_CHECKC(mp_mulmod(&out[i], &n_inverse, mod, &out[i]));
88 }
89 }
90
91 cleanup:
92 MPArray_clear(tmp);
93 MPArray_clear(ySub);
94 MPArray_clear(rootsSub);
95 mp_clear(&n_inverse);
96
97 return rv;
98 }
99
100 /*
101 * The PrioConfig object has a list of N-th roots of unity for large N.
102 * This routine returns the n-th roots of unity for n < N, where n is
103 * a power of two. If the `invert` flag is set, it returns the inverses
104 * of the n-th roots of unity.
105 */
106 SECStatus
poly_fft_get_roots(MPArray roots_out,int n_points,const_PrioConfig cfg,bool invert)107 poly_fft_get_roots(MPArray roots_out, int n_points, const_PrioConfig cfg,
108 bool invert)
109 {
110 if (n_points < 1) {
111 return SECFailure;
112 }
113
114 if (n_points != roots_out->len) {
115 return SECFailure;
116 }
117
118 if (n_points > cfg->n_roots) {
119 return SECFailure;
120 }
121
122 mp_set(&roots_out->data[0], 1);
123 if (n_points == 1) {
124 return SECSuccess;
125 }
126
127 const int step_size = cfg->n_roots / n_points;
128 mp_int* gen = &roots_out->data[1];
129
130 MP_CHECK(mp_copy(&cfg->generator, gen));
131
132 if (invert) {
133 MP_CHECK(mp_invmod(gen, &cfg->modulus, gen));
134 }
135
136 // Compute g' = g^step_size
137 // Now, g' generates a subgroup of order n_points.
138 MP_CHECK(mp_exptmod_d(gen, step_size, &cfg->modulus, gen));
139
140 for (int i = 2; i < n_points; i++) {
141 // Compute g^i for all i in {0,..., n-1}
142 MP_CHECK(mp_mulmod(gen, &roots_out->data[i - 1], &cfg->modulus,
143 &roots_out->data[i]));
144 }
145
146 return SECSuccess;
147 }
148
149 SECStatus
poly_fft(MPArray points_out,const_MPArray points_in,const_PrioConfig cfg,bool invert)150 poly_fft(MPArray points_out, const_MPArray points_in, const_PrioConfig cfg,
151 bool invert)
152 {
153 SECStatus rv = SECSuccess;
154 const int n_points = points_in->len;
155 MPArray scaled_roots = NULL;
156
157 if (points_out->len != points_in->len)
158 return SECFailure;
159 if (n_points > cfg->n_roots)
160 return SECFailure;
161 if (cfg->n_roots % n_points != 0)
162 return SECFailure;
163
164 P_CHECKA(scaled_roots = MPArray_new(n_points));
165 P_CHECKC(poly_fft_get_roots(scaled_roots, n_points, cfg, invert));
166
167 P_CHECKC(fft_interpolate_raw(points_out->data, points_in->data, n_points,
168 scaled_roots, &cfg->modulus, invert));
169
170 cleanup:
171 MPArray_clear(scaled_roots);
172
173 return SECSuccess;
174 }
175
176 SECStatus
poly_eval(mp_int * value,const_MPArray coeffs,const mp_int * eval_at,const_PrioConfig cfg)177 poly_eval(mp_int* value, const_MPArray coeffs, const mp_int* eval_at,
178 const_PrioConfig cfg)
179 {
180 SECStatus rv = SECSuccess;
181 const int n = coeffs->len;
182
183 // Use Horner's method to evaluate the polynomial at the point
184 // `eval_at`
185 MP_CHECK(mp_copy(&coeffs->data[n - 1], value));
186 for (int i = n - 2; i >= 0; i--) {
187 MP_CHECK(mp_mulmod(value, eval_at, &cfg->modulus, value));
188 MP_CHECK(mp_addmod(value, &coeffs->data[i], &cfg->modulus, value));
189 }
190
191 return rv;
192 }
193
194 SECStatus
poly_interp_evaluate(mp_int * value,const_MPArray poly_points,const mp_int * eval_at,const_PrioConfig cfg)195 poly_interp_evaluate(mp_int* value, const_MPArray poly_points,
196 const mp_int* eval_at, const_PrioConfig cfg)
197 {
198 SECStatus rv;
199 MPArray coeffs = NULL;
200 const int N = poly_points->len;
201
202 P_CHECKA(coeffs = MPArray_new(N));
203
204 // Interpolate polynomial through roots of unity
205 P_CHECKC(poly_fft(coeffs, poly_points, cfg, true))
206 P_CHECKC(poly_eval(value, coeffs, eval_at, cfg));
207
208 cleanup:
209 MPArray_clear(coeffs);
210 return rv;
211 }
212