1 #include "HalideRuntime.h"
2 #include "printer.h"
3 
4 #ifndef MX_API_VER
5 #define MX_API_VER 0x07040000
6 #endif
7 
8 struct mxArray;
9 
10 // It is important to have the mex function pointer definitions in a
11 // namespace to avoid silently conflicting symbols with matlab at
12 // runtime.
13 namespace Halide {
14 namespace Runtime {
15 namespace mex {
16 
17 // Define a few things from mex.h that we need to grab the mex APIs
18 // from matlab.
19 
20 enum { TMW_NAME_LENGTH_MAX = 64 };
21 enum { mxMAXNAM = TMW_NAME_LENGTH_MAX };
22 
23 typedef bool mxLogical;
24 typedef int16_t mxChar;
25 
26 enum mxClassID {
27     mxUNKNOWN_CLASS = 0,
28     mxCELL_CLASS,
29     mxSTRUCT_CLASS,
30     mxLOGICAL_CLASS,
31     mxCHAR_CLASS,
32     mxVOID_CLASS,
33     mxDOUBLE_CLASS,
34     mxSINGLE_CLASS,
35     mxINT8_CLASS,
36     mxUINT8_CLASS,
37     mxINT16_CLASS,
38     mxUINT16_CLASS,
39     mxINT32_CLASS,
40     mxUINT32_CLASS,
41     mxINT64_CLASS,
42     mxUINT64_CLASS,
43     mxFUNCTION_CLASS,
44     mxOPAQUE_CLASS,
45     mxOBJECT_CLASS,
46 #ifdef BITS_32
47     mxINDEX_CLASS = mxUINT32_CLASS,
48 #else
49     mxINDEX_CLASS = mxUINT64_CLASS,
50 #endif
51 
52     mxSPARSE_CLASS = mxVOID_CLASS
53 };
54 
55 enum mxComplexity {
56     mxREAL = 0,
57     mxCOMPLEX
58 };
59 
60 #ifdef BITS_32
61 typedef int mwSize;
62 typedef int mwIndex;
63 typedef int mwSignedIndex;
64 #else
65 typedef size_t mwSize;
66 typedef size_t mwIndex;
67 typedef ptrdiff_t mwSignedIndex;
68 #endif
69 
70 typedef void (*mex_exit_fn)(void);
71 
72 // Declare function pointers for the mex APIs.
73 #define MEX_FN(ret, func, args) ret(*func) args;
74 #include "mex_functions.h"
75 
76 // Given a halide type code and bit width, find the equivalent matlab class ID.
get_class_id(int32_t type_code,int32_t type_bits)77 WEAK mxClassID get_class_id(int32_t type_code, int32_t type_bits) {
78     switch (type_code) {
79     case halide_type_int:
80         switch (type_bits) {
81         case 1:
82             return mxLOGICAL_CLASS;
83         case 8:
84             return mxINT8_CLASS;
85         case 16:
86             return mxINT16_CLASS;
87         case 32:
88             return mxINT32_CLASS;
89         case 64:
90             return mxINT64_CLASS;
91         }
92         return mxUNKNOWN_CLASS;
93     case halide_type_uint:
94         switch (type_bits) {
95         case 1:
96             return mxLOGICAL_CLASS;
97         case 8:
98             return mxUINT8_CLASS;
99         case 16:
100             return mxUINT16_CLASS;
101         case 32:
102             return mxUINT32_CLASS;
103         case 64:
104             return mxUINT64_CLASS;
105         }
106         return mxUNKNOWN_CLASS;
107     case halide_type_float:
108         switch (type_bits) {
109         case 32:
110             return mxSINGLE_CLASS;
111         case 64:
112             return mxDOUBLE_CLASS;
113         }
114         return mxUNKNOWN_CLASS;
115     }
116     return mxUNKNOWN_CLASS;
117 }
118 
119 // Convert a matlab class ID to a string.
get_class_name(mxClassID id)120 WEAK const char *get_class_name(mxClassID id) {
121     switch (id) {
122     case mxCELL_CLASS:
123         return "cell";
124     case mxSTRUCT_CLASS:
125         return "struct";
126     case mxLOGICAL_CLASS:
127         return "logical";
128     case mxCHAR_CLASS:
129         return "char";
130     case mxVOID_CLASS:
131         return "void";
132     case mxDOUBLE_CLASS:
133         return "double";
134     case mxSINGLE_CLASS:
135         return "single";
136     case mxINT8_CLASS:
137         return "int8";
138     case mxUINT8_CLASS:
139         return "uint8";
140     case mxINT16_CLASS:
141         return "int16";
142     case mxUINT16_CLASS:
143         return "uint16";
144     case mxINT32_CLASS:
145         return "int32";
146     case mxUINT32_CLASS:
147         return "uint32";
148     case mxINT64_CLASS:
149         return "int64";
150     case mxUINT64_CLASS:
151         return "uint64";
152     case mxFUNCTION_CLASS:
153         return "function";
154     case mxOPAQUE_CLASS:
155         return "opaque";
156     case mxOBJECT_CLASS:
157         return "object";
158     default:
159         return "unknown";
160     }
161 }
162 
163 // Get the real data pointer from an mxArray.
164 template<typename T>
get_data(mxArray * a)165 ALWAYS_INLINE T *get_data(mxArray *a) {
166     return (T *)mxGetData(a);
167 }
168 template<typename T>
get_data(const mxArray * a)169 ALWAYS_INLINE const T *get_data(const mxArray *a) {
170     return (const T *)mxGetData(a);
171 }
172 
173 // Search for a symbol in the calling process (i.e. matlab).
174 template<typename T>
get_mex_symbol(void * user_context,const char * name,bool required)175 ALWAYS_INLINE T get_mex_symbol(void *user_context, const char *name, bool required) {
176     T s = (T)halide_get_symbol(name);
177     if (required && s == NULL) {
178         error(user_context) << "mex API not found: " << name << "\n";
179         return NULL;
180     }
181     return s;
182 }
183 
184 // Provide Matlab API version agnostic wrappers for version specific APIs.
get_number_of_dimensions(const mxArray * a)185 ALWAYS_INLINE size_t get_number_of_dimensions(const mxArray *a) {
186     if (mxGetNumberOfDimensions_730) {
187         return mxGetNumberOfDimensions_730(a);
188     } else {
189         return mxGetNumberOfDimensions_700(a);
190     }
191 }
192 
get_dimension(const mxArray * a,size_t n)193 ALWAYS_INLINE size_t get_dimension(const mxArray *a, size_t n) {
194     if (mxGetDimensions_730) {
195         return mxGetDimensions_730(a)[n];
196     } else {
197         return mxGetDimensions_700(a)[n];
198     }
199 }
200 
create_numeric_matrix(size_t M,size_t N,mxClassID type,mxComplexity complexity)201 ALWAYS_INLINE mxArray *create_numeric_matrix(size_t M, size_t N, mxClassID type, mxComplexity complexity) {
202     if (mxCreateNumericMatrix_730) {
203         return mxCreateNumericMatrix_730(M, N, type, complexity);
204     } else {
205         return mxCreateNumericMatrix_700(M, N, type, complexity);
206     }
207 }
208 
209 }  // namespace mex
210 }  // namespace Runtime
211 }  // namespace Halide
212 
213 using namespace Halide::Runtime::mex;
214 
215 extern "C" {
216 
halide_matlab_describe_pipeline(stringstream & desc,const halide_filter_metadata_t * metadata)217 WEAK void halide_matlab_describe_pipeline(stringstream &desc, const halide_filter_metadata_t *metadata) {
218     desc << "int " << metadata->name << "(";
219     for (int i = 0; i < metadata->num_arguments; i++) {
220         const halide_filter_argument_t *arg = &metadata->arguments[i];
221         if (i > 0) {
222             desc << ", ";
223         }
224         if (arg->kind == halide_argument_kind_output_buffer) {
225             desc << "out ";
226         }
227         if (arg->kind == halide_argument_kind_output_buffer ||
228             arg->kind == halide_argument_kind_input_buffer) {
229             desc << arg->dimensions << "d ";
230         } else if (arg->kind == halide_argument_kind_input_scalar) {
231             desc << "scalar ";
232         }
233         desc << get_class_name(get_class_id(arg->type.code, arg->type.bits));
234         desc << " '" << arg->name << "'";
235     }
236     desc << ")";
237 }
238 
halide_matlab_note_pipeline_description(void * user_context,const halide_filter_metadata_t * metadata)239 WEAK void halide_matlab_note_pipeline_description(void *user_context, const halide_filter_metadata_t *metadata) {
240     stringstream desc(user_context);
241     desc << "Note pipeline definition:\n";
242     halide_matlab_describe_pipeline(desc, metadata);
243     halide_print(user_context, desc.str());
244 }
245 
halide_matlab_error(void * user_context,const char * msg)246 WEAK void halide_matlab_error(void *user_context, const char *msg) {
247     // Note that mexErrMsg/mexErrMsgIdAndTxt crash Matlab. It seems to
248     // be a common problem, those APIs seem to be very fragile.
249     stringstream error_msg(user_context);
250     error_msg << "\nHalide Error: " << msg;
251     mexWarnMsgTxt(error_msg.str());
252 }
253 
halide_matlab_print(void *,const char * msg)254 WEAK void halide_matlab_print(void *, const char *msg) {
255     mexWarnMsgTxt(msg);
256 }
257 
halide_matlab_init(void * user_context)258 WEAK int halide_matlab_init(void *user_context) {
259     // Assume that if mexWarnMsgTxt exists, we've already attempted initialization.
260     if (mexWarnMsgTxt != NULL) {
261         return halide_error_code_success;
262     }
263 
264 #define MEX_FN(ret, func, args) func = get_mex_symbol<ret(*) args>(user_context, #func, true);
265 #define MEX_FN_700(ret, func, func_700, args) func_700 = get_mex_symbol<ret(*) args>(user_context, #func, false);
266 #define MEX_FN_730(ret, func, func_730, args) func_730 = get_mex_symbol<ret(*) args>(user_context, #func_730, false);
267 #include "mex_functions.h"
268 
269     if (!mexWarnMsgTxt) {
270         return halide_error_code_matlab_init_failed;
271     }
272 
273     // Set up Halide's printing to go through Matlab. Also, don't exit
274     // on error. We don't just replace halide_error/halide_printf,
275     // because they'd have to be weak here, and there would be no
276     // guarantee that we would get this version (and not the standard
277     // one).
278     halide_set_custom_print(halide_matlab_print);
279     halide_set_error_handler(halide_matlab_error);
280 
281     return halide_error_code_success;
282 }
283 
284 // Convert a matlab mxArray to a Halide halide_buffer_t, with a specific number of dimensions.
halide_matlab_array_to_halide_buffer_t(void * user_context,const mxArray * arr,const halide_filter_argument_t * arg,halide_buffer_t * buf)285 WEAK int halide_matlab_array_to_halide_buffer_t(void *user_context,
286                                                 const mxArray *arr,
287                                                 const halide_filter_argument_t *arg,
288                                                 halide_buffer_t *buf) {
289 
290     if (mxIsComplex(arr)) {
291         error(user_context) << "Complex argument not supported for parameter " << arg->name << ".\n";
292         return halide_error_code_matlab_bad_param_type;
293     }
294 
295     int dim_count = get_number_of_dimensions(arr);
296     int expected_dims = arg->dimensions;
297 
298     // Validate that the data type of a buffer matches exactly.
299     mxClassID arg_class_id = get_class_id(arg->type.code, arg->type.bits);
300     mxClassID class_id = mxGetClassID(arr);
301     if (class_id != arg_class_id) {
302         error(user_context) << "Expected type of class " << get_class_name(arg_class_id)
303                             << " for argument " << arg->name
304                             << ", got class " << get_class_name(class_id) << ".\n";
305         return halide_error_code_matlab_bad_param_type;
306     }
307     // Validate that the dimensionality matches. Matlab is wierd
308     // because matrices always have at least 2 dimensions, and it
309     // truncates trailing dimensions of extent 1. So, the only way
310     // to have an error here is to have more dimensions with
311     // extent != 1 than the Halide pipeline expects.
312     while (dim_count > 0 && get_dimension(arr, dim_count - 1) == 1) {
313         dim_count--;
314     }
315     if (dim_count > expected_dims) {
316         error(user_context) << "Expected array of rank " << expected_dims
317                             << " for argument " << arg->name
318                             << ", got array of rank " << dim_count << ".\n";
319         return halide_error_code_matlab_bad_param_type;
320     }
321 
322     buf->host = (uint8_t *)mxGetData(arr);
323     buf->type = arg->type;
324     buf->dimensions = arg->dimensions;
325     buf->set_host_dirty(true);
326 
327     for (int i = 0; i < dim_count && i < expected_dims; i++) {
328         buf->dim[i].extent = static_cast<int32_t>(get_dimension(arr, i));
329     }
330 
331     // Add back the dimensions with extent 1.
332     for (int i = 2; i < expected_dims; i++) {
333         if (buf->dim[i].extent == 0) {
334             buf->dim[i].extent = 1;
335         }
336     }
337 
338     // Compute dense strides.
339     buf->dim[0].stride = 1;
340     for (int i = 1; i < expected_dims; i++) {
341         buf->dim[i].stride = buf->dim[i - 1].extent * buf->dim[i - 1].stride;
342     }
343 
344     return halide_error_code_success;
345 }
346 
347 // Convert a matlab mxArray to a scalar.
halide_matlab_array_to_scalar(void * user_context,const mxArray * arr,const halide_filter_argument_t * arg,void * scalar)348 WEAK int halide_matlab_array_to_scalar(void *user_context,
349                                        const mxArray *arr, const halide_filter_argument_t *arg, void *scalar) {
350     if (mxIsComplex(arr)) {
351         error(user_context) << "Complex argument not supported for parameter " << arg->name << ".\n";
352         return halide_error_code_generic_error;
353     }
354 
355     // Validate that the mxArray has all dimensions of extent 1.
356     int dim_count = get_number_of_dimensions(arr);
357     for (int i = 0; i < dim_count; i++) {
358         if (get_dimension(arr, i) != 1) {
359             error(user_context) << "Expected scalar argument for parameter " << arg->name << ".\n";
360             return halide_error_code_matlab_bad_param_type;
361         }
362     }
363     if (!mxIsLogical(arr) && !mxIsNumeric(arr)) {
364         error(user_context) << "Expected numeric argument for scalar parameter " << arg->name
365                             << ", got " << get_class_name(mxGetClassID(arr)) << ".\n";
366         return halide_error_code_matlab_bad_param_type;
367     }
368 
369     double value = mxGetScalar(arr);
370     int32_t type_code = arg->type.code;
371     int32_t type_bits = arg->type.bits;
372 
373     if (type_code == halide_type_int) {
374         switch (type_bits) {
375         case 1:
376             *reinterpret_cast<bool *>(scalar) = value != 0;
377             return halide_error_code_success;
378         case 8:
379             *reinterpret_cast<int8_t *>(scalar) = static_cast<int8_t>(value);
380             return halide_error_code_success;
381         case 16:
382             *reinterpret_cast<int16_t *>(scalar) = static_cast<int16_t>(value);
383             return halide_error_code_success;
384         case 32:
385             *reinterpret_cast<int32_t *>(scalar) = static_cast<int32_t>(value);
386             return halide_error_code_success;
387         case 64:
388             *reinterpret_cast<int64_t *>(scalar) = static_cast<int64_t>(value);
389             return halide_error_code_success;
390         }
391     } else if (type_code == halide_type_uint) {
392         switch (type_bits) {
393         case 1:
394             *reinterpret_cast<bool *>(scalar) = value != 0;
395             return halide_error_code_success;
396         case 8:
397             *reinterpret_cast<uint8_t *>(scalar) = static_cast<uint8_t>(value);
398             return halide_error_code_success;
399         case 16:
400             *reinterpret_cast<uint16_t *>(scalar) = static_cast<uint16_t>(value);
401             return halide_error_code_success;
402         case 32:
403             *reinterpret_cast<uint32_t *>(scalar) = static_cast<uint32_t>(value);
404             return halide_error_code_success;
405         case 64:
406             *reinterpret_cast<uint64_t *>(scalar) = static_cast<uint64_t>(value);
407             return halide_error_code_success;
408         }
409     } else if (type_code == halide_type_float) {
410         switch (type_bits) {
411         case 32:
412             *reinterpret_cast<float *>(scalar) = static_cast<float>(value);
413             return halide_error_code_success;
414         case 64:
415             *reinterpret_cast<double *>(scalar) = static_cast<double>(value);
416             return halide_error_code_success;
417         }
418     } else if (type_code == halide_type_handle) {
419         error(user_context) << "Parameter " << arg->name << " is of a type not supported by Matlab.\n";
420         return halide_error_code_matlab_bad_param_type;
421     }
422     error(user_context) << "Halide metadata for " << arg->name << " contained invalid or unrecognized type description.\n";
423     return halide_error_code_internal_error;
424 }
425 
halide_matlab_call_pipeline(void * user_context,int (* pipeline)(void ** args),const halide_filter_metadata_t * metadata,int nlhs,mxArray ** plhs,int nrhs,const mxArray ** prhs)426 WEAK int halide_matlab_call_pipeline(void *user_context,
427                                      int (*pipeline)(void **args), const halide_filter_metadata_t *metadata,
428                                      int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs) {
429 
430     int init_result = halide_matlab_init(user_context);
431     if (init_result != 0) {
432         return init_result;
433     }
434 
435     int32_t result_storage;
436     int32_t *result_ptr = &result_storage;
437     if (nlhs > 0) {
438         plhs[0] = create_numeric_matrix(1, 1, mxINT32_CLASS, mxREAL);
439         result_ptr = get_data<int32_t>(plhs[0]);
440     }
441     int32_t &result = *result_ptr;
442 
443     // Set result to failure until proven otherwise.
444     result = halide_error_code_generic_error;
445 
446     // Validate the number of arguments is correct.
447     if (nrhs != metadata->num_arguments) {
448         if (nrhs > 0) {
449             // Only report an actual error if there were any arguments at all.
450             error(user_context) << "Expected " << metadata->num_arguments
451                                 << " arguments for Halide pipeline " << metadata->name
452                                 << ", got " << nrhs << ".\n";
453         }
454         halide_matlab_note_pipeline_description(user_context, metadata);
455         return result;
456     }
457 
458     // Validate the LHS has zero or one argument.
459     if (nlhs > 1) {
460         error(user_context) << "Expected zero or one return value for Halide pipeline " << metadata->name
461                             << ", got " << nlhs << ".\n";
462         halide_matlab_note_pipeline_description(user_context, metadata);
463         return result;
464     }
465 
466     void **args = (void **)__builtin_alloca(nrhs * sizeof(void *));
467     for (int i = 0; i < nrhs; i++) {
468         const mxArray *arg = prhs[i];
469         const halide_filter_argument_t *arg_metadata = &metadata->arguments[i];
470 
471         if (arg_metadata->kind == halide_argument_kind_input_buffer ||
472             arg_metadata->kind == halide_argument_kind_output_buffer) {
473             halide_buffer_t *buf = (halide_buffer_t *)__builtin_alloca(sizeof(halide_buffer_t));
474             memset(buf, 0, sizeof(halide_buffer_t));
475             buf->dim = (halide_dimension_t *)__builtin_alloca(sizeof(halide_dimension_t) * arg_metadata->dimensions);
476             memset(buf->dim, 0, sizeof(halide_dimension_t) * arg_metadata->dimensions);
477             result = halide_matlab_array_to_halide_buffer_t(user_context, arg, arg_metadata, buf);
478             if (result != 0) {
479                 halide_matlab_note_pipeline_description(user_context, metadata);
480                 return result;
481             }
482             args[i] = buf;
483         } else {
484             size_t size_bytes = max(8, (arg_metadata->type.bits + 7) / 8);
485             void *scalar = __builtin_alloca(size_bytes);
486             memset(scalar, 0, size_bytes);
487             result = halide_matlab_array_to_scalar(user_context, arg, arg_metadata, scalar);
488             if (result != 0) {
489                 halide_matlab_note_pipeline_description(user_context, metadata);
490                 return result;
491             }
492             args[i] = scalar;
493         }
494     }
495 
496     result = pipeline(args);
497 
498     // Copy any GPU resident output buffers back to the CPU before returning.
499     for (int i = 0; i < nrhs; i++) {
500         const halide_filter_argument_t *arg_metadata = &metadata->arguments[i];
501 
502         if (arg_metadata->kind == halide_argument_kind_output_buffer) {
503             halide_buffer_t *buf = (halide_buffer_t *)args[i];
504             halide_copy_to_host(user_context, buf);
505         }
506         if (arg_metadata->kind == halide_argument_kind_input_buffer ||
507             arg_metadata->kind == halide_argument_kind_output_buffer) {
508             halide_buffer_t *buf = (halide_buffer_t *)args[i];
509             halide_device_free(user_context, buf);
510         }
511     }
512 
513     return result;
514 }
515 
516 }  // extern "C"
517