1 /*
2  * Copyright (c) 2016-2021, The OSKAR Developers.
3  * See the LICENSE file at the top-level directory of this distribution.
4  */
5 
6 #include "imager/private_imager.h"
7 #include "imager/oskar_imager.h"
8 
9 #include "imager/define_grid_tile_grid.h"
10 #include "imager/private_imager_update_plane_wproj.h"
11 #include "imager/oskar_grid_wproj2.h"
12 #include "math/oskar_prefix_sum.h"
13 #include "math/oskar_round_robin.h"
14 #include "utility/oskar_device.h"
15 #include "utility/oskar_thread.h"
16 
17 #include <assert.h>
18 
19 #ifdef __cplusplus
20 extern "C" {
21 #endif
22 
23 static void* run_subset(void* arg);
24 
25 struct ThreadArgs
26 {
27     oskar_Imager* h;
28     size_t num_vis, num_skipped;
29     const oskar_Mem *uu, *vv, *ww, *amps, *weight;
30     oskar_Mem *plane;
31     double plane_norm;
32     int grid_size, i_plane, thread_id;
33 };
34 typedef struct ThreadArgs ThreadArgs;
35 
oskar_imager_update_plane_wproj(oskar_Imager * h,size_t num_vis,const oskar_Mem * uu,const oskar_Mem * vv,const oskar_Mem * ww,const oskar_Mem * amps,const oskar_Mem * weight,int i_plane,oskar_Mem * plane,double * plane_norm,size_t * num_skipped,int * status)36 void oskar_imager_update_plane_wproj(oskar_Imager* h, size_t num_vis,
37         const oskar_Mem* uu, const oskar_Mem* vv, const oskar_Mem* ww,
38         const oskar_Mem* amps, const oskar_Mem* weight, int i_plane,
39         oskar_Mem* plane, double* plane_norm, size_t* num_skipped, int* status)
40 {
41     if (*status) return;
42     if (!h->grid_on_gpu || h->num_gpus == 0)
43     {
44         oskar_Mem* plane_ptr = plane;
45         if (!plane_ptr)
46         {
47             if (h->planes)
48             {
49                 plane_ptr = h->planes[i_plane];
50             }
51             else
52             {
53                 *status = OSKAR_ERR_MEMORY_NOT_ALLOCATED;
54                 return;
55             }
56         }
57         if (oskar_mem_location(plane_ptr) != OSKAR_CPU)
58         {
59             *status = OSKAR_ERR_LOCATION_MISMATCH;
60             return;
61         }
62         if (oskar_mem_precision(plane_ptr) != h->imager_prec)
63         {
64             *status = OSKAR_ERR_TYPE_MISMATCH;
65             return;
66         }
67         const int grid_size = oskar_imager_plane_size(h);
68         const size_t num_cells = ((size_t) grid_size) * ((size_t) grid_size);
69         oskar_mem_ensure(plane_ptr, num_cells, status);
70         if (*status) return;
71         if (h->imager_prec == OSKAR_DOUBLE)
72         {
73             oskar_grid_wproj2_d(h->num_w_planes,
74                     oskar_mem_int_const(h->w_support, status),
75                     h->oversample,
76                     oskar_mem_int_const(h->w_kernel_start, status),
77                     oskar_mem_double_const(h->w_kernels_compact, status), num_vis,
78                     oskar_mem_double_const(uu, status),
79                     oskar_mem_double_const(vv, status),
80                     oskar_mem_double_const(ww, status),
81                     oskar_mem_double_const(amps, status),
82                     oskar_mem_double_const(weight, status),
83                     h->cellsize_rad, h->w_scale,
84                     grid_size, num_skipped, plane_norm,
85                     oskar_mem_double(plane_ptr, status));
86         }
87         else
88         {
89             oskar_grid_wproj2_f(h->num_w_planes,
90                     oskar_mem_int_const(h->w_support, status),
91                     h->oversample,
92                     oskar_mem_int_const(h->w_kernel_start, status),
93                     oskar_mem_float_const(h->w_kernels_compact, status), num_vis,
94                     oskar_mem_float_const(uu, status),
95                     oskar_mem_float_const(vv, status),
96                     oskar_mem_float_const(ww, status),
97                     oskar_mem_float_const(amps, status),
98                     oskar_mem_float_const(weight, status),
99                     h->cellsize_rad, h->w_scale,
100                     grid_size, num_skipped, plane_norm,
101                     oskar_mem_float(plane_ptr, status));
102         }
103     }
104     else
105     {
106         int i = 0;
107         oskar_Thread** threads = 0;
108         ThreadArgs* args = 0;
109 
110         /* Set up worker threads. */
111         const int num_threads = h->num_gpus;
112         threads = (oskar_Thread**) calloc(num_threads, sizeof(oskar_Thread*));
113         args = (ThreadArgs*) calloc(num_threads, sizeof(ThreadArgs));
114         for (i = 0; i < num_threads; ++i)
115         {
116             args[i].h = h;
117             args[i].num_vis = num_vis;
118             args[i].uu = uu;
119             args[i].vv = vv;
120             args[i].ww = ww;
121             args[i].amps = amps;
122             args[i].weight = weight;
123             args[i].plane = plane;
124             args[i].i_plane = i_plane;
125             args[i].thread_id = i;
126         }
127 
128         /* Set status code. */
129         h->status = *status;
130 
131         /* Start the worker threads. */
132         for (i = 0; i < num_threads; ++i)
133         {
134             threads[i] = oskar_thread_create(run_subset, (void*)&args[i], 0);
135         }
136 
137         /* Wait for worker threads to finish. */
138         for (i = 0; i < num_threads; ++i)
139         {
140             oskar_thread_join(threads[i]);
141             oskar_thread_free(threads[i]);
142             *plane_norm += args[i].plane_norm;
143             *num_skipped += args[i].num_skipped;
144         }
145         free(threads);
146         free(args);
147 
148         /* Get status code. */
149         *status = h->status;
150     }
151 }
152 
153 
run_subset(void * arg)154 static void* run_subset(void* arg)
155 {
156     oskar_Imager* h = 0;
157     oskar_Mem *plane = 0;
158     const oskar_Mem *uu = 0, *vv = 0, *ww = 0, *amps = 0, *weight = 0;
159     int count_skipped = 0, num_total = 0, *status = 0;
160     int start = 0, num_points = 0;
161     size_t local_size[] = {1, 1, 1}, global_size[] = {1, 1, 1};
162     DeviceData* d = 0;
163 
164     /* Get thread function arguments. */
165     ThreadArgs* a = (ThreadArgs*) arg;
166     const int i_plane = a->i_plane;
167     const int thread_id = a->thread_id;
168     const size_t num_vis = a->num_vis;
169     h = a->h;
170     status = &(h->status);
171     uu = a->uu;
172     vv = a->vv;
173     ww = a->ww;
174     amps = a->amps;
175     weight = a->weight;
176 
177     /* Set the device used by the thread. */
178     d = &h->d[thread_id];
179     plane = a->plane;
180     if (!plane)
181     {
182         if (d->planes)
183         {
184             plane = d->planes[i_plane];
185         }
186         else
187         {
188             *status = OSKAR_ERR_MEMORY_NOT_ALLOCATED;
189             return 0;
190         }
191     }
192     if (oskar_mem_location(plane) != h->dev_loc)
193     {
194         *status = OSKAR_ERR_LOCATION_MISMATCH;
195         return 0;
196     }
197     if (oskar_mem_precision(plane) != h->imager_prec)
198     {
199         *status = OSKAR_ERR_TYPE_MISMATCH;
200         return 0;
201     }
202     oskar_device_set(h->dev_loc, h->gpu_ids[thread_id], status);
203 
204     const int location = h->dev_loc;
205     const int grid_size = oskar_imager_plane_size(h);
206     const int vis_type = oskar_mem_type(amps);
207     const int is_dbl = oskar_type_is_double(vis_type);
208     const int grid_centre = grid_size / 2;
209     const double grid_scale = grid_size * h->cellsize_rad;
210     const double w_scale = h->w_scale;
211     const float grid_scale_f = (float) grid_scale;
212     const float w_scale_f = (float) w_scale;
213 
214     /* Define the tile size and number of tiles in each direction.
215      * A tile consists of SHMSZ grid cells per thread in shared memory
216      * and REGSZ grid cells per thread in registers. */
217     const int tile_size_u = 32;
218     const int tile_size_v = (SHMSZ + REGSZ);
219     const int num_tiles_u = (grid_size + tile_size_u - 1) / tile_size_u;
220     const int num_tiles_v = (grid_size + tile_size_v - 1) / tile_size_v;
221     const int num_tiles = num_tiles_u * num_tiles_v;
222 
223     /* Which tile contains the grid centre? */
224     const int c_tile_u = grid_centre / tile_size_u;
225     const int c_tile_v = grid_centre / tile_size_v;
226 
227     /* Compute difference between centre of centre tile and grid centre
228      * to ensure the centre of the grid is in the centre of a tile. */
229     const int top_left_u = grid_centre -
230             c_tile_u * tile_size_u - tile_size_u / 2;
231     const int top_left_v = grid_centre -
232             c_tile_v * tile_size_v - tile_size_v / 2;
233     assert(top_left_u <= 0);
234     assert(top_left_v <= 0);
235 
236     /* Set up scratch memory. */
237     oskar_round_robin((int)num_vis, h->num_gpus, thread_id,
238             &num_points, &start);
239     oskar_mem_ensure(d->uu, num_points, status);
240     oskar_mem_ensure(d->vv, num_points, status);
241     oskar_mem_ensure(d->ww, num_points, status);
242     oskar_mem_ensure(d->vis, num_points, status);
243     oskar_mem_ensure(d->weight, num_points, status);
244     oskar_mem_copy_contents(d->uu, uu, 0, start, num_points, status);
245     oskar_mem_copy_contents(d->vv, vv, 0, start, num_points, status);
246     oskar_mem_copy_contents(d->ww, ww, 0, start, num_points, status);
247     oskar_mem_copy_contents(d->vis, amps, 0, start, num_points, status);
248     oskar_mem_copy_contents(d->weight, weight, 0, start, num_points, status);
249     oskar_mem_ensure(d->num_points_in_tiles, num_tiles, status);
250     oskar_mem_ensure(d->tile_offsets, num_tiles + 1, status);
251     oskar_mem_ensure(d->tile_locks, num_tiles, status);
252     oskar_mem_clear_contents(d->counter, status);
253     oskar_mem_clear_contents(d->count_skipped, status);
254     oskar_mem_clear_contents(d->norm, status);
255     oskar_mem_clear_contents(d->num_points_in_tiles, status);
256     oskar_mem_clear_contents(d->tile_locks, status);
257     /* Don't need to clear d->tile_offsets, as it gets overwritten. */
258 
259     /* Count the number of elements in each tile. */
260     const float inv_tile_size_u = 1.0f / (float) tile_size_u;
261     const float inv_tile_size_v = 1.0f / (float) tile_size_v;
262     {
263         const char* k = 0;
264         if (oskar_type_is_single(vis_type))
265         {
266             k = "grid_tile_count_wproj_float";
267         }
268         else if (oskar_type_is_double(vis_type))
269         {
270             k = "grid_tile_count_wproj_double";
271         }
272         else
273         {
274             *status = OSKAR_ERR_BAD_DATA_TYPE;
275         }
276         local_size[0] = 512;
277         oskar_device_check_local_size(location, 0, local_size);
278         global_size[0] = oskar_device_global_size(num_points, local_size[0]);
279         const oskar_Arg args[] = {
280                 {INT_SZ, &h->num_w_planes},
281                 {PTR_SZ, oskar_mem_buffer_const(d->w_support)},
282                 {INT_SZ, &num_points},
283                 {PTR_SZ, oskar_mem_buffer_const(d->uu)},
284                 {PTR_SZ, oskar_mem_buffer_const(d->vv)},
285                 {PTR_SZ, oskar_mem_buffer_const(d->ww)},
286                 {INT_SZ, &grid_size},
287                 {INT_SZ, &grid_centre},
288                 {is_dbl ? DBL_SZ : FLT_SZ,  is_dbl ?
289                         (const void*)&grid_scale :
290                         (const void*)&grid_scale_f},
291                 {is_dbl ? DBL_SZ : FLT_SZ,  is_dbl ?
292                         (const void*)&w_scale :
293                         (const void*)&w_scale_f},
294                 {FLT_SZ, (const void*)&inv_tile_size_u},
295                 {FLT_SZ, (const void*)&inv_tile_size_v},
296                 {INT_SZ, &num_tiles_u},
297                 {INT_SZ, &top_left_u},
298                 {INT_SZ, &top_left_v},
299                 {PTR_SZ, oskar_mem_buffer(d->num_points_in_tiles)},
300                 {PTR_SZ, oskar_mem_buffer(d->count_skipped)}
301         };
302         oskar_device_launch_kernel(k, location, 1, local_size, global_size,
303                 sizeof(args) / sizeof(oskar_Arg), args, 0, 0, status);
304     }
305 
306     /* Get the offsets for each tile using prefix sum. */
307     oskar_prefix_sum(num_tiles,
308             d->num_points_in_tiles, d->tile_offsets, status);
309 
310     /* Get the total number of visibilities to process. */
311     oskar_mem_read_element(d->tile_offsets, num_tiles, &num_total, status);
312     oskar_mem_read_element(d->count_skipped, 0, &count_skipped, status);
313     a->num_skipped = (size_t) count_skipped;
314     /*printf("Total points: %d (factor %.3f increase)\n", num_total,
315             (float)num_total / (float)num_points);*/
316 
317     /* Bucket sort the data into tiles. */
318     oskar_mem_ensure(d->sorted_uu, num_total, status);
319     oskar_mem_ensure(d->sorted_vv, num_total, status);
320     oskar_mem_ensure(d->sorted_ww, num_total, status);
321     oskar_mem_ensure(d->sorted_wt, num_total, status);
322     oskar_mem_ensure(d->sorted_vis, num_total, status);
323     oskar_mem_ensure(d->sorted_tile, num_total, status);
324     {
325         const char* k = 0;
326         if (oskar_type_is_single(vis_type))
327         {
328             k = "grid_tile_bucket_sort_wproj_float";
329         }
330         else if (oskar_type_is_double(vis_type))
331         {
332             k = "grid_tile_bucket_sort_wproj_double";
333         }
334         else
335         {
336             *status = OSKAR_ERR_BAD_DATA_TYPE;
337         }
338         local_size[0] = 128;
339         oskar_device_check_local_size(location, 0, local_size);
340         global_size[0] = oskar_device_global_size(num_points, local_size[0]);
341         const oskar_Arg args[] = {
342                 {INT_SZ, &h->num_w_planes},
343                 {PTR_SZ, oskar_mem_buffer_const(d->w_support)},
344                 {INT_SZ, &num_points},
345                 {PTR_SZ, oskar_mem_buffer_const(d->uu)},
346                 {PTR_SZ, oskar_mem_buffer_const(d->vv)},
347                 {PTR_SZ, oskar_mem_buffer_const(d->ww)},
348                 {PTR_SZ, oskar_mem_buffer_const(d->vis)},
349                 {PTR_SZ, oskar_mem_buffer_const(d->weight)},
350                 {INT_SZ, &grid_size},
351                 {INT_SZ, &grid_centre},
352                 {is_dbl ? DBL_SZ : FLT_SZ,  is_dbl ?
353                         (const void*)&grid_scale :
354                         (const void*)&grid_scale_f},
355                 {is_dbl ? DBL_SZ : FLT_SZ,  is_dbl ?
356                         (const void*)&w_scale :
357                         (const void*)&w_scale_f},
358                 {FLT_SZ, (const void*)&inv_tile_size_u},
359                 {FLT_SZ, (const void*)&inv_tile_size_v},
360                 {INT_SZ, &num_tiles_u},
361                 {INT_SZ, &top_left_u},
362                 {INT_SZ, &top_left_v},
363                 {PTR_SZ, oskar_mem_buffer(d->tile_offsets)},
364                 {PTR_SZ, oskar_mem_buffer(d->sorted_uu)},
365                 {PTR_SZ, oskar_mem_buffer(d->sorted_vv)},
366                 {PTR_SZ, oskar_mem_buffer(d->sorted_ww)},
367                 {PTR_SZ, oskar_mem_buffer(d->sorted_vis)},
368                 {PTR_SZ, oskar_mem_buffer(d->sorted_wt)},
369                 {PTR_SZ, oskar_mem_buffer(d->sorted_tile)}
370         };
371         oskar_device_launch_kernel(k, location, 1, local_size, global_size,
372                 sizeof(args) / sizeof(oskar_Arg), args, 0, 0, status);
373     }
374 
375     /* Update the grid. */
376     {
377         const char* k = 0;
378         if (oskar_type_is_single(vis_type))
379         {
380             k = "grid_tile_grid_wproj_float";
381         }
382         else if (oskar_type_is_double(vis_type))
383         {
384             k = "grid_tile_grid_wproj_double";
385         }
386         else
387         {
388             *status = OSKAR_ERR_BAD_DATA_TYPE;
389         }
390         local_size[0] = tile_size_u;
391         size_t num_blocks = (num_points + local_size[0] - 1) / local_size[0];
392         if (num_blocks > 10000) num_blocks = 10000;
393         global_size[0] = local_size[0] * num_blocks;
394         const size_t sh_mem_size = oskar_mem_element_size(vis_type) *
395                 SHMSZ * local_size[0];
396         const oskar_Arg args[] = {
397                 {INT_SZ, &h->num_w_planes},
398                 {PTR_SZ, oskar_mem_buffer_const(d->w_support)},
399                 {INT_SZ, &h->oversample},
400                 {PTR_SZ, oskar_mem_buffer_const(d->w_kernel_start)},
401                 {PTR_SZ, oskar_mem_buffer_const(d->w_kernels_compact)},
402                 {INT_SZ, &grid_size},
403                 {INT_SZ, &grid_centre},
404                 {INT_SZ, &tile_size_u},
405                 {INT_SZ, &tile_size_v},
406                 {INT_SZ, &num_tiles_u},
407                 {INT_SZ, &top_left_u},
408                 {INT_SZ, &top_left_v},
409                 {INT_SZ, &num_total},
410                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_uu)},
411                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_vv)},
412                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_ww)},
413                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_vis)},
414                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_wt)},
415                 {PTR_SZ, oskar_mem_buffer_const(d->sorted_tile)},
416                 {PTR_SZ, oskar_mem_buffer(d->tile_locks)},
417                 {PTR_SZ, oskar_mem_buffer(d->counter)},
418                 {PTR_SZ, oskar_mem_buffer(d->norm)},
419                 {PTR_SZ, oskar_mem_buffer(plane)}
420         };
421         oskar_device_launch_kernel(k, location, 1, local_size, global_size,
422                 sizeof(args) / sizeof(oskar_Arg), args, 1, &sh_mem_size,
423                 status);
424     }
425 
426     /* Update the normalisation value. */
427     if (oskar_mem_type(d->norm) == OSKAR_SINGLE)
428     {
429         float temp_norm = 0.0f;
430         oskar_mem_read_element(d->norm, 0, &temp_norm, status);
431         a->plane_norm = temp_norm;
432     }
433     else
434     {
435         double temp_norm = 0.0;
436         oskar_mem_read_element(d->norm, 0, &temp_norm, status);
437         a->plane_norm = temp_norm;
438     }
439     return 0;
440 }
441 
442 
443 #ifdef __cplusplus
444 }
445 #endif
446