1 #include "metal_objc_platform_dependent.h"
2 #include "HalideRuntime.h"
3 #include "objc_support.h"
4
5 namespace Halide {
6 namespace Runtime {
7 namespace Internal {
8 namespace Metal {
9
dispatch_threadgroups(mtl_compute_command_encoder * encoder,int32_t blocks_x,int32_t blocks_y,int32_t blocks_z,int32_t threads_x,int32_t threads_y,int32_t threads_z)10 WEAK void dispatch_threadgroups(mtl_compute_command_encoder *encoder,
11 int32_t blocks_x, int32_t blocks_y, int32_t blocks_z,
12 int32_t threads_x, int32_t threads_y, int32_t threads_z) {
13 #if BITS_64
14 struct MTLSize {
15 unsigned long width;
16 unsigned long height;
17 unsigned long depth;
18 };
19
20 MTLSize threadgroupsPerGrid;
21 threadgroupsPerGrid.width = blocks_x;
22 threadgroupsPerGrid.height = blocks_y;
23 threadgroupsPerGrid.depth = blocks_z;
24
25 MTLSize threadsPerThreadgroup;
26 threadsPerThreadgroup.width = threads_x;
27 threadsPerThreadgroup.height = threads_y;
28 threadsPerThreadgroup.depth = threads_z;
29
30 #if ARM_COMPILE
31 typedef void (*dispatch_threadgroups_method)(objc_id encoder, objc_sel sel,
32 MTLSize * threadgroupsPerGrid, MTLSize * threadsPerThreadgroup);
33 dispatch_threadgroups_method method = (dispatch_threadgroups_method)&objc_msgSend;
34 (*method)(encoder, sel_getUid("dispatchThreadgroups:threadsPerThreadgroup:"),
35 &threadgroupsPerGrid, &threadsPerThreadgroup);
36 #elif X86_COMPILE
37 typedef void (*dispatch_threadgroups_method)(objc_id encoder, objc_sel sel,
38 MTLSize threadgroupsPerGrid, MTLSize threadsPerThreadgroup);
39 dispatch_threadgroups_method method = (dispatch_threadgroups_method)&objc_msgSend;
40 (*method)(encoder, sel_getUid("dispatchThreadgroups:threadsPerThreadgroup:"),
41 threadgroupsPerGrid, threadsPerThreadgroup);
42 #endif
43 #else
44 typedef void (*dispatch_threadgroups_method)(objc_id encoder, objc_sel sel,
45 int32_t blocks_x, int32_t blocks_y, int32_t blocks_z,
46 int32_t threads_x, int32_t threads_y, int32_t threads_z);
47 dispatch_threadgroups_method method = (dispatch_threadgroups_method)&objc_msgSend;
48 (*method)(encoder, sel_getUid("dispatchThreadgroups:threadsPerThreadgroup:"),
49 blocks_x, blocks_y, blocks_z, threads_x, threads_y, threads_z);
50 #endif
51 }
52
53 } // namespace Metal
54 } // namespace Internal
55 } // namespace Runtime
56 } // namespace Halide
57