1 #include "metal_objc_platform_dependent.h"
2 #include "HalideRuntime.h"
3 #include "objc_support.h"
4 
5 namespace Halide {
6 namespace Runtime {
7 namespace Internal {
8 namespace Metal {
9 
dispatch_threadgroups(mtl_compute_command_encoder * encoder,int32_t blocks_x,int32_t blocks_y,int32_t blocks_z,int32_t threads_x,int32_t threads_y,int32_t threads_z)10 WEAK void dispatch_threadgroups(mtl_compute_command_encoder *encoder,
11                                 int32_t blocks_x, int32_t blocks_y, int32_t blocks_z,
12                                 int32_t threads_x, int32_t threads_y, int32_t threads_z) {
13 #if BITS_64
14     struct MTLSize {
15         unsigned long width;
16         unsigned long height;
17         unsigned long depth;
18     };
19 
20     MTLSize threadgroupsPerGrid;
21     threadgroupsPerGrid.width = blocks_x;
22     threadgroupsPerGrid.height = blocks_y;
23     threadgroupsPerGrid.depth = blocks_z;
24 
25     MTLSize threadsPerThreadgroup;
26     threadsPerThreadgroup.width = threads_x;
27     threadsPerThreadgroup.height = threads_y;
28     threadsPerThreadgroup.depth = threads_z;
29 
30 #if ARM_COMPILE
31     typedef void (*dispatch_threadgroups_method)(objc_id encoder, objc_sel sel,
32                                                  MTLSize * threadgroupsPerGrid, MTLSize * threadsPerThreadgroup);
33     dispatch_threadgroups_method method = (dispatch_threadgroups_method)&objc_msgSend;
34     (*method)(encoder, sel_getUid("dispatchThreadgroups:threadsPerThreadgroup:"),
35               &threadgroupsPerGrid, &threadsPerThreadgroup);
36 #elif X86_COMPILE
37     typedef void (*dispatch_threadgroups_method)(objc_id encoder, objc_sel sel,
38                                                  MTLSize threadgroupsPerGrid, MTLSize threadsPerThreadgroup);
39     dispatch_threadgroups_method method = (dispatch_threadgroups_method)&objc_msgSend;
40     (*method)(encoder, sel_getUid("dispatchThreadgroups:threadsPerThreadgroup:"),
41               threadgroupsPerGrid, threadsPerThreadgroup);
42 #endif
43 #else
44     typedef void (*dispatch_threadgroups_method)(objc_id encoder, objc_sel sel,
45                                                  int32_t blocks_x, int32_t blocks_y, int32_t blocks_z,
46                                                  int32_t threads_x, int32_t threads_y, int32_t threads_z);
47     dispatch_threadgroups_method method = (dispatch_threadgroups_method)&objc_msgSend;
48     (*method)(encoder, sel_getUid("dispatchThreadgroups:threadsPerThreadgroup:"),
49               blocks_x, blocks_y, blocks_z, threads_x, threads_y, threads_z);
50 #endif
51 }
52 
53 }  // namespace Metal
54 }  // namespace Internal
55 }  // namespace Runtime
56 }  // namespace Halide
57