1 /*
2  * Copyright (c) 2003, 2007-14 Matteo Frigo
3  * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation; either version 2 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
18  *
19  */
20 
21 
22 /* Solve an R2HC/HC2R problem via post/pre processing of a DHT.  This
23    is mainly useful because we can use Rader to compute DHTs of prime
24    sizes.  It also allows us to express hc2r problems in terms of r2hc
25    (via dht-r2hc), and to do hc2r problems without destroying the input. */
26 
27 #include "rdft/rdft.h"
28 
29 typedef struct {
30      solver super;
31 } S;
32 
33 typedef struct {
34      plan_rdft super;
35      plan *cld;
36      INT is, os;
37      INT n;
38 } P;
39 
apply_r2hc(const plan * ego_,R * I,R * O)40 static void apply_r2hc(const plan *ego_, R *I, R *O)
41 {
42      const P *ego = (const P *) ego_;
43      INT os;
44      INT i, n;
45 
46      {
47 	  plan_rdft *cld = (plan_rdft *) ego->cld;
48 	  cld->apply((plan *) cld, I, O);
49      }
50 
51      n = ego->n;
52      os = ego->os;
53      for (i = 1; i < n - i; ++i) {
54 	  E a, b;
55 	  a = K(0.5) * O[os * i];
56 	  b = K(0.5) * O[os * (n - i)];
57 	  O[os * i] = a + b;
58 #if FFT_SIGN == -1
59 	  O[os * (n - i)] = b - a;
60 #else
61 	  O[os * (n - i)] = a - b;
62 #endif
63      }
64 }
65 
66 /* hc2r, destroying input as usual */
apply_hc2r(const plan * ego_,R * I,R * O)67 static void apply_hc2r(const plan *ego_, R *I, R *O)
68 {
69      const P *ego = (const P *) ego_;
70      INT is = ego->is;
71      INT i, n = ego->n;
72 
73      for (i = 1; i < n - i; ++i) {
74 	  E a, b;
75 	  a = I[is * i];
76 	  b = I[is * (n - i)];
77 #if FFT_SIGN == -1
78 	  I[is * i] = a - b;
79 	  I[is * (n - i)] = a + b;
80 #else
81 	  I[is * i] = a + b;
82 	  I[is * (n - i)] = a - b;
83 #endif
84      }
85 
86      {
87 	  plan_rdft *cld = (plan_rdft *) ego->cld;
88 	  cld->apply((plan *) cld, I, O);
89      }
90 }
91 
92 /* hc2r, without destroying input */
apply_hc2r_save(const plan * ego_,R * I,R * O)93 static void apply_hc2r_save(const plan *ego_, R *I, R *O)
94 {
95      const P *ego = (const P *) ego_;
96      INT is = ego->is, os = ego->os;
97      INT i, n = ego->n;
98 
99      O[0] = I[0];
100      for (i = 1; i < n - i; ++i) {
101 	  E a, b;
102 	  a = I[is * i];
103 	  b = I[is * (n - i)];
104 #if FFT_SIGN == -1
105 	  O[os * i] = a - b;
106 	  O[os * (n - i)] = a + b;
107 #else
108 	  O[os * i] = a + b;
109 	  O[os * (n - i)] = a - b;
110 #endif
111      }
112      if (i == n - i)
113 	  O[os * i] = I[is * i];
114 
115      {
116 	  plan_rdft *cld = (plan_rdft *) ego->cld;
117 	  cld->apply((plan *) cld, O, O);
118      }
119 }
120 
awake(plan * ego_,enum wakefulness wakefulness)121 static void awake(plan *ego_, enum wakefulness wakefulness)
122 {
123      P *ego = (P *) ego_;
124      X(plan_awake)(ego->cld, wakefulness);
125 }
126 
destroy(plan * ego_)127 static void destroy(plan *ego_)
128 {
129      P *ego = (P *) ego_;
130      X(plan_destroy_internal)(ego->cld);
131 }
132 
print(const plan * ego_,printer * p)133 static void print(const plan *ego_, printer *p)
134 {
135      const P *ego = (const P *) ego_;
136      p->print(p, "(%s-dht-%D%(%p%))",
137 	      ego->super.apply == apply_r2hc ? "r2hc" : "hc2r",
138 	      ego->n, ego->cld);
139 }
140 
applicable0(const solver * ego_,const problem * p_)141 static int applicable0(const solver *ego_, const problem *p_)
142 {
143      const problem_rdft *p = (const problem_rdft *) p_;
144      UNUSED(ego_);
145 
146      return (1
147 	     && p->sz->rnk == 1
148 	     && p->vecsz->rnk == 0
149 	     && (p->kind[0] == R2HC || p->kind[0] == HC2R)
150 
151 	     /* hack: size-2 DHT etc. are defined as being equivalent
152 		to size-2 R2HC in problem.c, so we need this to prevent
153 		infinite loops for size 2 in EXHAUSTIVE mode: */
154 	     && p->sz->dims[0].n > 2
155 	  );
156 }
157 
applicable(const solver * ego,const problem * p_,const planner * plnr)158 static int applicable(const solver *ego, const problem *p_,
159 		      const planner *plnr)
160 {
161      return (!NO_SLOWP(plnr) && applicable0(ego, p_));
162 }
163 
mkplan(const solver * ego_,const problem * p_,planner * plnr)164 static plan *mkplan(const solver *ego_, const problem *p_, planner *plnr)
165 {
166      P *pln;
167      const problem_rdft *p;
168      problem *cldp;
169      plan *cld;
170 
171      static const plan_adt padt = {
172 	  X(rdft_solve), awake, print, destroy
173      };
174 
175      if (!applicable(ego_, p_, plnr))
176           return (plan *)0;
177 
178      p = (const problem_rdft *) p_;
179 
180      if (p->kind[0] == R2HC || !NO_DESTROY_INPUTP(plnr))
181 	  cldp = X(mkproblem_rdft_1)(p->sz, p->vecsz, p->I, p->O, DHT);
182      else {
183 	  tensor *sz = X(tensor_copy_inplace)(p->sz, INPLACE_OS);
184 	  cldp = X(mkproblem_rdft_1)(sz, p->vecsz, p->O, p->O, DHT);
185 	  X(tensor_destroy)(sz);
186      }
187      cld = X(mkplan_d)(plnr, cldp);
188      if (!cld) return (plan *)0;
189 
190      pln = MKPLAN_RDFT(P, &padt, p->kind[0] == R2HC ?
191 		       apply_r2hc : (NO_DESTROY_INPUTP(plnr) ?
192 				     apply_hc2r_save : apply_hc2r));
193      pln->n = p->sz->dims[0].n;
194      pln->is = p->sz->dims[0].is;
195      pln->os = p->sz->dims[0].os;
196      pln->cld = cld;
197 
198      pln->super.super.ops = cld->ops;
199      pln->super.super.ops.other += 4 * ((pln->n - 1)/2);
200      pln->super.super.ops.add += 2 * ((pln->n - 1)/2);
201      if (p->kind[0] == R2HC)
202 	  pln->super.super.ops.mul += 2 * ((pln->n - 1)/2);
203      if (pln->super.apply == apply_hc2r_save)
204 	  pln->super.super.ops.other += 2 + (pln->n % 2 ? 0 : 2);
205 
206      return &(pln->super.super);
207 }
208 
209 /* constructor */
mksolver(void)210 static solver *mksolver(void)
211 {
212      static const solver_adt sadt = { PROBLEM_RDFT, mkplan, 0 };
213      S *slv = MKSOLVER(S, &sadt);
214      return &(slv->super);
215 }
216 
X(rdft_dht_register)217 void X(rdft_dht_register)(planner *p)
218 {
219      REGISTER_SOLVER(p, mksolver());
220 }
221