1 /*
2  * Copyright (c) 2018, Henry Corrigan-Gibbs
3  *
4  * This Source Code Form is subject to the terms of the Mozilla Public
5  * License, v. 2.0. If a copy of the MPL was not distributed with this
6  * file, You can obtain one at http://mozilla.org/MPL/2.0/.
7  */
8 
9 #include <mprio.h>
10 
11 #include "config.h"
12 #include "poly.h"
13 #include "util.h"
14 
15 /*
16  * A nice exposition of the recursive FFT/DFT algorithm we implement
17  * is in the book:
18  *
19  *   "Modern Computer Algebra"
20  *    by Von zur Gathen and Gerhard.
21  *    Cambridge University Press, 2013.
22  *
23  * They present this algorithm as Algorithm 8.14.
24  */
25 
26 static SECStatus
fft_recurse(mp_int * out,const mp_int * mod,int n,const mp_int * roots,const mp_int * ys,mp_int * tmp,mp_int * ySub,mp_int * rootsSub)27 fft_recurse(mp_int* out, const mp_int* mod, int n, const mp_int* roots,
28             const mp_int* ys, mp_int* tmp, mp_int* ySub, mp_int* rootsSub)
29 {
30   if (n == 1) {
31     MP_CHECK(mp_copy(&ys[0], &out[0]));
32     return SECSuccess;
33   }
34 
35   // Recurse on the first half
36   for (int i = 0; i < n / 2; i++) {
37     MP_CHECK(mp_addmod(&ys[i], &ys[i + (n / 2)], mod, &ySub[i]));
38     MP_CHECK(mp_copy(&roots[2 * i], &rootsSub[i]));
39   }
40 
41   MP_CHECK(fft_recurse(tmp, mod, n / 2, rootsSub, ySub, &tmp[n / 2],
42                        &ySub[n / 2], &rootsSub[n / 2]));
43   for (int i = 0; i < n / 2; i++) {
44     MP_CHECK(mp_copy(&tmp[i], &out[2 * i]));
45   }
46 
47   // Recurse on the second half
48   for (int i = 0; i < n / 2; i++) {
49     MP_CHECK(mp_submod(&ys[i], &ys[i + (n / 2)], mod, &ySub[i]));
50     MP_CHECK(mp_mulmod(&ySub[i], &roots[i], mod, &ySub[i]));
51   }
52 
53   MP_CHECK(fft_recurse(tmp, mod, n / 2, rootsSub, ySub, &tmp[n / 2],
54                        &ySub[n / 2], &rootsSub[n / 2]));
55   for (int i = 0; i < n / 2; i++) {
56     MP_CHECK(mp_copy(&tmp[i], &out[2 * i + 1]));
57   }
58 
59   return SECSuccess;
60 }
61 
62 static SECStatus
fft_interpolate_raw(mp_int * out,const mp_int * ys,int nPoints,const_MPArray roots,const mp_int * mod,bool invert)63 fft_interpolate_raw(mp_int* out, const mp_int* ys, int nPoints,
64                     const_MPArray roots, const mp_int* mod, bool invert)
65 {
66   SECStatus rv = SECSuccess;
67   MPArray tmp = NULL;
68   MPArray ySub = NULL;
69   MPArray rootsSub = NULL;
70 
71   P_CHECKA(tmp = MPArray_new(nPoints));
72   P_CHECKA(ySub = MPArray_new(nPoints));
73   P_CHECKA(rootsSub = MPArray_new(nPoints));
74 
75   mp_int n_inverse;
76   MP_DIGITS(&n_inverse) = NULL;
77 
78   MP_CHECKC(fft_recurse(out, mod, nPoints, roots->data, ys, tmp->data,
79                         ySub->data, rootsSub->data));
80 
81   if (invert) {
82     MP_CHECKC(mp_init(&n_inverse));
83 
84     mp_set(&n_inverse, nPoints);
85     MP_CHECKC(mp_invmod(&n_inverse, mod, &n_inverse));
86     for (int i = 0; i < nPoints; i++) {
87       MP_CHECKC(mp_mulmod(&out[i], &n_inverse, mod, &out[i]));
88     }
89   }
90 
91 cleanup:
92   MPArray_clear(tmp);
93   MPArray_clear(ySub);
94   MPArray_clear(rootsSub);
95   mp_clear(&n_inverse);
96 
97   return rv;
98 }
99 
100 /*
101  * The PrioConfig object has a list of N-th roots of unity for large N.
102  * This routine returns the n-th roots of unity for n < N, where n is
103  * a power of two. If the `invert` flag is set, it returns the inverses
104  * of the n-th roots of unity.
105  */
106 SECStatus
poly_fft_get_roots(MPArray roots_out,int n_points,const_PrioConfig cfg,bool invert)107 poly_fft_get_roots(MPArray roots_out, int n_points, const_PrioConfig cfg,
108                    bool invert)
109 {
110   if (n_points < 1) {
111     return SECFailure;
112   }
113 
114   if (n_points != roots_out->len) {
115     return SECFailure;
116   }
117 
118   if (n_points > cfg->n_roots) {
119     return SECFailure;
120   }
121 
122   mp_set(&roots_out->data[0], 1);
123   if (n_points == 1) {
124     return SECSuccess;
125   }
126 
127   const int step_size = cfg->n_roots / n_points;
128   mp_int* gen = &roots_out->data[1];
129 
130   MP_CHECK(mp_copy(&cfg->generator, gen));
131 
132   if (invert) {
133     MP_CHECK(mp_invmod(gen, &cfg->modulus, gen));
134   }
135 
136   // Compute g' = g^step_size
137   // Now, g' generates a subgroup of order n_points.
138   MP_CHECK(mp_exptmod_d(gen, step_size, &cfg->modulus, gen));
139 
140   for (int i = 2; i < n_points; i++) {
141     // Compute g^i for all i in {0,..., n-1}
142     MP_CHECK(mp_mulmod(gen, &roots_out->data[i - 1], &cfg->modulus,
143                        &roots_out->data[i]));
144   }
145 
146   return SECSuccess;
147 }
148 
149 SECStatus
poly_fft(MPArray points_out,const_MPArray points_in,const_PrioConfig cfg,bool invert)150 poly_fft(MPArray points_out, const_MPArray points_in, const_PrioConfig cfg,
151          bool invert)
152 {
153   SECStatus rv = SECSuccess;
154   const int n_points = points_in->len;
155   MPArray scaled_roots = NULL;
156 
157   if (points_out->len != points_in->len)
158     return SECFailure;
159   if (n_points > cfg->n_roots)
160     return SECFailure;
161   if (cfg->n_roots % n_points != 0)
162     return SECFailure;
163 
164   P_CHECKA(scaled_roots = MPArray_new(n_points));
165   P_CHECKC(poly_fft_get_roots(scaled_roots, n_points, cfg, invert));
166 
167   P_CHECKC(fft_interpolate_raw(points_out->data, points_in->data, n_points,
168                                scaled_roots, &cfg->modulus, invert));
169 
170 cleanup:
171   MPArray_clear(scaled_roots);
172 
173   return SECSuccess;
174 }
175 
176 SECStatus
poly_eval(mp_int * value,const_MPArray coeffs,const mp_int * eval_at,const_PrioConfig cfg)177 poly_eval(mp_int* value, const_MPArray coeffs, const mp_int* eval_at,
178           const_PrioConfig cfg)
179 {
180   SECStatus rv = SECSuccess;
181   const int n = coeffs->len;
182 
183   // Use Horner's method to evaluate the polynomial at the point
184   // `eval_at`
185   MP_CHECK(mp_copy(&coeffs->data[n - 1], value));
186   for (int i = n - 2; i >= 0; i--) {
187     MP_CHECK(mp_mulmod(value, eval_at, &cfg->modulus, value));
188     MP_CHECK(mp_addmod(value, &coeffs->data[i], &cfg->modulus, value));
189   }
190 
191   return rv;
192 }
193 
194 SECStatus
poly_interp_evaluate(mp_int * value,const_MPArray poly_points,const mp_int * eval_at,const_PrioConfig cfg)195 poly_interp_evaluate(mp_int* value, const_MPArray poly_points,
196                      const mp_int* eval_at, const_PrioConfig cfg)
197 {
198   SECStatus rv;
199   MPArray coeffs = NULL;
200   const int N = poly_points->len;
201 
202   P_CHECKA(coeffs = MPArray_new(N));
203 
204   // Interpolate polynomial through roots of unity
205   P_CHECKC(poly_fft(coeffs, poly_points, cfg, true))
206   P_CHECKC(poly_eval(value, coeffs, eval_at, cfg));
207 
208 cleanup:
209   MPArray_clear(coeffs);
210   return rv;
211 }
212