1 /*
2  * Copyright (c) 2003, 2007-14 Matteo Frigo
3  * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation; either version 2 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
18  *
19  */
20 
21 #include "rdft/rdft.h"
22 
23 /*
24  * Compute DHTs of prime sizes using Rader's trick: turn them
25  * into convolutions of size n - 1, which we then perform via a pair
26  * of FFTs.   (We can then do prime real FFTs via rdft-dht.c.)
27  *
28  * Optionally (determined by the "pad" field of the solver), we can
29  * perform the (cyclic) convolution by zero-padding to a size
30  * >= 2*(n-1) - 1.  This is advantageous if n-1 has large prime factors.
31  *
32  */
33 
34 typedef struct {
35      solver super;
36      int pad;
37 } S;
38 
39 typedef struct {
40      plan_rdft super;
41 
42      plan *cld1, *cld2;
43      R *omega;
44      INT n, npad, g, ginv;
45      INT is, os;
46      plan *cld_omega;
47 } P;
48 
49 static rader_tl *omegas = 0;
50 
51 /***************************************************************************/
52 
53 /* If R2HC_ONLY_CONV is 1, we use a trick to perform the convolution
54    purely in terms of R2HC transforms, as opposed to R2HC followed by H2RC.
55    This requires a few more operations, but allows us to share the same
56    plan/codelets for both Rader children. */
57 #define R2HC_ONLY_CONV 1
58 
apply(const plan * ego_,R * I,R * O)59 static void apply(const plan *ego_, R *I, R *O)
60 {
61      const P *ego = (const P *) ego_;
62      INT n = ego->n; /* prime */
63      INT npad = ego->npad; /* == n - 1 for unpadded Rader; always even */
64      INT is = ego->is, os;
65      INT k, gpower, g;
66      R *buf, *omega;
67      R r0;
68 
69      buf = (R *) MALLOC(sizeof(R) * npad, BUFFERS);
70 
71      /* First, permute the input, storing in buf: */
72      g = ego->g;
73      for (gpower = 1, k = 0; k < n - 1; ++k, gpower = MULMOD(gpower, g, n)) {
74 	  buf[k] = I[gpower * is];
75      }
76      /* gpower == g^(n-1) mod n == 1 */;
77 
78      A(n - 1 <= npad);
79      for (k = n - 1; k < npad; ++k) /* optionally, zero-pad convolution */
80 	  buf[k] = 0;
81 
82      os = ego->os;
83 
84      /* compute RDFT of buf, storing in buf (i.e., in-place): */
85      {
86 	    plan_rdft *cld = (plan_rdft *) ego->cld1;
87 	    cld->apply((plan *) cld, buf, buf);
88      }
89 
90      /* set output DC component: */
91      O[0] = (r0 = I[0]) + buf[0];
92 
93      /* now, multiply by omega: */
94      omega = ego->omega;
95      buf[0] *= omega[0];
96      for (k = 1; k < npad/2; ++k) {
97 	  E rB, iB, rW, iW, a, b;
98 	  rW = omega[k];
99 	  iW = omega[npad - k];
100 	  rB = buf[k];
101 	  iB = buf[npad - k];
102 	  a = rW * rB - iW * iB;
103 	  b = rW * iB + iW * rB;
104 #if R2HC_ONLY_CONV
105 	  buf[k] = a + b;
106 	  buf[npad - k] = a - b;
107 #else
108 	  buf[k] = a;
109 	  buf[npad - k] = b;
110 #endif
111      }
112      /* Nyquist component: */
113      A(k + k == npad); /* since npad is even */
114      buf[k] *= omega[k];
115 
116      /* this will add input[0] to all of the outputs after the ifft */
117      buf[0] += r0;
118 
119      /* inverse FFT: */
120      {
121 	    plan_rdft *cld = (plan_rdft *) ego->cld2;
122 	    cld->apply((plan *) cld, buf, buf);
123      }
124 
125      /* do inverse permutation to unshuffle the output: */
126      A(gpower == 1);
127 #if R2HC_ONLY_CONV
128      O[os] = buf[0];
129      gpower = g = ego->ginv;
130      A(npad == n - 1 || npad/2 >= n - 1);
131      if (npad == n - 1) {
132 	  for (k = 1; k < npad/2; ++k, gpower = MULMOD(gpower, g, n)) {
133 	       O[gpower * os] = buf[k] + buf[npad - k];
134 	  }
135 	  O[gpower * os] = buf[k];
136 	  ++k, gpower = MULMOD(gpower, g, n);
137 	  for (; k < npad; ++k, gpower = MULMOD(gpower, g, n)) {
138 	       O[gpower * os] = buf[npad - k] - buf[k];
139 	  }
140      }
141      else {
142 	  for (k = 1; k < n - 1; ++k, gpower = MULMOD(gpower, g, n)) {
143 	       O[gpower * os] = buf[k] + buf[npad - k];
144 	  }
145      }
146 #else
147      g = ego->ginv;
148      for (k = 0; k < n - 1; ++k, gpower = MULMOD(gpower, g, n)) {
149 	  O[gpower * os] = buf[k];
150      }
151 #endif
152      A(gpower == 1);
153 
154      X(ifree)(buf);
155 }
156 
mkomega(enum wakefulness wakefulness,plan * p_,INT n,INT npad,INT ginv)157 static R *mkomega(enum wakefulness wakefulness,
158 		  plan *p_, INT n, INT npad, INT ginv)
159 {
160      plan_rdft *p = (plan_rdft *) p_;
161      R *omega;
162      INT i, gpower;
163      trigreal scale;
164      triggen *t;
165 
166      if ((omega = X(rader_tl_find)(n, npad + 1, ginv, omegas)))
167 	  return omega;
168 
169      omega = (R *)MALLOC(sizeof(R) * npad, TWIDDLES);
170 
171      scale = npad; /* normalization for convolution */
172 
173      t = X(mktriggen)(wakefulness, n);
174      for (i = 0, gpower = 1; i < n-1; ++i, gpower = MULMOD(gpower, ginv, n)) {
175 	  trigreal w[2];
176 	  t->cexpl(t, gpower, w);
177 	  omega[i] = (w[0] + w[1]) / scale;
178      }
179      X(triggen_destroy)(t);
180      A(gpower == 1);
181 
182      A(npad == n - 1 || npad >= 2*(n - 1) - 1);
183 
184      for (; i < npad; ++i)
185 	  omega[i] = K(0.0);
186      if (npad > n - 1)
187 	  for (i = 1; i < n-1; ++i)
188 	       omega[npad - i] = omega[n - 1 - i];
189 
190      p->apply(p_, omega, omega);
191 
192      X(rader_tl_insert)(n, npad + 1, ginv, omega, &omegas);
193      return omega;
194 }
195 
free_omega(R * omega)196 static void free_omega(R *omega)
197 {
198      X(rader_tl_delete)(omega, &omegas);
199 }
200 
201 /***************************************************************************/
202 
awake(plan * ego_,enum wakefulness wakefulness)203 static void awake(plan *ego_, enum wakefulness wakefulness)
204 {
205      P *ego = (P *) ego_;
206 
207      X(plan_awake)(ego->cld1, wakefulness);
208      X(plan_awake)(ego->cld2, wakefulness);
209      X(plan_awake)(ego->cld_omega, wakefulness);
210 
211      switch (wakefulness) {
212 	 case SLEEPY:
213 	      free_omega(ego->omega);
214 	      ego->omega = 0;
215 	      break;
216 	 default:
217 	      ego->g = X(find_generator)(ego->n);
218 	      ego->ginv = X(power_mod)(ego->g, ego->n - 2, ego->n);
219 	      A(MULMOD(ego->g, ego->ginv, ego->n) == 1);
220 
221 	      A(!ego->omega);
222 	      ego->omega = mkomega(wakefulness,
223 				   ego->cld_omega,ego->n,ego->npad,ego->ginv);
224 	      break;
225      }
226 }
227 
destroy(plan * ego_)228 static void destroy(plan *ego_)
229 {
230      P *ego = (P *) ego_;
231      X(plan_destroy_internal)(ego->cld_omega);
232      X(plan_destroy_internal)(ego->cld2);
233      X(plan_destroy_internal)(ego->cld1);
234 }
235 
print(const plan * ego_,printer * p)236 static void print(const plan *ego_, printer *p)
237 {
238      const P *ego = (const P *) ego_;
239 
240      p->print(p, "(dht-rader-%D/%D%ois=%oos=%(%p%)",
241               ego->n, ego->npad, ego->is, ego->os, ego->cld1);
242      if (ego->cld2 != ego->cld1)
243           p->print(p, "%(%p%)", ego->cld2);
244      if (ego->cld_omega != ego->cld1 && ego->cld_omega != ego->cld2)
245           p->print(p, "%(%p%)", ego->cld_omega);
246      p->putchr(p, ')');
247 }
248 
applicable(const solver * ego,const problem * p_,const planner * plnr)249 static int applicable(const solver *ego, const problem *p_, const planner *plnr)
250 {
251      const problem_rdft *p = (const problem_rdft *) p_;
252      UNUSED(ego);
253      return (1
254 	     && p->sz->rnk == 1
255 	     && p->vecsz->rnk == 0
256 	     && p->kind[0] == DHT
257 	     && X(is_prime)(p->sz->dims[0].n)
258 	     && p->sz->dims[0].n > 2
259 	     && CIMPLIES(NO_SLOWP(plnr), p->sz->dims[0].n > RADER_MAX_SLOW)
260 	     /* proclaim the solver SLOW if p-1 is not easily
261 		factorizable.  Unlike in the complex case where
262 		Bluestein can solve the problem, in the DHT case we
263 		may have no other choice */
264 	     && CIMPLIES(NO_SLOWP(plnr), X(factors_into_small_primes)(p->sz->dims[0].n - 1))
265 	  );
266 }
267 
choose_transform_size(INT minsz)268 static INT choose_transform_size(INT minsz)
269 {
270      static const INT primes[] = { 2, 3, 5, 0 };
271      while (!X(factors_into)(minsz, primes) || minsz % 2)
272 	  ++minsz;
273      return minsz;
274 }
275 
mkplan(const solver * ego_,const problem * p_,planner * plnr)276 static plan *mkplan(const solver *ego_, const problem *p_, planner *plnr)
277 {
278      const S *ego = (const S *) ego_;
279      const problem_rdft *p = (const problem_rdft *) p_;
280      P *pln;
281      INT n, npad;
282      INT is, os;
283      plan *cld1 = (plan *) 0;
284      plan *cld2 = (plan *) 0;
285      plan *cld_omega = (plan *) 0;
286      R *buf = (R *) 0;
287      problem *cldp;
288 
289      static const plan_adt padt = {
290 	  X(rdft_solve), awake, print, destroy
291      };
292 
293      if (!applicable(ego_, p_, plnr))
294 	  return (plan *) 0;
295 
296      n = p->sz->dims[0].n;
297      is = p->sz->dims[0].is;
298      os = p->sz->dims[0].os;
299 
300      if (ego->pad)
301 	  npad = choose_transform_size(2 * (n - 1) - 1);
302      else
303 	  npad = n - 1;
304 
305      /* initial allocation for the purpose of planning */
306      buf = (R *) MALLOC(sizeof(R) * npad, BUFFERS);
307 
308      cld1 = X(mkplan_f_d)(plnr,
309 			  X(mkproblem_rdft_1_d)(X(mktensor_1d)(npad, 1, 1),
310 						X(mktensor_1d)(1, 0, 0),
311 						buf, buf,
312 						R2HC),
313 			  NO_SLOW, 0, 0);
314      if (!cld1) goto nada;
315 
316      cldp =
317           X(mkproblem_rdft_1_d)(
318                X(mktensor_1d)(npad, 1, 1),
319                X(mktensor_1d)(1, 0, 0),
320 	       buf, buf,
321 #if R2HC_ONLY_CONV
322 	       R2HC
323 #else
324 	       HC2R
325 #endif
326 	       );
327      if (!(cld2 = X(mkplan_f_d)(plnr, cldp, NO_SLOW, 0, 0)))
328 	  goto nada;
329 
330      /* plan for omega */
331      cld_omega = X(mkplan_f_d)(plnr,
332 			       X(mkproblem_rdft_1_d)(
333 				    X(mktensor_1d)(npad, 1, 1),
334 				    X(mktensor_1d)(1, 0, 0),
335 				    buf, buf, R2HC),
336 			       NO_SLOW, ESTIMATE, 0);
337      if (!cld_omega) goto nada;
338 
339      /* deallocate buffers; let awake() or apply() allocate them for real */
340      X(ifree)(buf);
341      buf = 0;
342 
343      pln = MKPLAN_RDFT(P, &padt, apply);
344      pln->cld1 = cld1;
345      pln->cld2 = cld2;
346      pln->cld_omega = cld_omega;
347      pln->omega = 0;
348      pln->n = n;
349      pln->npad = npad;
350      pln->is = is;
351      pln->os = os;
352 
353      X(ops_add)(&cld1->ops, &cld2->ops, &pln->super.super.ops);
354      pln->super.super.ops.other += (npad/2-1)*6 + npad + n + (n-1) * ego->pad;
355      pln->super.super.ops.add += (npad/2-1)*2 + 2 + (n-1) * ego->pad;
356      pln->super.super.ops.mul += (npad/2-1)*4 + 2 + ego->pad;
357 #if R2HC_ONLY_CONV
358      pln->super.super.ops.other += n-2 - ego->pad;
359      pln->super.super.ops.add += (npad/2-1)*2 + (n-2) - ego->pad;
360 #endif
361 
362      return &(pln->super.super);
363 
364  nada:
365      X(ifree0)(buf);
366      X(plan_destroy_internal)(cld_omega);
367      X(plan_destroy_internal)(cld2);
368      X(plan_destroy_internal)(cld1);
369      return 0;
370 }
371 
372 /* constructors */
373 
mksolver(int pad)374 static solver *mksolver(int pad)
375 {
376      static const solver_adt sadt = { PROBLEM_RDFT, mkplan, 0 };
377      S *slv = MKSOLVER(S, &sadt);
378      slv->pad = pad;
379      return &(slv->super);
380 }
381 
X(dht_rader_register)382 void X(dht_rader_register)(planner *p)
383 {
384      REGISTER_SOLVER(p, mksolver(0));
385      REGISTER_SOLVER(p, mksolver(1));
386 }
387