1 /* ************************************************************************
2  * Copyright 2013 Advanced Micro Devices, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  * ************************************************************************/
16 
17 
18 #pragma once
19 #if !defined( CLFFT_FFTWTRANSFORM_H )
20 #define CLFFT_FFTWTRANSFORM_H
21 
22 #include <vector>
23 #include "fftw3.h"
24 #include "buffer.h"
25 #include "../client/openCL.misc.h" // we need this to leverage the CLFFT_INPLACE and _OUTOFPLACE enums
26 
27 enum fftw_direction {forward=-1, backward=+1};
28 
29 enum fftw_transform_type {c2c, r2c, c2r};
30 
31 template <typename T, typename fftw_T>
32 class fftw_wrapper
33 {};
34 
35 template <>
36 class fftw_wrapper<float, fftwf_complex>
37 {
38 public:
39 	fftwf_plan plan;
40 
make_plan(int x,int y,int z,int num_dimensions,int batch_size,fftwf_complex * input_ptr,fftwf_complex * output_ptr,int num_points_in_single_batch,fftw_direction direction,fftw_transform_type type)41 	void make_plan( int x, int y, int z, int num_dimensions, int batch_size, fftwf_complex* input_ptr, fftwf_complex* output_ptr, int num_points_in_single_batch, fftw_direction direction, fftw_transform_type type )
42 	{
43 		// we need to swap x,y,z dimensions because of a row-column discrepancy between clfft and fftw
44 		int lengths[max_dimension] = {z, y, x};
45 
46 		if( type == c2c )
47 		{
48 			plan = fftwf_plan_many_dft( num_dimensions,
49 										// because we swapped dimensions up above, we need to start
50 										// at the end of the array and count backwards to get the
51 										// correct dimensions passed in to fftw
52 										// e.g. if max_dimension is 3 and number_of_dimensions is 2:
53 										// lengths = {dimz, dimy, dimx}
54 										// lengths + 3 - 2 = lengths + 1
55 										// so we will skip dimz and pass in a pointer to {dimy, dimx}
56 										lengths+max_dimension-num_dimensions,
57 										batch_size,
58 										input_ptr, NULL,
59 										1, num_points_in_single_batch,
60 										output_ptr, NULL,
61 										1, num_points_in_single_batch,
62 										direction, FFTW_ESTIMATE);
63 		}
64 		else if( type == r2c )
65 		{
66 			plan = fftwf_plan_many_dft_r2c( num_dimensions,
67 											// because we swapped dimensions up above, we need to start
68 											// at the end of the array and count backwards to get the
69 											// correct dimensions passed in to fftw
70 											// e.g. if max_dimension is 3 and number_of_dimensions is 2:
71 											// lengths = {dimz, dimy, dimx}
72 											// lengths + 3 - 2 = lengths + 1
73 											// so we will skip dimz and pass in a pointer to {dimy, dimx}
74 											lengths+max_dimension-num_dimensions,
75 											batch_size,
76 											reinterpret_cast<float*>(input_ptr), NULL,
77 											1, num_points_in_single_batch,
78 											output_ptr, NULL,
79 											1, (x/2 + 1) * y * z,
80 											FFTW_ESTIMATE);
81 		}
82 		else if( type == c2r )
83 		{
84 			plan = fftwf_plan_many_dft_c2r( num_dimensions,
85 											// because we swapped dimensions up above, we need to start
86 											// at the end of the array and count backwards to get the
87 											// correct dimensions passed in to fftw
88 											// e.g. if max_dimension is 3 and number_of_dimensions is 2:
89 											// lengths = {dimz, dimy, dimx}
90 											// lengths + 3 - 2 = lengths + 1
91 											// so we will skip dimz and pass in a pointer to {dimy, dimx}
92 											lengths+max_dimension-num_dimensions,
93 											batch_size,
94 											input_ptr, NULL,
95 											1, (x/2 + 1) * y * z,
96 											reinterpret_cast<float*>(output_ptr), NULL,
97 											1, num_points_in_single_batch,
98 											FFTW_ESTIMATE);
99 		}
100 		else
101 			throw std::runtime_error( "invalid transform type in <float>make_plan" );
102 	}
103 
fftw_wrapper(int x,int y,int z,int num_dimensions,int batch_size,fftwf_complex * input_ptr,fftwf_complex * output_ptr,int num_points_in_single_batch,fftw_direction direction,fftw_transform_type type)104 	fftw_wrapper( int x, int y, int z, int num_dimensions, int batch_size, fftwf_complex* input_ptr, fftwf_complex* output_ptr, int num_points_in_single_batch, fftw_direction direction, fftw_transform_type type )
105 	{
106 		make_plan( x, y, z, num_dimensions, batch_size, input_ptr, output_ptr, num_points_in_single_batch, direction, type );
107 	}
108 
destroy_plan()109 	void destroy_plan()
110 	{
111 		fftwf_destroy_plan(plan);
112 	}
113 
~fftw_wrapper()114 	~fftw_wrapper()
115 	{
116 		destroy_plan();
117 	}
118 
execute()119 	void execute()
120 	{
121 		fftwf_execute(plan);
122 	}
123 };
124 
125 template <>
126 class fftw_wrapper<double, fftw_complex>
127 {
128 public:
129 	fftw_plan plan;
130 
make_plan(int x,int y,int z,int num_dimensions,int batch_size,fftw_complex * input_ptr,fftw_complex * output_ptr,int num_points_in_single_batch,fftw_direction direction,fftw_transform_type type)131 	void make_plan( int x, int y, int z, int num_dimensions, int batch_size, fftw_complex* input_ptr, fftw_complex* output_ptr, int num_points_in_single_batch, fftw_direction direction, fftw_transform_type type )
132 	{
133 		// we need to swap x,y,z dimensions because of a row-column discrepancy between clfft and fftw
134 		int lengths[max_dimension] = {z, y, x};
135 
136 		if( type == c2c )
137 		{
138 			plan = fftw_plan_many_dft( num_dimensions,
139 									// because we swapped dimensions up above, we need to start
140 									// at the end of the array and count backwards to get the
141 									// correct dimensions passed in to fftw
142 									// e.g. if max_dimension is 3 and number_of_dimensions is 2:
143 									// lengths = {dimz, dimy, dimx}
144 									// lengths + 3 - 2 = lengths + 1
145 									// so we will skip dimz and pass in a pointer to {dimy, dimx}
146 									lengths+max_dimension-num_dimensions,
147 									batch_size,
148 									input_ptr, NULL,
149 									1, num_points_in_single_batch,
150 									output_ptr, NULL,
151 									1, num_points_in_single_batch,
152 									direction, FFTW_ESTIMATE);
153 		}
154 		else if( type == r2c )
155 		{
156 			plan = fftw_plan_many_dft_r2c( num_dimensions,
157 											// because we swapped dimensions up above, we need to start
158 											// at the end of the array and count backwards to get the
159 											// correct dimensions passed in to fftw
160 											// e.g. if max_dimension is 3 and number_of_dimensions is 2:
161 											// lengths = {dimz, dimy, dimx}
162 											// lengths + 3 - 2 = lengths + 1
163 											// so we will skip dimz and pass in a pointer to {dimy, dimx}
164 											lengths+max_dimension-num_dimensions,
165 											batch_size,
166 											reinterpret_cast<double*>(input_ptr), NULL,
167 											1, num_points_in_single_batch,
168 											output_ptr, NULL,
169 											1, (x/2 + 1) * y * z,
170 											FFTW_ESTIMATE);
171 		}
172 		else if( type == c2r )
173 		{
174 			plan = fftw_plan_many_dft_c2r( num_dimensions,
175 											// because we swapped dimensions up above, we need to start
176 											// at the end of the array and count backwards to get the
177 											// correct dimensions passed in to fftw
178 											// e.g. if max_dimension is 3 and number_of_dimensions is 2:
179 											// lengths = {dimz, dimy, dimx}
180 											// lengths + 3 - 2 = lengths + 1
181 											// so we will skip dimz and pass in a pointer to {dimy, dimx}
182 											lengths+max_dimension-num_dimensions,
183 											batch_size,
184 											input_ptr, NULL,
185 											1, (x/2 + 1) * y * z,
186 											reinterpret_cast<double*>(output_ptr), NULL,
187 											1, num_points_in_single_batch,
188 											FFTW_ESTIMATE);
189 		}
190 		else
191 			throw std::runtime_error( "invalid transform type in <double>make_plan" );
192 	}
193 
fftw_wrapper(int x,int y,int z,int num_dimensions,int batch_size,fftw_complex * input_ptr,fftw_complex * output_ptr,int num_points_in_single_batch,fftw_direction direction,fftw_transform_type type)194 	fftw_wrapper( int x, int y, int z, int num_dimensions, int batch_size, fftw_complex* input_ptr, fftw_complex* output_ptr, int num_points_in_single_batch, fftw_direction direction, fftw_transform_type type )
195 	{
196 		make_plan( x, y, z, num_dimensions, batch_size, input_ptr, output_ptr, num_points_in_single_batch, direction, type );
197 	}
198 
destroy_plan()199 	void destroy_plan()
200 	{
201 		fftw_destroy_plan(plan);
202 	}
203 
~fftw_wrapper()204 	~fftw_wrapper()
205 	{
206 		destroy_plan();
207 	}
208 
execute()209 	void execute()
210 	{
211 		fftw_execute(plan);
212 	}
213 };
214 
215 /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
216 /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
217 template <typename T, typename fftw_T>
218 class fftw {
219 private:
220 	static const size_t tightly_packed_distance = 0;
221 
222 	std::vector<size_t> _lengths;
223 	fftw_direction _direction;
224 	fftw_transform_type _type;
225 	layout::buffer_layout_t _input_layout, _output_layout;
226 	size_t _batch_size;
227 	buffer<T> input;
228 	buffer<T> output;
229 	fftw_wrapper<T, fftw_T> fftw_guts;
230 
231 	T _forward_scale, _backward_scale;
232 public:
233 	/*****************************************************/
fftw(const size_t number_of_dimensions_in,const size_t * lengths_in,const size_t batch_size_in,fftw_transform_type type_in)234 	fftw( const size_t number_of_dimensions_in, const size_t* lengths_in, const size_t batch_size_in, fftw_transform_type type_in )
235 		: _lengths( initialized_lengths( number_of_dimensions_in, lengths_in ) )
236 		, _direction( forward )
237 		, _type( type_in )
238 		, _input_layout( initialized_input_layout() )
239 		, _output_layout( initialized_output_layout() )
240 		, _batch_size( batch_size_in )
241 		, input( number_of_dimensions_in,
242 				lengths_in,
243 				NULL,
244 				batch_size_in,
245 				tightly_packed_distance,
246 				_input_layout,
247 				CLFFT_OUTOFPLACE )
248 		, output( number_of_dimensions_in,
249 				lengths_in,
250 				NULL,
251 				batch_size_in,
252 				tightly_packed_distance,
253 				_output_layout,
254 				CLFFT_OUTOFPLACE )
255 		, _forward_scale( 1.0f )
256 		, _backward_scale( 1.0f/T(input.number_of_data_points_single_batch()) )
257 		, fftw_guts( (int)_lengths[dimx], (int)_lengths[dimy], (int)_lengths[dimz],
258 					 (int)number_of_dimensions_in, (int)batch_size_in,
259 					 reinterpret_cast<fftw_T*>(input_ptr()),
260 					 reinterpret_cast<fftw_T*>(output_ptr()),
261 					 (int)(_lengths[dimx]*_lengths[dimy]*_lengths[dimz]), _direction, _type)
262 	{
263 		clear_data_buffer();
264 	}
265 
266 	/*****************************************************/
~fftw()267 	~fftw() {}
268 
269 	/*****************************************************/
initialized_input_layout()270 	layout::buffer_layout_t initialized_input_layout()
271 	{
272 		if( _type == c2c )
273 			return layout::complex_interleaved;
274 		else if( _type == r2c )
275 			return layout::real;
276 		else if( _type == c2r )
277 			return layout::hermitian_interleaved;
278 		else
279 			throw std::runtime_error( "invalid transform type in initialized_input_layout" );
280 	}
281 
282 	/*****************************************************/
initialized_output_layout()283 	layout::buffer_layout_t initialized_output_layout()
284 	{
285 		if( _type == c2c )
286 			return layout::complex_interleaved;
287 		else if( _type == r2c )
288 			return layout::hermitian_interleaved;
289 		else if( _type == c2r )
290 			return layout::real;
291 		else
292 			throw std::runtime_error( "invalid transform type in initialized_input_layout" );
293 	}
294 
295 	/*****************************************************/
initialized_lengths(const size_t number_of_dimensions,const size_t * lengths_in)296 	std::vector<size_t> initialized_lengths( const size_t number_of_dimensions, const size_t* lengths_in )
297 	{
298 		std::vector<size_t> lengths( 3, 1 ); // start with 1, 1, 1
299 
300 		for( size_t i = 0; i < number_of_dimensions; i++ )
301 		{
302 			lengths[i] = lengths_in[i];
303 		}
304 
305 		return lengths;
306 	}
307 
308 	/*****************************************************/
input_ptr()309 	T* input_ptr()
310 	{
311 		if( _input_layout == layout::real )
312 			return input.real_ptr();
313 		else if( _input_layout == layout::complex_interleaved )
314 			return input.interleaved_ptr();
315 		else if( _input_layout == layout::hermitian_interleaved )
316 			return input.interleaved_ptr();
317 		else
318 			throw std::runtime_error( "invalid layout in fftw::input_ptr" );
319 	}
320 
321 	/*****************************************************/
output_ptr()322 	T* output_ptr()
323 	{
324 		if( _output_layout == layout::real )
325 			return output.real_ptr();
326 		else if( _output_layout == layout::complex_interleaved )
327 			return output.interleaved_ptr();
328 		else if( _output_layout == layout::hermitian_interleaved )
329 			return output.interleaved_ptr();
330 		else
331 			throw std::runtime_error( "invalid layout in fftw::output_ptr" );
332 	}
333 
334 	// you must call either set_forward_transform() or
335 	// set_backward_transform() before setting the input buffer
336 	/*****************************************************/
set_forward_transform()337 	void set_forward_transform()
338 	{
339 		if( _type != c2c )
340 			throw std::runtime_error( "do not use set_forward_transform() except with c2c transforms" );
341 
342 		if( _direction != forward )
343 		{
344 			_direction = forward;
345 			fftw_guts.destroy_plan();
346 			fftw_guts.make_plan((int)_lengths[dimx], (int)_lengths[dimy], (int)_lengths[dimz],
347 								(int)input.number_of_dimensions(), (int)input.batch_size(),
348 								reinterpret_cast<fftw_T*>(input.interleaved_ptr()), reinterpret_cast<fftw_T*>(output.interleaved_ptr()),
349 								(int)(_lengths[dimx]*_lengths[dimy]*_lengths[dimz]), _direction, _type);
350 		}
351 	}
352 
353 	/*****************************************************/
set_backward_transform()354 	void set_backward_transform()
355 	{
356 		if( _type != c2c )
357 			throw std::runtime_error( "do not use set_backward_transform() except with c2c transforms" );
358 
359 		if( _direction != backward )
360 		{
361 			_direction = backward;
362 			fftw_guts.destroy_plan();
363 			fftw_guts.make_plan((int)_lengths[dimx], (int)_lengths[dimy], (int)_lengths[dimz],
364 								(int)input.number_of_dimensions(), (int)input.batch_size(),
365 								reinterpret_cast<fftw_T*>(input.interleaved_ptr()), reinterpret_cast<fftw_T*>(output.interleaved_ptr()),
366 								(int)(_lengths[dimx]*_lengths[dimy]*_lengths[dimz]), _direction, _type);
367 		}
368 	}
369 
370 	/*****************************************************/
size_of_data_in_bytes()371 	size_t size_of_data_in_bytes()
372 	{
373 		return input.size_in_bytes();
374 	}
375 
376 	/*****************************************************/
forward_scale(T in)377 	void forward_scale( T in )
378 	{
379 		_forward_scale = in;
380 	}
381 
382 	/*****************************************************/
backward_scale(T in)383 	void backward_scale( T in )
384 	{
385 		_backward_scale = in;
386 	}
387 
388 	/*****************************************************/
forward_scale()389 	T forward_scale()
390 	{
391 		return _forward_scale;
392 	}
393 
394 	/*****************************************************/
backward_scale()395 	T backward_scale()
396 	{
397 		return _backward_scale;
398 	}
399 
400 	/*****************************************************/
set_all_data_to_value(T value)401 	void set_all_data_to_value( T value )
402 	{
403 		input.set_all_to_value( value );
404 	}
405 
406 	/*****************************************************/
set_all_data_to_value(T real_value,T imag_value)407 	void set_all_data_to_value( T real_value, T imag_value )
408 	{
409 		input.set_all_to_value( real_value, imag_value );
410 	}
411 
412 	/*****************************************************/
set_data_to_sawtooth(T max)413 	void set_data_to_sawtooth(T max)
414 	{
415 		input.set_all_to_sawtooth( max );
416 	}
417 
418 	/*****************************************************/
set_data_to_increase_linearly()419 	void set_data_to_increase_linearly()
420 	{
421 		input.set_all_to_linear_increase();
422 	}
423 
424 	/*****************************************************/
set_data_to_impulse()425 	void set_data_to_impulse()
426 	{
427 		input.set_all_to_impulse();
428 	}
429 
430 	/*****************************************************/
431 	// yes, the "super duper global seed" is horrible
432 	// alas, i'll have TODO it better later
set_data_to_random()433 	void set_data_to_random()
434 	{
435 		input.set_all_to_random_data( 10, super_duper_global_seed );
436 	}
437 
438 	/*****************************************************/
set_input_to_buffer(buffer<T> other_buffer)439 	void set_input_to_buffer( buffer<T> other_buffer ) {
440 		input = other_buffer;
441 	}
442 
set_output_postcallback()443 	void set_output_postcallback()
444 	{
445 		//postcallback user data
446 		buffer<T> userdata( 	output.number_of_dimensions(),
447 					output.lengths(),
448 					output.strides(),
449 					output.batch_size(),
450 					output.distance(),
451 					layout::real ,
452 					CLFFT_INPLACE
453 					);
454 
455 		userdata.set_all_to_random_data(_lengths[0], 10);
456 
457 		output *= userdata;
458 	}
459 
set_input_precallback()460 	void set_input_precallback()
461 	{
462 		//precallback user data
463 		buffer<T> userdata( 	input.number_of_dimensions(),
464 					input.lengths(),
465 					input.strides(),
466 					input.batch_size(),
467 					input.distance(),
468 					layout::real ,
469 					CLFFT_INPLACE
470 					);
471 
472 		userdata.set_all_to_random_data(_lengths[0], 10);
473 
474 		input *= userdata;
475 	}
476 
set_input_precallback_special()477 	void set_input_precallback_special()
478 	{
479 		//precallback user data
480 		buffer<T> userdata( 	input.number_of_dimensions(),
481 					input.lengths(),
482 					input.strides(),
483 					input.batch_size(),
484 					input.distance(),
485 					layout::real ,
486 					CLFFT_INPLACE
487 					);
488 
489 		userdata.set_all_to_random_data(_lengths[0], 10);
490 
491 		input.multiply_3pt_average(userdata);
492 	}
493 
set_output_postcallback_special()494 	void set_output_postcallback_special()
495 	{
496 		//postcallback user data
497 		buffer<T> userdata( 	output.number_of_dimensions(),
498 					output.lengths(),
499 					output.strides(),
500 					output.batch_size(),
501 					output.distance(),
502 					layout::real ,
503 					CLFFT_INPLACE
504 					);
505 
506 		userdata.set_all_to_random_data(_lengths[0], 10);
507 
508 		output.multiply_3pt_average(userdata);
509 	}
510 
511 	/*****************************************************/
clear_data_buffer()512 	void clear_data_buffer()
513 	{
514 		if( _input_layout == layout::real )
515 		{
516 			set_all_data_to_value( 0.0f );
517 		}
518 		else
519 		{
520 			set_all_data_to_value( 0.0f, 0.0f );
521 		}
522 	}
523 
524 	/*****************************************************/
transform()525 	void transform()
526 	{
527 		fftw_guts.execute();
528 
529 		if( _type == c2c )
530 		{
531 			if( _direction == forward  ) {
532 				output.scale_data( static_cast<T>( forward_scale( ) ) );
533 			}
534 			else if( _direction == backward  ) {
535 				output.scale_data( static_cast<T>( backward_scale( ) ) );
536 			}
537 		}
538 		else if( _type == r2c )
539 		{
540 			output.scale_data( static_cast<T>( forward_scale( ) ) );
541 		}
542 		else if( _type == c2r )
543 		{
544 			output.scale_data( static_cast<T>( backward_scale( ) ) );
545 		}
546 		else
547 			throw std::runtime_error( "invalid transform type in fftw::transform()" );
548 	}
549 
550 	/*****************************************************/
result()551 	buffer<T> & result()
552 	{
553 		return output;
554 	}
555 
556 	/*****************************************************/
input_buffer()557 	buffer<T> & input_buffer()
558 	{
559 		return input;
560 	}
561 };
562 
563 #endif
564