1 /*
2  * Copyright (c) 2016-2021, The OSKAR Developers.
3  * See the LICENSE file at the top-level directory of this distribution.
4  */
5 
6 #include "imager/private_imager.h"
7 #include "imager/oskar_imager.h"
8 
9 #include "imager/define_grid_tile_grid.h"
10 #include "imager/private_imager_update_plane_fft.h"
11 #include "imager/oskar_grid_simple.h"
12 #include "math/oskar_prefix_sum.h"
13 #include "math/oskar_round_robin.h"
14 #include "utility/oskar_device.h"
15 #include "utility/oskar_thread.h"
16 
17 #include <assert.h>
18 
19 #ifdef __cplusplus
20 extern "C" {
21 #endif
22 
23 static void* run_subset(void* arg);
24 
25 struct ThreadArgs
26 {
27     oskar_Imager* h;
28     size_t num_vis, num_skipped;
29     const oskar_Mem *uu, *vv, *amps, *weight;
30     oskar_Mem *plane;
31     double plane_norm;
32     int grid_size, i_plane, thread_id;
33 };
34 typedef struct ThreadArgs ThreadArgs;
35 
oskar_imager_update_plane_fft(oskar_Imager * h,size_t num_vis,const oskar_Mem * uu,const oskar_Mem * vv,const oskar_Mem * amps,const oskar_Mem * weight,int i_plane,oskar_Mem * plane,double * plane_norm,size_t * num_skipped,int * status)36 void oskar_imager_update_plane_fft(oskar_Imager* h, size_t num_vis,
37         const oskar_Mem* uu, const oskar_Mem* vv, const oskar_Mem* amps,
38         const oskar_Mem* weight, int i_plane, oskar_Mem* plane,
39         double* plane_norm, size_t* num_skipped, int* status)
40 {
41     if (*status) return;
42     if (!h->grid_on_gpu || h->num_gpus == 0)
43     {
44         oskar_Mem* plane_ptr = plane;
45         if (!plane_ptr)
46         {
47             if (h->planes)
48             {
49                 plane_ptr = h->planes[i_plane];
50             }
51             else
52             {
53                 *status = OSKAR_ERR_MEMORY_NOT_ALLOCATED;
54                 return;
55             }
56         }
57         if (oskar_mem_location(plane_ptr) != OSKAR_CPU)
58         {
59             *status = OSKAR_ERR_LOCATION_MISMATCH;
60             return;
61         }
62         if (oskar_mem_precision(plane_ptr) != h->imager_prec)
63         {
64             *status = OSKAR_ERR_TYPE_MISMATCH;
65             return;
66         }
67         const int grid_size = oskar_imager_plane_size(h);
68         const size_t num_cells = ((size_t) grid_size) * ((size_t) grid_size);
69         oskar_mem_ensure(plane_ptr, num_cells, status);
70         if (*status) return;
71         if (h->imager_prec == OSKAR_DOUBLE)
72         {
73             oskar_grid_simple_d(h->support, h->oversample,
74                     oskar_mem_double_const(h->conv_func, status), num_vis,
75                     oskar_mem_double_const(uu, status),
76                     oskar_mem_double_const(vv, status),
77                     oskar_mem_double_const(amps, status),
78                     oskar_mem_double_const(weight, status),
79                     h->cellsize_rad,
80                     grid_size, num_skipped, plane_norm,
81                     oskar_mem_double(plane_ptr, status));
82         }
83         else
84         {
85             oskar_grid_simple_f(h->support, h->oversample,
86                     oskar_mem_float_const(h->conv_func, status), num_vis,
87                     oskar_mem_float_const(uu, status),
88                     oskar_mem_float_const(vv, status),
89                     oskar_mem_float_const(amps, status),
90                     oskar_mem_float_const(weight, status),
91                     (float) (h->cellsize_rad),
92                     grid_size, num_skipped, plane_norm,
93                     oskar_mem_float(plane_ptr, status));
94         }
95     }
96     else
97     {
98         int i = 0;
99         oskar_Thread** threads = 0;
100         ThreadArgs* args = 0;
101 
102         /* Set up worker threads. */
103         const int num_threads = h->num_gpus;
104         threads = (oskar_Thread**) calloc(num_threads, sizeof(oskar_Thread*));
105         args = (ThreadArgs*) calloc(num_threads, sizeof(ThreadArgs));
106         for (i = 0; i < num_threads; ++i)
107         {
108             args[i].h = h;
109             args[i].num_vis = num_vis;
110             args[i].uu = uu;
111             args[i].vv = vv;
112             args[i].amps = amps;
113             args[i].weight = weight;
114             args[i].plane = plane;
115             args[i].i_plane = i_plane;
116             args[i].thread_id = i;
117         }
118 
119         /* Set status code. */
120         h->status = *status;
121 
122         /* Start the worker threads. */
123         for (i = 0; i < num_threads; ++i)
124         {
125             threads[i] = oskar_thread_create(run_subset, (void*)&args[i], 0);
126         }
127 
128         /* Wait for worker threads to finish. */
129         for (i = 0; i < num_threads; ++i)
130         {
131             oskar_thread_join(threads[i]);
132             oskar_thread_free(threads[i]);
133             *plane_norm += args[i].plane_norm;
134             *num_skipped += args[i].num_skipped;
135         }
136         free(threads);
137         free(args);
138 
139         /* Get status code. */
140         *status = h->status;
141     }
142 }
143 
run_subset(void * arg)144 static void* run_subset(void* arg)
145 {
146     oskar_Imager* h = 0;
147     oskar_Mem *plane = 0;
148     const oskar_Mem *uu = 0, *vv = 0, *amps = 0, *weight = 0;
149     int count_skipped = 0, num_total = 0, *status = 0;
150     int start = 0, num_points = 0;
151     size_t local_size[] = {1, 1, 1}, global_size[] = {1, 1, 1};
152     DeviceData* d = 0;
153 
154     /* Get thread function arguments. */
155     ThreadArgs* a = (ThreadArgs*) arg;
156     const int i_plane = a->i_plane;
157     const int thread_id = a->thread_id;
158     const size_t num_vis = a->num_vis;
159     h = a->h;
160     status = &(h->status);
161     uu = a->uu;
162     vv = a->vv;
163     amps = a->amps;
164     weight = a->weight;
165 
166     /* Set the device used by the thread. */
167     d = &h->d[thread_id];
168     plane = a->plane;
169     if (!plane)
170     {
171         if (d->planes)
172         {
173             plane = d->planes[i_plane];
174         }
175         else
176         {
177             *status = OSKAR_ERR_MEMORY_NOT_ALLOCATED;
178             return 0;
179         }
180     }
181     if (oskar_mem_location(plane) != h->dev_loc)
182     {
183         *status = OSKAR_ERR_LOCATION_MISMATCH;
184         return 0;
185     }
186     if (oskar_mem_precision(plane) != h->imager_prec)
187     {
188         *status = OSKAR_ERR_TYPE_MISMATCH;
189         return 0;
190     }
191     oskar_device_set(h->dev_loc, h->gpu_ids[thread_id], status);
192 
193     const int location = h->dev_loc;
194     const int grid_size = oskar_imager_plane_size(h);
195     const int vis_type = oskar_mem_type(amps);
196     const int is_dbl = oskar_type_is_double(vis_type);
197     const int grid_centre = grid_size / 2;
198     const double grid_scale = grid_size * h->cellsize_rad;
199     const float grid_scale_f = (float) grid_scale;
200 
201     /* Define the tile size and number of tiles in each direction.
202      * A tile consists of SHMSZ grid cells per thread in shared memory
203      * and REGSZ grid cells per thread in registers. */
204     const int tile_size_u = 32;
205     const int tile_size_v = (SHMSZ + REGSZ);
206     const int num_tiles_u = (grid_size + tile_size_u - 1) / tile_size_u;
207     const int num_tiles_v = (grid_size + tile_size_v - 1) / tile_size_v;
208     const int num_tiles = num_tiles_u * num_tiles_v;
209 
210     /* Which tile contains the grid centre? */
211     const int c_tile_u = grid_centre / tile_size_u;
212     const int c_tile_v = grid_centre / tile_size_v;
213 
214     /* Compute difference between centre of centre tile and grid centre
215      * to ensure the centre of the grid is in the centre of a tile. */
216     const int top_left_u = grid_centre -
217             c_tile_u * tile_size_u - tile_size_u / 2;
218     const int top_left_v = grid_centre -
219             c_tile_v * tile_size_v - tile_size_v / 2;
220     assert(top_left_u <= 0);
221     assert(top_left_v <= 0);
222 
223     /* Set up scratch memory. */
224     oskar_round_robin((int)num_vis, h->num_gpus, thread_id,
225             &num_points, &start);
226     oskar_mem_ensure(d->uu, num_points, status);
227     oskar_mem_ensure(d->vv, num_points, status);
228     oskar_mem_ensure(d->vis, num_points, status);
229     oskar_mem_ensure(d->weight, num_points, status);
230     oskar_mem_copy_contents(d->uu, uu, 0, start, num_points, status);
231     oskar_mem_copy_contents(d->vv, vv, 0, start, num_points, status);
232     oskar_mem_copy_contents(d->vis, amps, 0, start, num_points, status);
233     oskar_mem_copy_contents(d->weight, weight, 0, start, num_points, status);
234     oskar_mem_ensure(d->num_points_in_tiles, num_tiles, status);
235     oskar_mem_ensure(d->tile_offsets, num_tiles + 1, status);
236     oskar_mem_ensure(d->tile_locks, num_tiles, status);
237     oskar_mem_clear_contents(d->counter, status);
238     oskar_mem_clear_contents(d->count_skipped, status);
239     oskar_mem_clear_contents(d->norm, status);
240     oskar_mem_clear_contents(d->num_points_in_tiles, status);
241     oskar_mem_clear_contents(d->tile_locks, status);
242     /* Don't need to clear d->tile_offsets, as it gets overwritten. */
243 
244     /* Count the number of elements in each tile. */
245     const float inv_tile_size_u = 1.0f / (float) tile_size_u;
246     const float inv_tile_size_v = 1.0f / (float) tile_size_v;
247     {
248         const char* k = 0;
249         if (oskar_type_is_single(vis_type))
250         {
251             k = "grid_tile_count_simple_float";
252         }
253         else if (oskar_type_is_double(vis_type))
254         {
255             k = "grid_tile_count_simple_double";
256         }
257         else
258         {
259             *status = OSKAR_ERR_BAD_DATA_TYPE;
260         }
261         local_size[0] = 512;
262         oskar_device_check_local_size(location, 0, local_size);
263         global_size[0] = oskar_device_global_size(num_points, local_size[0]);
264         const oskar_Arg args[] = {
265                 {INT_SZ, &h->support},
266                 {INT_SZ, &num_points},
267                 {PTR_SZ, oskar_mem_buffer_const(d->uu)},
268                 {PTR_SZ, oskar_mem_buffer_const(d->vv)},
269                 {INT_SZ, &grid_size},
270                 {INT_SZ, &grid_centre},
271                 {is_dbl ? DBL_SZ : FLT_SZ,  is_dbl ?
272                         (const void*)&grid_scale :
273                         (const void*)&grid_scale_f},
274                 {FLT_SZ, (const void*)&inv_tile_size_u},
275                 {FLT_SZ, (const void*)&inv_tile_size_v},
276                 {INT_SZ, &num_tiles_u},
277                 {INT_SZ, &top_left_u},
278                 {INT_SZ, &top_left_v},
279                 {PTR_SZ, oskar_mem_buffer(d->num_points_in_tiles)},
280                 {PTR_SZ, oskar_mem_buffer(d->count_skipped)}
281         };
282         oskar_device_launch_kernel(k, location, 1, local_size, global_size,
283                 sizeof(args) / sizeof(oskar_Arg), args, 0, 0, status);
284     }
285 
286     /* Get the offsets for each tile using prefix sum. */
287     oskar_prefix_sum(num_tiles,
288             d->num_points_in_tiles, d->tile_offsets, status);
289 
290     /* Get the total number of visibilities to process. */
291     oskar_mem_read_element(d->tile_offsets, num_tiles, &num_total, status);
292     oskar_mem_read_element(d->count_skipped, 0, &count_skipped, status);
293     a->num_skipped = (size_t) count_skipped;
294     /*printf("Total points: %d (factor %.3f increase)\n", num_total,
295             (float)num_total / (float)num_points);*/
296 
297     /* Bucket sort the data into tiles. */
298     oskar_mem_ensure(d->sorted_uu, num_total, status);
299     oskar_mem_ensure(d->sorted_vv, num_total, status);
300     oskar_mem_ensure(d->sorted_wt, num_total, status);
301     oskar_mem_ensure(d->sorted_vis, num_total, status);
302     oskar_mem_ensure(d->sorted_tile, num_total, status);
303     {
304         const char* k = 0;
305         if (oskar_type_is_single(vis_type))
306         {
307             k = "grid_tile_bucket_sort_simple_float";
308         }
309         else if (oskar_type_is_double(vis_type))
310         {
311             k = "grid_tile_bucket_sort_simple_double";
312         }
313         else
314         {
315             *status = OSKAR_ERR_BAD_DATA_TYPE;
316         }
317         local_size[0] = 128;
318         oskar_device_check_local_size(location, 0, local_size);
319         global_size[0] = oskar_device_global_size(num_points, local_size[0]);
320         const oskar_Arg args[] = {
321                 {INT_SZ, &h->support},
322                 {INT_SZ, &num_points},
323                 {PTR_SZ, oskar_mem_buffer_const(d->uu)},
324                 {PTR_SZ, oskar_mem_buffer_const(d->vv)},
325                 {PTR_SZ, oskar_mem_buffer_const(d->vis)},
326                 {PTR_SZ, oskar_mem_buffer_const(d->weight)},
327                 {INT_SZ, &grid_size},
328                 {INT_SZ, &grid_centre},
329                 {is_dbl ? DBL_SZ : FLT_SZ,  is_dbl ?
330                         (const void*)&grid_scale :
331                         (const void*)&grid_scale_f},
332                 {FLT_SZ, (const void*)&inv_tile_size_u},
333                 {FLT_SZ, (const void*)&inv_tile_size_v},
334                 {INT_SZ, &num_tiles_u},
335                 {INT_SZ, &top_left_u},
336                 {INT_SZ, &top_left_v},
337                 {PTR_SZ, oskar_mem_buffer(d->tile_offsets)},
338                 {PTR_SZ, oskar_mem_buffer(d->sorted_uu)},
339                 {PTR_SZ, oskar_mem_buffer(d->sorted_vv)},
340                 {PTR_SZ, oskar_mem_buffer(d->sorted_vis)},
341                 {PTR_SZ, oskar_mem_buffer(d->sorted_wt)},
342                 {PTR_SZ, oskar_mem_buffer(d->sorted_tile)}
343         };
344         oskar_device_launch_kernel(k, location, 1, local_size, global_size,
345                 sizeof(args) / sizeof(oskar_Arg), args, 0, 0, status);
346     }
347 
348     /* Update the grid. */
349     {
350         const char* k = 0;
351         if (oskar_type_is_single(vis_type))
352         {
353             k = "grid_tile_grid_simple_float";
354         }
355         else if (oskar_type_is_double(vis_type))
356         {
357             k = "grid_tile_grid_simple_double";
358         }
359         else
360         {
361             *status = OSKAR_ERR_BAD_DATA_TYPE;
362         }
363         local_size[0] = tile_size_u;
364         size_t num_blocks = (num_points + local_size[0] - 1) / local_size[0];
365         if (num_blocks > 10000) num_blocks = 10000;
366         global_size[0] = local_size[0] * num_blocks;
367         const size_t sh_mem_size = oskar_mem_element_size(vis_type) *
368                 SHMSZ * local_size[0];
369         const oskar_Arg args[] = {
370                 {INT_SZ, &h->support},
371                 {INT_SZ, &h->oversample},
372                 {PTR_SZ, oskar_mem_buffer_const(d->conv_func)},
373                 {INT_SZ, &grid_size},
374                 {INT_SZ, &grid_centre},
375                 {INT_SZ, &tile_size_u},
376                 {INT_SZ, &tile_size_v},
377                 {INT_SZ, &num_tiles_u},
378                 {INT_SZ, &top_left_u},
379                 {INT_SZ, &top_left_v},
380                 {INT_SZ, &num_total},
381                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_uu)},
382                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_vv)},
383                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_vis)},
384                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_wt)},
385                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_tile)},
386                 {PTR_SZ, oskar_mem_buffer(d->tile_locks)},
387                 {PTR_SZ, oskar_mem_buffer(d->counter)},
388                 {PTR_SZ, oskar_mem_buffer(d->norm)},
389                 {PTR_SZ, oskar_mem_buffer(plane)}
390         };
391         oskar_device_launch_kernel(k, location, 1, local_size, global_size,
392                 sizeof(args) / sizeof(oskar_Arg), args, 1, &sh_mem_size,
393                 status);
394     }
395 
396     /* Update the normalisation value. */
397     if (oskar_mem_type(d->norm) == OSKAR_SINGLE)
398     {
399         float temp_norm = 0.0f;
400         oskar_mem_read_element(d->norm, 0, &temp_norm, status);
401         a->plane_norm = temp_norm;
402     }
403     else
404     {
405         double temp_norm = 0.0;
406         oskar_mem_read_element(d->norm, 0, &temp_norm, status);
407         a->plane_norm = temp_norm;
408     }
409     return 0;
410 }
411 
412 #ifdef __cplusplus
413 }
414 #endif
415