1 #include "HalideRuntime.h"
2 #include "printer.h"
3
4 #ifndef MX_API_VER
5 #define MX_API_VER 0x07040000
6 #endif
7
8 struct mxArray;
9
10 // It is important to have the mex function pointer definitions in a
11 // namespace to avoid silently conflicting symbols with matlab at
12 // runtime.
13 namespace Halide {
14 namespace Runtime {
15 namespace mex {
16
17 // Define a few things from mex.h that we need to grab the mex APIs
18 // from matlab.
19
20 enum { TMW_NAME_LENGTH_MAX = 64 };
21 enum { mxMAXNAM = TMW_NAME_LENGTH_MAX };
22
23 typedef bool mxLogical;
24 typedef int16_t mxChar;
25
26 enum mxClassID {
27 mxUNKNOWN_CLASS = 0,
28 mxCELL_CLASS,
29 mxSTRUCT_CLASS,
30 mxLOGICAL_CLASS,
31 mxCHAR_CLASS,
32 mxVOID_CLASS,
33 mxDOUBLE_CLASS,
34 mxSINGLE_CLASS,
35 mxINT8_CLASS,
36 mxUINT8_CLASS,
37 mxINT16_CLASS,
38 mxUINT16_CLASS,
39 mxINT32_CLASS,
40 mxUINT32_CLASS,
41 mxINT64_CLASS,
42 mxUINT64_CLASS,
43 mxFUNCTION_CLASS,
44 mxOPAQUE_CLASS,
45 mxOBJECT_CLASS,
46 #ifdef BITS_32
47 mxINDEX_CLASS = mxUINT32_CLASS,
48 #else
49 mxINDEX_CLASS = mxUINT64_CLASS,
50 #endif
51
52 mxSPARSE_CLASS = mxVOID_CLASS
53 };
54
55 enum mxComplexity {
56 mxREAL = 0,
57 mxCOMPLEX
58 };
59
60 #ifdef BITS_32
61 typedef int mwSize;
62 typedef int mwIndex;
63 typedef int mwSignedIndex;
64 #else
65 typedef size_t mwSize;
66 typedef size_t mwIndex;
67 typedef ptrdiff_t mwSignedIndex;
68 #endif
69
70 typedef void (*mex_exit_fn)(void);
71
72 // Declare function pointers for the mex APIs.
73 #define MEX_FN(ret, func, args) ret(*func) args;
74 #include "mex_functions.h"
75
76 // Given a halide type code and bit width, find the equivalent matlab class ID.
get_class_id(int32_t type_code,int32_t type_bits)77 WEAK mxClassID get_class_id(int32_t type_code, int32_t type_bits) {
78 switch (type_code) {
79 case halide_type_int:
80 switch (type_bits) {
81 case 1:
82 return mxLOGICAL_CLASS;
83 case 8:
84 return mxINT8_CLASS;
85 case 16:
86 return mxINT16_CLASS;
87 case 32:
88 return mxINT32_CLASS;
89 case 64:
90 return mxINT64_CLASS;
91 }
92 return mxUNKNOWN_CLASS;
93 case halide_type_uint:
94 switch (type_bits) {
95 case 1:
96 return mxLOGICAL_CLASS;
97 case 8:
98 return mxUINT8_CLASS;
99 case 16:
100 return mxUINT16_CLASS;
101 case 32:
102 return mxUINT32_CLASS;
103 case 64:
104 return mxUINT64_CLASS;
105 }
106 return mxUNKNOWN_CLASS;
107 case halide_type_float:
108 switch (type_bits) {
109 case 32:
110 return mxSINGLE_CLASS;
111 case 64:
112 return mxDOUBLE_CLASS;
113 }
114 return mxUNKNOWN_CLASS;
115 }
116 return mxUNKNOWN_CLASS;
117 }
118
119 // Convert a matlab class ID to a string.
get_class_name(mxClassID id)120 WEAK const char *get_class_name(mxClassID id) {
121 switch (id) {
122 case mxCELL_CLASS:
123 return "cell";
124 case mxSTRUCT_CLASS:
125 return "struct";
126 case mxLOGICAL_CLASS:
127 return "logical";
128 case mxCHAR_CLASS:
129 return "char";
130 case mxVOID_CLASS:
131 return "void";
132 case mxDOUBLE_CLASS:
133 return "double";
134 case mxSINGLE_CLASS:
135 return "single";
136 case mxINT8_CLASS:
137 return "int8";
138 case mxUINT8_CLASS:
139 return "uint8";
140 case mxINT16_CLASS:
141 return "int16";
142 case mxUINT16_CLASS:
143 return "uint16";
144 case mxINT32_CLASS:
145 return "int32";
146 case mxUINT32_CLASS:
147 return "uint32";
148 case mxINT64_CLASS:
149 return "int64";
150 case mxUINT64_CLASS:
151 return "uint64";
152 case mxFUNCTION_CLASS:
153 return "function";
154 case mxOPAQUE_CLASS:
155 return "opaque";
156 case mxOBJECT_CLASS:
157 return "object";
158 default:
159 return "unknown";
160 }
161 }
162
163 // Get the real data pointer from an mxArray.
164 template<typename T>
get_data(mxArray * a)165 ALWAYS_INLINE T *get_data(mxArray *a) {
166 return (T *)mxGetData(a);
167 }
168 template<typename T>
get_data(const mxArray * a)169 ALWAYS_INLINE const T *get_data(const mxArray *a) {
170 return (const T *)mxGetData(a);
171 }
172
173 // Search for a symbol in the calling process (i.e. matlab).
174 template<typename T>
get_mex_symbol(void * user_context,const char * name,bool required)175 ALWAYS_INLINE T get_mex_symbol(void *user_context, const char *name, bool required) {
176 T s = (T)halide_get_symbol(name);
177 if (required && s == NULL) {
178 error(user_context) << "mex API not found: " << name << "\n";
179 return NULL;
180 }
181 return s;
182 }
183
184 // Provide Matlab API version agnostic wrappers for version specific APIs.
get_number_of_dimensions(const mxArray * a)185 ALWAYS_INLINE size_t get_number_of_dimensions(const mxArray *a) {
186 if (mxGetNumberOfDimensions_730) {
187 return mxGetNumberOfDimensions_730(a);
188 } else {
189 return mxGetNumberOfDimensions_700(a);
190 }
191 }
192
get_dimension(const mxArray * a,size_t n)193 ALWAYS_INLINE size_t get_dimension(const mxArray *a, size_t n) {
194 if (mxGetDimensions_730) {
195 return mxGetDimensions_730(a)[n];
196 } else {
197 return mxGetDimensions_700(a)[n];
198 }
199 }
200
create_numeric_matrix(size_t M,size_t N,mxClassID type,mxComplexity complexity)201 ALWAYS_INLINE mxArray *create_numeric_matrix(size_t M, size_t N, mxClassID type, mxComplexity complexity) {
202 if (mxCreateNumericMatrix_730) {
203 return mxCreateNumericMatrix_730(M, N, type, complexity);
204 } else {
205 return mxCreateNumericMatrix_700(M, N, type, complexity);
206 }
207 }
208
209 } // namespace mex
210 } // namespace Runtime
211 } // namespace Halide
212
213 using namespace Halide::Runtime::mex;
214
215 extern "C" {
216
halide_matlab_describe_pipeline(stringstream & desc,const halide_filter_metadata_t * metadata)217 WEAK void halide_matlab_describe_pipeline(stringstream &desc, const halide_filter_metadata_t *metadata) {
218 desc << "int " << metadata->name << "(";
219 for (int i = 0; i < metadata->num_arguments; i++) {
220 const halide_filter_argument_t *arg = &metadata->arguments[i];
221 if (i > 0) {
222 desc << ", ";
223 }
224 if (arg->kind == halide_argument_kind_output_buffer) {
225 desc << "out ";
226 }
227 if (arg->kind == halide_argument_kind_output_buffer ||
228 arg->kind == halide_argument_kind_input_buffer) {
229 desc << arg->dimensions << "d ";
230 } else if (arg->kind == halide_argument_kind_input_scalar) {
231 desc << "scalar ";
232 }
233 desc << get_class_name(get_class_id(arg->type.code, arg->type.bits));
234 desc << " '" << arg->name << "'";
235 }
236 desc << ")";
237 }
238
halide_matlab_note_pipeline_description(void * user_context,const halide_filter_metadata_t * metadata)239 WEAK void halide_matlab_note_pipeline_description(void *user_context, const halide_filter_metadata_t *metadata) {
240 stringstream desc(user_context);
241 desc << "Note pipeline definition:\n";
242 halide_matlab_describe_pipeline(desc, metadata);
243 halide_print(user_context, desc.str());
244 }
245
halide_matlab_error(void * user_context,const char * msg)246 WEAK void halide_matlab_error(void *user_context, const char *msg) {
247 // Note that mexErrMsg/mexErrMsgIdAndTxt crash Matlab. It seems to
248 // be a common problem, those APIs seem to be very fragile.
249 stringstream error_msg(user_context);
250 error_msg << "\nHalide Error: " << msg;
251 mexWarnMsgTxt(error_msg.str());
252 }
253
halide_matlab_print(void *,const char * msg)254 WEAK void halide_matlab_print(void *, const char *msg) {
255 mexWarnMsgTxt(msg);
256 }
257
halide_matlab_init(void * user_context)258 WEAK int halide_matlab_init(void *user_context) {
259 // Assume that if mexWarnMsgTxt exists, we've already attempted initialization.
260 if (mexWarnMsgTxt != NULL) {
261 return halide_error_code_success;
262 }
263
264 #define MEX_FN(ret, func, args) func = get_mex_symbol<ret(*) args>(user_context, #func, true);
265 #define MEX_FN_700(ret, func, func_700, args) func_700 = get_mex_symbol<ret(*) args>(user_context, #func, false);
266 #define MEX_FN_730(ret, func, func_730, args) func_730 = get_mex_symbol<ret(*) args>(user_context, #func_730, false);
267 #include "mex_functions.h"
268
269 if (!mexWarnMsgTxt) {
270 return halide_error_code_matlab_init_failed;
271 }
272
273 // Set up Halide's printing to go through Matlab. Also, don't exit
274 // on error. We don't just replace halide_error/halide_printf,
275 // because they'd have to be weak here, and there would be no
276 // guarantee that we would get this version (and not the standard
277 // one).
278 halide_set_custom_print(halide_matlab_print);
279 halide_set_error_handler(halide_matlab_error);
280
281 return halide_error_code_success;
282 }
283
284 // Convert a matlab mxArray to a Halide halide_buffer_t, with a specific number of dimensions.
halide_matlab_array_to_halide_buffer_t(void * user_context,const mxArray * arr,const halide_filter_argument_t * arg,halide_buffer_t * buf)285 WEAK int halide_matlab_array_to_halide_buffer_t(void *user_context,
286 const mxArray *arr,
287 const halide_filter_argument_t *arg,
288 halide_buffer_t *buf) {
289
290 if (mxIsComplex(arr)) {
291 error(user_context) << "Complex argument not supported for parameter " << arg->name << ".\n";
292 return halide_error_code_matlab_bad_param_type;
293 }
294
295 int dim_count = get_number_of_dimensions(arr);
296 int expected_dims = arg->dimensions;
297
298 // Validate that the data type of a buffer matches exactly.
299 mxClassID arg_class_id = get_class_id(arg->type.code, arg->type.bits);
300 mxClassID class_id = mxGetClassID(arr);
301 if (class_id != arg_class_id) {
302 error(user_context) << "Expected type of class " << get_class_name(arg_class_id)
303 << " for argument " << arg->name
304 << ", got class " << get_class_name(class_id) << ".\n";
305 return halide_error_code_matlab_bad_param_type;
306 }
307 // Validate that the dimensionality matches. Matlab is wierd
308 // because matrices always have at least 2 dimensions, and it
309 // truncates trailing dimensions of extent 1. So, the only way
310 // to have an error here is to have more dimensions with
311 // extent != 1 than the Halide pipeline expects.
312 while (dim_count > 0 && get_dimension(arr, dim_count - 1) == 1) {
313 dim_count--;
314 }
315 if (dim_count > expected_dims) {
316 error(user_context) << "Expected array of rank " << expected_dims
317 << " for argument " << arg->name
318 << ", got array of rank " << dim_count << ".\n";
319 return halide_error_code_matlab_bad_param_type;
320 }
321
322 buf->host = (uint8_t *)mxGetData(arr);
323 buf->type = arg->type;
324 buf->dimensions = arg->dimensions;
325 buf->set_host_dirty(true);
326
327 for (int i = 0; i < dim_count && i < expected_dims; i++) {
328 buf->dim[i].extent = static_cast<int32_t>(get_dimension(arr, i));
329 }
330
331 // Add back the dimensions with extent 1.
332 for (int i = 2; i < expected_dims; i++) {
333 if (buf->dim[i].extent == 0) {
334 buf->dim[i].extent = 1;
335 }
336 }
337
338 // Compute dense strides.
339 buf->dim[0].stride = 1;
340 for (int i = 1; i < expected_dims; i++) {
341 buf->dim[i].stride = buf->dim[i - 1].extent * buf->dim[i - 1].stride;
342 }
343
344 return halide_error_code_success;
345 }
346
347 // Convert a matlab mxArray to a scalar.
halide_matlab_array_to_scalar(void * user_context,const mxArray * arr,const halide_filter_argument_t * arg,void * scalar)348 WEAK int halide_matlab_array_to_scalar(void *user_context,
349 const mxArray *arr, const halide_filter_argument_t *arg, void *scalar) {
350 if (mxIsComplex(arr)) {
351 error(user_context) << "Complex argument not supported for parameter " << arg->name << ".\n";
352 return halide_error_code_generic_error;
353 }
354
355 // Validate that the mxArray has all dimensions of extent 1.
356 int dim_count = get_number_of_dimensions(arr);
357 for (int i = 0; i < dim_count; i++) {
358 if (get_dimension(arr, i) != 1) {
359 error(user_context) << "Expected scalar argument for parameter " << arg->name << ".\n";
360 return halide_error_code_matlab_bad_param_type;
361 }
362 }
363 if (!mxIsLogical(arr) && !mxIsNumeric(arr)) {
364 error(user_context) << "Expected numeric argument for scalar parameter " << arg->name
365 << ", got " << get_class_name(mxGetClassID(arr)) << ".\n";
366 return halide_error_code_matlab_bad_param_type;
367 }
368
369 double value = mxGetScalar(arr);
370 int32_t type_code = arg->type.code;
371 int32_t type_bits = arg->type.bits;
372
373 if (type_code == halide_type_int) {
374 switch (type_bits) {
375 case 1:
376 *reinterpret_cast<bool *>(scalar) = value != 0;
377 return halide_error_code_success;
378 case 8:
379 *reinterpret_cast<int8_t *>(scalar) = static_cast<int8_t>(value);
380 return halide_error_code_success;
381 case 16:
382 *reinterpret_cast<int16_t *>(scalar) = static_cast<int16_t>(value);
383 return halide_error_code_success;
384 case 32:
385 *reinterpret_cast<int32_t *>(scalar) = static_cast<int32_t>(value);
386 return halide_error_code_success;
387 case 64:
388 *reinterpret_cast<int64_t *>(scalar) = static_cast<int64_t>(value);
389 return halide_error_code_success;
390 }
391 } else if (type_code == halide_type_uint) {
392 switch (type_bits) {
393 case 1:
394 *reinterpret_cast<bool *>(scalar) = value != 0;
395 return halide_error_code_success;
396 case 8:
397 *reinterpret_cast<uint8_t *>(scalar) = static_cast<uint8_t>(value);
398 return halide_error_code_success;
399 case 16:
400 *reinterpret_cast<uint16_t *>(scalar) = static_cast<uint16_t>(value);
401 return halide_error_code_success;
402 case 32:
403 *reinterpret_cast<uint32_t *>(scalar) = static_cast<uint32_t>(value);
404 return halide_error_code_success;
405 case 64:
406 *reinterpret_cast<uint64_t *>(scalar) = static_cast<uint64_t>(value);
407 return halide_error_code_success;
408 }
409 } else if (type_code == halide_type_float) {
410 switch (type_bits) {
411 case 32:
412 *reinterpret_cast<float *>(scalar) = static_cast<float>(value);
413 return halide_error_code_success;
414 case 64:
415 *reinterpret_cast<double *>(scalar) = static_cast<double>(value);
416 return halide_error_code_success;
417 }
418 } else if (type_code == halide_type_handle) {
419 error(user_context) << "Parameter " << arg->name << " is of a type not supported by Matlab.\n";
420 return halide_error_code_matlab_bad_param_type;
421 }
422 error(user_context) << "Halide metadata for " << arg->name << " contained invalid or unrecognized type description.\n";
423 return halide_error_code_internal_error;
424 }
425
halide_matlab_call_pipeline(void * user_context,int (* pipeline)(void ** args),const halide_filter_metadata_t * metadata,int nlhs,mxArray ** plhs,int nrhs,const mxArray ** prhs)426 WEAK int halide_matlab_call_pipeline(void *user_context,
427 int (*pipeline)(void **args), const halide_filter_metadata_t *metadata,
428 int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs) {
429
430 int init_result = halide_matlab_init(user_context);
431 if (init_result != 0) {
432 return init_result;
433 }
434
435 int32_t result_storage;
436 int32_t *result_ptr = &result_storage;
437 if (nlhs > 0) {
438 plhs[0] = create_numeric_matrix(1, 1, mxINT32_CLASS, mxREAL);
439 result_ptr = get_data<int32_t>(plhs[0]);
440 }
441 int32_t &result = *result_ptr;
442
443 // Set result to failure until proven otherwise.
444 result = halide_error_code_generic_error;
445
446 // Validate the number of arguments is correct.
447 if (nrhs != metadata->num_arguments) {
448 if (nrhs > 0) {
449 // Only report an actual error if there were any arguments at all.
450 error(user_context) << "Expected " << metadata->num_arguments
451 << " arguments for Halide pipeline " << metadata->name
452 << ", got " << nrhs << ".\n";
453 }
454 halide_matlab_note_pipeline_description(user_context, metadata);
455 return result;
456 }
457
458 // Validate the LHS has zero or one argument.
459 if (nlhs > 1) {
460 error(user_context) << "Expected zero or one return value for Halide pipeline " << metadata->name
461 << ", got " << nlhs << ".\n";
462 halide_matlab_note_pipeline_description(user_context, metadata);
463 return result;
464 }
465
466 void **args = (void **)__builtin_alloca(nrhs * sizeof(void *));
467 for (int i = 0; i < nrhs; i++) {
468 const mxArray *arg = prhs[i];
469 const halide_filter_argument_t *arg_metadata = &metadata->arguments[i];
470
471 if (arg_metadata->kind == halide_argument_kind_input_buffer ||
472 arg_metadata->kind == halide_argument_kind_output_buffer) {
473 halide_buffer_t *buf = (halide_buffer_t *)__builtin_alloca(sizeof(halide_buffer_t));
474 memset(buf, 0, sizeof(halide_buffer_t));
475 buf->dim = (halide_dimension_t *)__builtin_alloca(sizeof(halide_dimension_t) * arg_metadata->dimensions);
476 memset(buf->dim, 0, sizeof(halide_dimension_t) * arg_metadata->dimensions);
477 result = halide_matlab_array_to_halide_buffer_t(user_context, arg, arg_metadata, buf);
478 if (result != 0) {
479 halide_matlab_note_pipeline_description(user_context, metadata);
480 return result;
481 }
482 args[i] = buf;
483 } else {
484 size_t size_bytes = max(8, (arg_metadata->type.bits + 7) / 8);
485 void *scalar = __builtin_alloca(size_bytes);
486 memset(scalar, 0, size_bytes);
487 result = halide_matlab_array_to_scalar(user_context, arg, arg_metadata, scalar);
488 if (result != 0) {
489 halide_matlab_note_pipeline_description(user_context, metadata);
490 return result;
491 }
492 args[i] = scalar;
493 }
494 }
495
496 result = pipeline(args);
497
498 // Copy any GPU resident output buffers back to the CPU before returning.
499 for (int i = 0; i < nrhs; i++) {
500 const halide_filter_argument_t *arg_metadata = &metadata->arguments[i];
501
502 if (arg_metadata->kind == halide_argument_kind_output_buffer) {
503 halide_buffer_t *buf = (halide_buffer_t *)args[i];
504 halide_copy_to_host(user_context, buf);
505 }
506 if (arg_metadata->kind == halide_argument_kind_input_buffer ||
507 arg_metadata->kind == halide_argument_kind_output_buffer) {
508 halide_buffer_t *buf = (halide_buffer_t *)args[i];
509 halide_device_free(user_context, buf);
510 }
511 }
512
513 return result;
514 }
515
516 } // extern "C"
517