1 #include "HalideBuffer.h"
2 #include "HalideRuntime.h"
3 
4 // Grab the internal device_interface functions
5 #define WEAK
6 #include "device_interface.h"
7 
8 #include <stdio.h>
9 #include <stdlib.h>
10 
11 #include "cleanup_on_error.h"
12 
13 using namespace Halide::Runtime;
14 
15 const int size = 64;
16 
17 int successful_mallocs = 0, failed_mallocs = 0, frees = 0, errors = 0, device_mallocs = 0, device_frees = 0;
18 
my_halide_malloc(void * user_context,size_t x)19 void *my_halide_malloc(void *user_context, size_t x) {
20     // Only the first malloc succeeds
21     if (successful_mallocs) {
22         failed_mallocs++;
23         return nullptr;
24     }
25     successful_mallocs++;
26 
27     void *orig = malloc(x + 40);
28     // Round up to next multiple of 32. Should add at least 8 bytes so we can fit the original pointer.
29     void *ptr = (void *)((((size_t)orig + 32) >> 5) << 5);
30     ((void **)ptr)[-1] = orig;
31     return ptr;
32 }
33 
my_halide_free(void * user_context,void * ptr)34 void my_halide_free(void *user_context, void *ptr) {
35     if (!ptr) return;
36     frees++;
37     free(((void **)ptr)[-1]);
38 }
39 
my_halide_error(void * user_context,const char * msg)40 void my_halide_error(void *user_context, const char *msg) {
41     errors++;
42 }
43 
44 #ifndef _WIN32
45 // These two can't be overridden on windows, so we'll just check that
46 // the number of calls to free matches the number of calls to malloc.
halide_device_free(void * user_context,struct halide_buffer_t * buf)47 extern "C" int halide_device_free(void *user_context, struct halide_buffer_t *buf) {
48     device_frees++;
49     return buf->device_interface->impl->device_free(user_context, buf);
50 }
51 
halide_device_malloc(void * user_context,struct halide_buffer_t * buf,const halide_device_interface_t * interface)52 extern "C" int halide_device_malloc(void *user_context, struct halide_buffer_t *buf,
53                                     const halide_device_interface_t *interface) {
54     if (!buf->device) {
55         device_mallocs++;
56     }
57     return interface->impl->device_malloc(user_context, buf);
58 }
59 #endif
60 
main(int argc,char ** argv)61 int main(int argc, char **argv) {
62 
63     halide_set_custom_malloc(&my_halide_malloc);
64     halide_set_custom_free(&my_halide_free);
65     halide_set_error_handler(&my_halide_error);
66 
67     Buffer<int32_t> output(size);
68     int result = cleanup_on_error(output);
69 
70     if (result != halide_error_code_out_of_memory &&
71         result != halide_error_code_device_malloc_failed) {
72         printf("The exit status was %d instead of %d or %d\n",
73                result,
74                halide_error_code_out_of_memory,
75                halide_error_code_device_malloc_failed);
76         return -1;
77     }
78 
79     if (failed_mallocs != 1) {
80         printf("One of the mallocs was supposed to fail\n");
81         return -1;
82     }
83 
84     if (successful_mallocs != 1) {
85         printf("One of the mallocs was supposed to succeed\n");
86         return -1;
87     }
88 
89     if (frees != 1) {
90         printf("The successful malloc should have been freed\n");
91         return -1;
92     }
93 
94     if (errors != 1) {
95         printf("%d errors. There was supposed to be one error\n", errors);
96         return -1;
97     }
98 
99     if (device_mallocs != device_frees) {
100         printf("There were a different number of device mallocs (%d) and frees (%d)\n", device_mallocs, device_frees);
101         return -1;
102     }
103 
104     printf("Success!\n");
105     return 0;
106 }
107