1 /* ************************************************************************ 2 * Copyright 2013 Advanced Micro Devices, Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * ************************************************************************/ 16 17 18 #pragma once 19 #if !defined( CLFFT_FFTWTRANSFORM_H ) 20 #define CLFFT_FFTWTRANSFORM_H 21 22 #include <vector> 23 #include "fftw3.h" 24 #include "buffer.h" 25 #include "../client/openCL.misc.h" // we need this to leverage the CLFFT_INPLACE and _OUTOFPLACE enums 26 27 enum fftw_direction {forward=-1, backward=+1}; 28 29 enum fftw_transform_type {c2c, r2c, c2r}; 30 31 template <typename T, typename fftw_T> 32 class fftw_wrapper 33 {}; 34 35 template <> 36 class fftw_wrapper<float, fftwf_complex> 37 { 38 public: 39 fftwf_plan plan; 40 make_plan(int x,int y,int z,int num_dimensions,int batch_size,fftwf_complex * input_ptr,fftwf_complex * output_ptr,int num_points_in_single_batch,fftw_direction direction,fftw_transform_type type)41 void make_plan( int x, int y, int z, int num_dimensions, int batch_size, fftwf_complex* input_ptr, fftwf_complex* output_ptr, int num_points_in_single_batch, fftw_direction direction, fftw_transform_type type ) 42 { 43 // we need to swap x,y,z dimensions because of a row-column discrepancy between clfft and fftw 44 int lengths[max_dimension] = {z, y, x}; 45 46 if( type == c2c ) 47 { 48 plan = fftwf_plan_many_dft( num_dimensions, 49 // because we swapped dimensions up above, we need to start 50 // at the end of the array and count backwards to get the 51 // correct dimensions passed in to fftw 52 // e.g. if max_dimension is 3 and number_of_dimensions is 2: 53 // lengths = {dimz, dimy, dimx} 54 // lengths + 3 - 2 = lengths + 1 55 // so we will skip dimz and pass in a pointer to {dimy, dimx} 56 lengths+max_dimension-num_dimensions, 57 batch_size, 58 input_ptr, NULL, 59 1, num_points_in_single_batch, 60 output_ptr, NULL, 61 1, num_points_in_single_batch, 62 direction, FFTW_ESTIMATE); 63 } 64 else if( type == r2c ) 65 { 66 plan = fftwf_plan_many_dft_r2c( num_dimensions, 67 // because we swapped dimensions up above, we need to start 68 // at the end of the array and count backwards to get the 69 // correct dimensions passed in to fftw 70 // e.g. if max_dimension is 3 and number_of_dimensions is 2: 71 // lengths = {dimz, dimy, dimx} 72 // lengths + 3 - 2 = lengths + 1 73 // so we will skip dimz and pass in a pointer to {dimy, dimx} 74 lengths+max_dimension-num_dimensions, 75 batch_size, 76 reinterpret_cast<float*>(input_ptr), NULL, 77 1, num_points_in_single_batch, 78 output_ptr, NULL, 79 1, (x/2 + 1) * y * z, 80 FFTW_ESTIMATE); 81 } 82 else if( type == c2r ) 83 { 84 plan = fftwf_plan_many_dft_c2r( num_dimensions, 85 // because we swapped dimensions up above, we need to start 86 // at the end of the array and count backwards to get the 87 // correct dimensions passed in to fftw 88 // e.g. if max_dimension is 3 and number_of_dimensions is 2: 89 // lengths = {dimz, dimy, dimx} 90 // lengths + 3 - 2 = lengths + 1 91 // so we will skip dimz and pass in a pointer to {dimy, dimx} 92 lengths+max_dimension-num_dimensions, 93 batch_size, 94 input_ptr, NULL, 95 1, (x/2 + 1) * y * z, 96 reinterpret_cast<float*>(output_ptr), NULL, 97 1, num_points_in_single_batch, 98 FFTW_ESTIMATE); 99 } 100 else 101 throw std::runtime_error( "invalid transform type in <float>make_plan" ); 102 } 103 fftw_wrapper(int x,int y,int z,int num_dimensions,int batch_size,fftwf_complex * input_ptr,fftwf_complex * output_ptr,int num_points_in_single_batch,fftw_direction direction,fftw_transform_type type)104 fftw_wrapper( int x, int y, int z, int num_dimensions, int batch_size, fftwf_complex* input_ptr, fftwf_complex* output_ptr, int num_points_in_single_batch, fftw_direction direction, fftw_transform_type type ) 105 { 106 make_plan( x, y, z, num_dimensions, batch_size, input_ptr, output_ptr, num_points_in_single_batch, direction, type ); 107 } 108 destroy_plan()109 void destroy_plan() 110 { 111 fftwf_destroy_plan(plan); 112 } 113 ~fftw_wrapper()114 ~fftw_wrapper() 115 { 116 destroy_plan(); 117 } 118 execute()119 void execute() 120 { 121 fftwf_execute(plan); 122 } 123 }; 124 125 template <> 126 class fftw_wrapper<double, fftw_complex> 127 { 128 public: 129 fftw_plan plan; 130 make_plan(int x,int y,int z,int num_dimensions,int batch_size,fftw_complex * input_ptr,fftw_complex * output_ptr,int num_points_in_single_batch,fftw_direction direction,fftw_transform_type type)131 void make_plan( int x, int y, int z, int num_dimensions, int batch_size, fftw_complex* input_ptr, fftw_complex* output_ptr, int num_points_in_single_batch, fftw_direction direction, fftw_transform_type type ) 132 { 133 // we need to swap x,y,z dimensions because of a row-column discrepancy between clfft and fftw 134 int lengths[max_dimension] = {z, y, x}; 135 136 if( type == c2c ) 137 { 138 plan = fftw_plan_many_dft( num_dimensions, 139 // because we swapped dimensions up above, we need to start 140 // at the end of the array and count backwards to get the 141 // correct dimensions passed in to fftw 142 // e.g. if max_dimension is 3 and number_of_dimensions is 2: 143 // lengths = {dimz, dimy, dimx} 144 // lengths + 3 - 2 = lengths + 1 145 // so we will skip dimz and pass in a pointer to {dimy, dimx} 146 lengths+max_dimension-num_dimensions, 147 batch_size, 148 input_ptr, NULL, 149 1, num_points_in_single_batch, 150 output_ptr, NULL, 151 1, num_points_in_single_batch, 152 direction, FFTW_ESTIMATE); 153 } 154 else if( type == r2c ) 155 { 156 plan = fftw_plan_many_dft_r2c( num_dimensions, 157 // because we swapped dimensions up above, we need to start 158 // at the end of the array and count backwards to get the 159 // correct dimensions passed in to fftw 160 // e.g. if max_dimension is 3 and number_of_dimensions is 2: 161 // lengths = {dimz, dimy, dimx} 162 // lengths + 3 - 2 = lengths + 1 163 // so we will skip dimz and pass in a pointer to {dimy, dimx} 164 lengths+max_dimension-num_dimensions, 165 batch_size, 166 reinterpret_cast<double*>(input_ptr), NULL, 167 1, num_points_in_single_batch, 168 output_ptr, NULL, 169 1, (x/2 + 1) * y * z, 170 FFTW_ESTIMATE); 171 } 172 else if( type == c2r ) 173 { 174 plan = fftw_plan_many_dft_c2r( num_dimensions, 175 // because we swapped dimensions up above, we need to start 176 // at the end of the array and count backwards to get the 177 // correct dimensions passed in to fftw 178 // e.g. if max_dimension is 3 and number_of_dimensions is 2: 179 // lengths = {dimz, dimy, dimx} 180 // lengths + 3 - 2 = lengths + 1 181 // so we will skip dimz and pass in a pointer to {dimy, dimx} 182 lengths+max_dimension-num_dimensions, 183 batch_size, 184 input_ptr, NULL, 185 1, (x/2 + 1) * y * z, 186 reinterpret_cast<double*>(output_ptr), NULL, 187 1, num_points_in_single_batch, 188 FFTW_ESTIMATE); 189 } 190 else 191 throw std::runtime_error( "invalid transform type in <double>make_plan" ); 192 } 193 fftw_wrapper(int x,int y,int z,int num_dimensions,int batch_size,fftw_complex * input_ptr,fftw_complex * output_ptr,int num_points_in_single_batch,fftw_direction direction,fftw_transform_type type)194 fftw_wrapper( int x, int y, int z, int num_dimensions, int batch_size, fftw_complex* input_ptr, fftw_complex* output_ptr, int num_points_in_single_batch, fftw_direction direction, fftw_transform_type type ) 195 { 196 make_plan( x, y, z, num_dimensions, batch_size, input_ptr, output_ptr, num_points_in_single_batch, direction, type ); 197 } 198 destroy_plan()199 void destroy_plan() 200 { 201 fftw_destroy_plan(plan); 202 } 203 ~fftw_wrapper()204 ~fftw_wrapper() 205 { 206 destroy_plan(); 207 } 208 execute()209 void execute() 210 { 211 fftw_execute(plan); 212 } 213 }; 214 215 /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/ 216 /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/ 217 template <typename T, typename fftw_T> 218 class fftw { 219 private: 220 static const size_t tightly_packed_distance = 0; 221 222 std::vector<size_t> _lengths; 223 fftw_direction _direction; 224 fftw_transform_type _type; 225 layout::buffer_layout_t _input_layout, _output_layout; 226 size_t _batch_size; 227 buffer<T> input; 228 buffer<T> output; 229 fftw_wrapper<T, fftw_T> fftw_guts; 230 231 T _forward_scale, _backward_scale; 232 public: 233 /*****************************************************/ fftw(const size_t number_of_dimensions_in,const size_t * lengths_in,const size_t batch_size_in,fftw_transform_type type_in)234 fftw( const size_t number_of_dimensions_in, const size_t* lengths_in, const size_t batch_size_in, fftw_transform_type type_in ) 235 : _lengths( initialized_lengths( number_of_dimensions_in, lengths_in ) ) 236 , _direction( forward ) 237 , _type( type_in ) 238 , _input_layout( initialized_input_layout() ) 239 , _output_layout( initialized_output_layout() ) 240 , _batch_size( batch_size_in ) 241 , input( number_of_dimensions_in, 242 lengths_in, 243 NULL, 244 batch_size_in, 245 tightly_packed_distance, 246 _input_layout, 247 CLFFT_OUTOFPLACE ) 248 , output( number_of_dimensions_in, 249 lengths_in, 250 NULL, 251 batch_size_in, 252 tightly_packed_distance, 253 _output_layout, 254 CLFFT_OUTOFPLACE ) 255 , _forward_scale( 1.0f ) 256 , _backward_scale( 1.0f/T(input.number_of_data_points_single_batch()) ) 257 , fftw_guts( (int)_lengths[dimx], (int)_lengths[dimy], (int)_lengths[dimz], 258 (int)number_of_dimensions_in, (int)batch_size_in, 259 reinterpret_cast<fftw_T*>(input_ptr()), 260 reinterpret_cast<fftw_T*>(output_ptr()), 261 (int)(_lengths[dimx]*_lengths[dimy]*_lengths[dimz]), _direction, _type) 262 { 263 clear_data_buffer(); 264 } 265 266 /*****************************************************/ ~fftw()267 ~fftw() {} 268 269 /*****************************************************/ initialized_input_layout()270 layout::buffer_layout_t initialized_input_layout() 271 { 272 if( _type == c2c ) 273 return layout::complex_interleaved; 274 else if( _type == r2c ) 275 return layout::real; 276 else if( _type == c2r ) 277 return layout::hermitian_interleaved; 278 else 279 throw std::runtime_error( "invalid transform type in initialized_input_layout" ); 280 } 281 282 /*****************************************************/ initialized_output_layout()283 layout::buffer_layout_t initialized_output_layout() 284 { 285 if( _type == c2c ) 286 return layout::complex_interleaved; 287 else if( _type == r2c ) 288 return layout::hermitian_interleaved; 289 else if( _type == c2r ) 290 return layout::real; 291 else 292 throw std::runtime_error( "invalid transform type in initialized_input_layout" ); 293 } 294 295 /*****************************************************/ initialized_lengths(const size_t number_of_dimensions,const size_t * lengths_in)296 std::vector<size_t> initialized_lengths( const size_t number_of_dimensions, const size_t* lengths_in ) 297 { 298 std::vector<size_t> lengths( 3, 1 ); // start with 1, 1, 1 299 300 for( size_t i = 0; i < number_of_dimensions; i++ ) 301 { 302 lengths[i] = lengths_in[i]; 303 } 304 305 return lengths; 306 } 307 308 /*****************************************************/ input_ptr()309 T* input_ptr() 310 { 311 if( _input_layout == layout::real ) 312 return input.real_ptr(); 313 else if( _input_layout == layout::complex_interleaved ) 314 return input.interleaved_ptr(); 315 else if( _input_layout == layout::hermitian_interleaved ) 316 return input.interleaved_ptr(); 317 else 318 throw std::runtime_error( "invalid layout in fftw::input_ptr" ); 319 } 320 321 /*****************************************************/ output_ptr()322 T* output_ptr() 323 { 324 if( _output_layout == layout::real ) 325 return output.real_ptr(); 326 else if( _output_layout == layout::complex_interleaved ) 327 return output.interleaved_ptr(); 328 else if( _output_layout == layout::hermitian_interleaved ) 329 return output.interleaved_ptr(); 330 else 331 throw std::runtime_error( "invalid layout in fftw::output_ptr" ); 332 } 333 334 // you must call either set_forward_transform() or 335 // set_backward_transform() before setting the input buffer 336 /*****************************************************/ set_forward_transform()337 void set_forward_transform() 338 { 339 if( _type != c2c ) 340 throw std::runtime_error( "do not use set_forward_transform() except with c2c transforms" ); 341 342 if( _direction != forward ) 343 { 344 _direction = forward; 345 fftw_guts.destroy_plan(); 346 fftw_guts.make_plan((int)_lengths[dimx], (int)_lengths[dimy], (int)_lengths[dimz], 347 (int)input.number_of_dimensions(), (int)input.batch_size(), 348 reinterpret_cast<fftw_T*>(input.interleaved_ptr()), reinterpret_cast<fftw_T*>(output.interleaved_ptr()), 349 (int)(_lengths[dimx]*_lengths[dimy]*_lengths[dimz]), _direction, _type); 350 } 351 } 352 353 /*****************************************************/ set_backward_transform()354 void set_backward_transform() 355 { 356 if( _type != c2c ) 357 throw std::runtime_error( "do not use set_backward_transform() except with c2c transforms" ); 358 359 if( _direction != backward ) 360 { 361 _direction = backward; 362 fftw_guts.destroy_plan(); 363 fftw_guts.make_plan((int)_lengths[dimx], (int)_lengths[dimy], (int)_lengths[dimz], 364 (int)input.number_of_dimensions(), (int)input.batch_size(), 365 reinterpret_cast<fftw_T*>(input.interleaved_ptr()), reinterpret_cast<fftw_T*>(output.interleaved_ptr()), 366 (int)(_lengths[dimx]*_lengths[dimy]*_lengths[dimz]), _direction, _type); 367 } 368 } 369 370 /*****************************************************/ size_of_data_in_bytes()371 size_t size_of_data_in_bytes() 372 { 373 return input.size_in_bytes(); 374 } 375 376 /*****************************************************/ forward_scale(T in)377 void forward_scale( T in ) 378 { 379 _forward_scale = in; 380 } 381 382 /*****************************************************/ backward_scale(T in)383 void backward_scale( T in ) 384 { 385 _backward_scale = in; 386 } 387 388 /*****************************************************/ forward_scale()389 T forward_scale() 390 { 391 return _forward_scale; 392 } 393 394 /*****************************************************/ backward_scale()395 T backward_scale() 396 { 397 return _backward_scale; 398 } 399 400 /*****************************************************/ set_all_data_to_value(T value)401 void set_all_data_to_value( T value ) 402 { 403 input.set_all_to_value( value ); 404 } 405 406 /*****************************************************/ set_all_data_to_value(T real_value,T imag_value)407 void set_all_data_to_value( T real_value, T imag_value ) 408 { 409 input.set_all_to_value( real_value, imag_value ); 410 } 411 412 /*****************************************************/ set_data_to_sawtooth(T max)413 void set_data_to_sawtooth(T max) 414 { 415 input.set_all_to_sawtooth( max ); 416 } 417 418 /*****************************************************/ set_data_to_increase_linearly()419 void set_data_to_increase_linearly() 420 { 421 input.set_all_to_linear_increase(); 422 } 423 424 /*****************************************************/ set_data_to_impulse()425 void set_data_to_impulse() 426 { 427 input.set_all_to_impulse(); 428 } 429 430 /*****************************************************/ 431 // yes, the "super duper global seed" is horrible 432 // alas, i'll have TODO it better later set_data_to_random()433 void set_data_to_random() 434 { 435 input.set_all_to_random_data( 10, super_duper_global_seed ); 436 } 437 438 /*****************************************************/ set_input_to_buffer(buffer<T> other_buffer)439 void set_input_to_buffer( buffer<T> other_buffer ) { 440 input = other_buffer; 441 } 442 set_output_postcallback()443 void set_output_postcallback() 444 { 445 //postcallback user data 446 buffer<T> userdata( output.number_of_dimensions(), 447 output.lengths(), 448 output.strides(), 449 output.batch_size(), 450 output.distance(), 451 layout::real , 452 CLFFT_INPLACE 453 ); 454 455 userdata.set_all_to_random_data(_lengths[0], 10); 456 457 output *= userdata; 458 } 459 set_input_precallback()460 void set_input_precallback() 461 { 462 //precallback user data 463 buffer<T> userdata( input.number_of_dimensions(), 464 input.lengths(), 465 input.strides(), 466 input.batch_size(), 467 input.distance(), 468 layout::real , 469 CLFFT_INPLACE 470 ); 471 472 userdata.set_all_to_random_data(_lengths[0], 10); 473 474 input *= userdata; 475 } 476 set_input_precallback_special()477 void set_input_precallback_special() 478 { 479 //precallback user data 480 buffer<T> userdata( input.number_of_dimensions(), 481 input.lengths(), 482 input.strides(), 483 input.batch_size(), 484 input.distance(), 485 layout::real , 486 CLFFT_INPLACE 487 ); 488 489 userdata.set_all_to_random_data(_lengths[0], 10); 490 491 input.multiply_3pt_average(userdata); 492 } 493 set_output_postcallback_special()494 void set_output_postcallback_special() 495 { 496 //postcallback user data 497 buffer<T> userdata( output.number_of_dimensions(), 498 output.lengths(), 499 output.strides(), 500 output.batch_size(), 501 output.distance(), 502 layout::real , 503 CLFFT_INPLACE 504 ); 505 506 userdata.set_all_to_random_data(_lengths[0], 10); 507 508 output.multiply_3pt_average(userdata); 509 } 510 511 /*****************************************************/ clear_data_buffer()512 void clear_data_buffer() 513 { 514 if( _input_layout == layout::real ) 515 { 516 set_all_data_to_value( 0.0f ); 517 } 518 else 519 { 520 set_all_data_to_value( 0.0f, 0.0f ); 521 } 522 } 523 524 /*****************************************************/ transform()525 void transform() 526 { 527 fftw_guts.execute(); 528 529 if( _type == c2c ) 530 { 531 if( _direction == forward ) { 532 output.scale_data( static_cast<T>( forward_scale( ) ) ); 533 } 534 else if( _direction == backward ) { 535 output.scale_data( static_cast<T>( backward_scale( ) ) ); 536 } 537 } 538 else if( _type == r2c ) 539 { 540 output.scale_data( static_cast<T>( forward_scale( ) ) ); 541 } 542 else if( _type == c2r ) 543 { 544 output.scale_data( static_cast<T>( backward_scale( ) ) ); 545 } 546 else 547 throw std::runtime_error( "invalid transform type in fftw::transform()" ); 548 } 549 550 /*****************************************************/ result()551 buffer<T> & result() 552 { 553 return output; 554 } 555 556 /*****************************************************/ input_buffer()557 buffer<T> & input_buffer() 558 { 559 return input; 560 } 561 }; 562 563 #endif 564