1 #include "HalideRuntime.h"
2 #include "device_buffer_utils.h"
3 #include "printer.h"
4 #include "scoped_mutex_lock.h"
5 
6 namespace Halide {
7 namespace Runtime {
8 namespace Internal {
9 
10 #define CACHE_DEBUGGING 0
11 
12 #if CACHE_DEBUGGING
debug_print_buffer(void * user_context,const char * buf_name,const halide_buffer_t & buf)13 WEAK void debug_print_buffer(void *user_context, const char *buf_name, const halide_buffer_t &buf) {
14     debug(user_context) << buf_name << ": elem_size " << buf.type.bytes() << " dimensions " << buf.dimensions << ", ";
15     for (int i = 0; i < buf.dimensions; i++) {
16         debug(user_context) << "(" << buf.dim[i].min
17                             << ", " << buf.dim[i].extent
18                             << ", " << buf.dim[i].stride << ") ";
19     }
20     debug(user_context) << "\n";
21 }
22 
to_hex_char(int val)23 WEAK char to_hex_char(int val) {
24     if (val < 10) {
25         return '0' + val;
26     }
27     return 'A' + (val - 10);
28 }
29 
debug_print_key(void * user_context,const char * msg,const uint8_t * cache_key,int32_t key_size)30 WEAK void debug_print_key(void *user_context, const char *msg, const uint8_t *cache_key, int32_t key_size) {
31     debug(user_context) << "Key for " << msg << "\n";
32     char buf[1024];
33     bool append_ellipses = false;
34     if ((size_t)key_size > (sizeof(buf) / 2) - 1) {  // Each byte in key can take two bytes in output
35         append_ellipses = true;
36         key_size = (sizeof(buf) / 2) - 4;  // room for NUL and "..."
37     }
38     char *buf_ptr = buf;
39     for (int i = 0; i < key_size; i++) {
40         if (cache_key[i] >= 32 && cache_key[i] <= '~') {
41             *buf_ptr++ = cache_key[i];
42         } else {
43             *buf_ptr++ = to_hex_char((cache_key[i] >> 4));
44             *buf_ptr++ = to_hex_char((cache_key[i] & 0xf));
45         }
46     }
47     if (append_ellipses) {
48         *buf_ptr++ = '.';
49         *buf_ptr++ = '.';
50         *buf_ptr++ = '.';
51     }
52     *buf_ptr++ = '\0';
53     debug(user_context) << buf << "\n";
54 }
55 #endif
56 
keys_equal(const uint8_t * key1,const uint8_t * key2,size_t key_size)57 WEAK bool keys_equal(const uint8_t *key1, const uint8_t *key2, size_t key_size) {
58     return memcmp(key1, key2, key_size) == 0;
59 }
60 
buffer_has_shape(const halide_buffer_t * buf,const halide_dimension_t * shape)61 WEAK bool buffer_has_shape(const halide_buffer_t *buf, const halide_dimension_t *shape) {
62     for (int i = 0; i < buf->dimensions; i++) {
63         if (buf->dim[i] != shape[i]) return false;
64     }
65     return true;
66 }
67 
68 struct CacheEntry {
69     CacheEntry *next;
70     CacheEntry *more_recent;
71     CacheEntry *less_recent;
72     uint8_t *metadata_storage;
73     size_t key_size;
74     uint8_t *key;
75     uint32_t hash;
76     uint32_t in_use_count;  // 0 if none returned from halide_cache_lookup
77     uint32_t tuple_count;
78     // The shape of the computed data. There may be more data allocated than this.
79     int32_t dimensions;
80     halide_dimension_t *computed_bounds;
81     // The actual stored data.
82     halide_buffer_t *buf;
83 
84     bool init(const uint8_t *cache_key, size_t cache_key_size,
85               uint32_t key_hash,
86               const halide_buffer_t *computed_bounds_buf,
87               int32_t tuples, halide_buffer_t **tuple_buffers);
88     void destroy();
89     halide_buffer_t &buffer(int32_t i);
90 };
91 
92 struct CacheBlockHeader {
93     CacheEntry *entry;
94     uint32_t hash;
95 };
96 
97 // Each host block has extra space to store a header just before the
98 // contents. This block must respect the same alignment as
99 // halide_malloc, because it offsets the return value from
100 // halide_malloc. The header holds the cache key hash and pointer to
101 // the hash entry.
WEAK(always_inline)102 WEAK __attribute((always_inline)) size_t header_bytes() {
103     size_t s = sizeof(CacheBlockHeader);
104     size_t mask = halide_malloc_alignment() - 1;
105     return (s + mask) & ~mask;
106 }
107 
get_pointer_to_header(uint8_t * host)108 WEAK CacheBlockHeader *get_pointer_to_header(uint8_t *host) {
109     return (CacheBlockHeader *)(host - header_bytes());
110 }
111 
init(const uint8_t * cache_key,size_t cache_key_size,uint32_t key_hash,const halide_buffer_t * computed_bounds_buf,int32_t tuples,halide_buffer_t ** tuple_buffers)112 WEAK bool CacheEntry::init(const uint8_t *cache_key, size_t cache_key_size,
113                            uint32_t key_hash, const halide_buffer_t *computed_bounds_buf,
114                            int32_t tuples, halide_buffer_t **tuple_buffers) {
115     next = NULL;
116     more_recent = NULL;
117     less_recent = NULL;
118     key_size = cache_key_size;
119     hash = key_hash;
120     in_use_count = 0;
121     tuple_count = tuples;
122     dimensions = computed_bounds_buf->dimensions;
123 
124     // Allocate all the necessary space (or die)
125     size_t storage_bytes = 0;
126 
127     // First storage for the tuple halide_buffer_t's
128     storage_bytes += sizeof(halide_buffer_t) * tuple_count;
129 
130     // Then storage for the computed shape, and the allocated shape for
131     // each tuple buffer. These may all be distinct.
132     size_t shape_offset = storage_bytes;
133     storage_bytes += sizeof(halide_dimension_t) * dimensions * (tuple_count + 1);
134 
135     // Then storage for the key
136     size_t key_offset = storage_bytes;
137     storage_bytes += key_size;
138 
139     // Do the single malloc call
140     metadata_storage = (uint8_t *)halide_malloc(NULL, storage_bytes);
141     if (!metadata_storage) {
142         return false;
143     }
144 
145     // Set up the pointers into the allocated metadata space
146     buf = (halide_buffer_t *)metadata_storage;
147     computed_bounds = (halide_dimension_t *)(metadata_storage + shape_offset);
148     key = metadata_storage + key_offset;
149 
150     // Copy over the key
151     for (size_t i = 0; i < key_size; i++) {
152         key[i] = cache_key[i];
153     }
154 
155     // Copy over the shape of the computed region
156     for (int i = 0; i < dimensions; i++) {
157         computed_bounds[i] = computed_bounds_buf->dim[i];
158     }
159 
160     // Copy over the tuple buffers and the shapes of the allocated regions
161     for (uint32_t i = 0; i < tuple_count; i++) {
162         buf[i] = *tuple_buffers[i];
163         buf[i].dim = computed_bounds + (i + 1) * dimensions;
164         for (int j = 0; j < dimensions; j++) {
165             buf[i].dim[j] = tuple_buffers[i]->dim[j];
166         }
167     }
168     return true;
169 }
170 
destroy()171 WEAK void CacheEntry::destroy() {
172     for (uint32_t i = 0; i < tuple_count; i++) {
173         halide_device_free(NULL, &buf[i]);
174         halide_free(NULL, get_pointer_to_header(buf[i].host));
175     }
176     halide_free(NULL, metadata_storage);
177 }
178 
djb_hash(const uint8_t * key,size_t key_size)179 WEAK uint32_t djb_hash(const uint8_t *key, size_t key_size) {
180     uint32_t h = 5381;
181     for (size_t i = 0; i < key_size; i++) {
182         h = (h << 5) + h + key[i];
183     }
184     return h;
185 }
186 
187 WEAK halide_mutex memoization_lock = {{0}};
188 
189 const size_t kHashTableSize = 256;
190 
191 WEAK CacheEntry *cache_entries[kHashTableSize];
192 
193 WEAK CacheEntry *most_recently_used = NULL;
194 WEAK CacheEntry *least_recently_used = NULL;
195 
196 const uint64_t kDefaultCacheSize = 1 << 20;
197 WEAK int64_t max_cache_size = kDefaultCacheSize;
198 WEAK int64_t current_cache_size = 0;
199 
200 #if CACHE_DEBUGGING
validate_cache()201 WEAK void validate_cache() {
202     print(NULL) << "validating cache, "
203                 << "current size " << current_cache_size
204                 << " of maximum " << max_cache_size << "\n";
205     int entries_in_hash_table = 0;
206     for (size_t i = 0; i < kHashTableSize; i++) {
207         CacheEntry *entry = cache_entries[i];
208         while (entry != NULL) {
209             entries_in_hash_table++;
210             if (entry->more_recent == NULL && entry != most_recently_used) {
211                 halide_print(NULL, "cache invalid case 1\n");
212                 __builtin_trap();
213             }
214             if (entry->less_recent == NULL && entry != least_recently_used) {
215                 halide_print(NULL, "cache invalid case 2\n");
216                 __builtin_trap();
217             }
218             entry = entry->next;
219         }
220     }
221     int entries_from_mru = 0;
222     CacheEntry *mru_chain = most_recently_used;
223     while (mru_chain != NULL) {
224         entries_from_mru++;
225         mru_chain = mru_chain->less_recent;
226     }
227     int entries_from_lru = 0;
228     CacheEntry *lru_chain = least_recently_used;
229     while (lru_chain != NULL) {
230         entries_from_lru++;
231         lru_chain = lru_chain->more_recent;
232     }
233     print(NULL) << "hash entries " << entries_in_hash_table
234                 << ", mru entries " << entries_from_mru
235                 << ", lru entries " << entries_from_lru << "\n";
236     if (entries_in_hash_table != entries_from_mru) {
237         halide_print(NULL, "cache invalid case 3\n");
238         __builtin_trap();
239     }
240     if (entries_in_hash_table != entries_from_lru) {
241         halide_print(NULL, "cache invalid case 4\n");
242         __builtin_trap();
243     }
244     if (current_cache_size < 0) {
245         halide_print(NULL, "cache size is negative\n");
246         __builtin_trap();
247     }
248 }
249 #endif
250 
prune_cache()251 WEAK void prune_cache() {
252 #if CACHE_DEBUGGING
253     validate_cache();
254 #endif
255     CacheEntry *prune_candidate = least_recently_used;
256     while (current_cache_size > max_cache_size &&
257            prune_candidate != NULL) {
258         CacheEntry *more_recent = prune_candidate->more_recent;
259 
260         if (prune_candidate->in_use_count == 0) {
261             uint32_t h = prune_candidate->hash;
262             uint32_t index = h % kHashTableSize;
263 
264             // Remove from hash table
265             CacheEntry *prev_hash_entry = cache_entries[index];
266             if (prev_hash_entry == prune_candidate) {
267                 cache_entries[index] = prune_candidate->next;
268             } else {
269                 while (prev_hash_entry != NULL && prev_hash_entry->next != prune_candidate) {
270                     prev_hash_entry = prev_hash_entry->next;
271                 }
272                 halide_assert(NULL, prev_hash_entry != NULL);
273                 prev_hash_entry->next = prune_candidate->next;
274             }
275 
276             // Remove from less recent chain.
277             if (least_recently_used == prune_candidate) {
278                 least_recently_used = more_recent;
279             }
280             if (more_recent != NULL) {
281                 more_recent->less_recent = prune_candidate->less_recent;
282             }
283 
284             // Remove from more recent chain.
285             if (most_recently_used == prune_candidate) {
286                 most_recently_used = prune_candidate->less_recent;
287             }
288             if (prune_candidate->less_recent != NULL) {
289                 prune_candidate->less_recent = more_recent;
290             }
291 
292             // Decrease cache used amount.
293             for (uint32_t i = 0; i < prune_candidate->tuple_count; i++) {
294                 current_cache_size -= prune_candidate->buf[i].size_in_bytes();
295             }
296 
297             // Deallocate the entry.
298             prune_candidate->destroy();
299             halide_free(NULL, prune_candidate);
300         }
301 
302         prune_candidate = more_recent;
303     }
304 #if CACHE_DEBUGGING
305     validate_cache();
306 #endif
307 }
308 
309 }  // namespace Internal
310 }  // namespace Runtime
311 }  // namespace Halide
312 
313 extern "C" {
314 
halide_memoization_cache_set_size(int64_t size)315 WEAK void halide_memoization_cache_set_size(int64_t size) {
316     if (size == 0) {
317         size = kDefaultCacheSize;
318     }
319 
320     ScopedMutexLock lock(&memoization_lock);
321 
322     max_cache_size = size;
323     prune_cache();
324 }
325 
halide_memoization_cache_lookup(void * user_context,const uint8_t * cache_key,int32_t size,halide_buffer_t * computed_bounds,int32_t tuple_count,halide_buffer_t ** tuple_buffers)326 WEAK int halide_memoization_cache_lookup(void *user_context, const uint8_t *cache_key, int32_t size,
327                                          halide_buffer_t *computed_bounds, int32_t tuple_count, halide_buffer_t **tuple_buffers) {
328     uint32_t h = djb_hash(cache_key, size);
329     uint32_t index = h % kHashTableSize;
330 
331     ScopedMutexLock lock(&memoization_lock);
332 
333 #if CACHE_DEBUGGING
334     debug_print_key(user_context, "halide_memoization_cache_lookup", cache_key, size);
335 
336     debug_print_buffer(user_context, "computed_bounds", *computed_bounds);
337 
338     {
339         for (int32_t i = 0; i < tuple_count; i++) {
340             halide_buffer_t *buf = tuple_buffers[i];
341             debug_print_buffer(user_context, "Allocation bounds", *buf);
342         }
343     }
344 #endif
345 
346     CacheEntry *entry = cache_entries[index];
347     while (entry != NULL) {
348         if (entry->hash == h && entry->key_size == (size_t)size &&
349             keys_equal(entry->key, cache_key, size) &&
350             buffer_has_shape(computed_bounds, entry->computed_bounds) &&
351             entry->tuple_count == (uint32_t)tuple_count) {
352 
353             // Check all the tuple buffers have the same bounds (they should).
354             bool all_bounds_equal = true;
355             for (int32_t i = 0; all_bounds_equal && i < tuple_count; i++) {
356                 all_bounds_equal = buffer_has_shape(tuple_buffers[i], entry->buf[i].dim);
357             }
358 
359             if (all_bounds_equal) {
360                 if (entry != most_recently_used) {
361                     halide_assert(user_context, entry->more_recent != NULL);
362                     if (entry->less_recent != NULL) {
363                         entry->less_recent->more_recent = entry->more_recent;
364                     } else {
365                         halide_assert(user_context, least_recently_used == entry);
366                         least_recently_used = entry->more_recent;
367                     }
368                     halide_assert(user_context, entry->more_recent != NULL);
369                     entry->more_recent->less_recent = entry->less_recent;
370 
371                     entry->more_recent = NULL;
372                     entry->less_recent = most_recently_used;
373                     if (most_recently_used != NULL) {
374                         most_recently_used->more_recent = entry;
375                     }
376                     most_recently_used = entry;
377                 }
378 
379                 for (int32_t i = 0; i < tuple_count; i++) {
380                     halide_buffer_t *buf = tuple_buffers[i];
381                     *buf = entry->buf[i];
382                 }
383 
384                 entry->in_use_count += tuple_count;
385 
386                 return 0;
387             }
388         }
389         entry = entry->next;
390     }
391 
392     for (int32_t i = 0; i < tuple_count; i++) {
393         halide_buffer_t *buf = tuple_buffers[i];
394 
395         buf->host = ((uint8_t *)halide_malloc(user_context, buf->size_in_bytes() + header_bytes()));
396         if (buf->host == NULL) {
397             for (int32_t j = i; j > 0; j--) {
398                 halide_free(user_context, get_pointer_to_header(tuple_buffers[j - 1]->host));
399                 tuple_buffers[j - 1]->host = NULL;
400             }
401             return -1;
402         }
403         buf->host += header_bytes();
404         CacheBlockHeader *header = get_pointer_to_header(buf->host);
405         header->hash = h;
406         header->entry = NULL;
407     }
408 
409 #if CACHE_DEBUGGING
410     validate_cache();
411 #endif
412 
413     return 1;
414 }
415 
halide_memoization_cache_store(void * user_context,const uint8_t * cache_key,int32_t size,halide_buffer_t * computed_bounds,int32_t tuple_count,halide_buffer_t ** tuple_buffers)416 WEAK int halide_memoization_cache_store(void *user_context, const uint8_t *cache_key, int32_t size,
417                                         halide_buffer_t *computed_bounds,
418                                         int32_t tuple_count, halide_buffer_t **tuple_buffers) {
419     debug(user_context) << "halide_memoization_cache_store\n";
420 
421     uint32_t h = get_pointer_to_header(tuple_buffers[0]->host)->hash;
422 
423     uint32_t index = h % kHashTableSize;
424 
425     ScopedMutexLock lock(&memoization_lock);
426 
427 #if CACHE_DEBUGGING
428     debug_print_key(user_context, "halide_memoization_cache_store", cache_key, size);
429 
430     debug_print_buffer(user_context, "computed_bounds", *computed_bounds);
431 
432     {
433         for (int32_t i = 0; i < tuple_count; i++) {
434             halide_buffer_t *buf = tuple_buffers[i];
435             debug_print_buffer(user_context, "Allocation bounds", *buf);
436         }
437     }
438 #endif
439 
440     CacheEntry *entry = cache_entries[index];
441     while (entry != NULL) {
442         if (entry->hash == h && entry->key_size == (size_t)size &&
443             keys_equal(entry->key, cache_key, size) &&
444             buffer_has_shape(computed_bounds, entry->computed_bounds) &&
445             entry->tuple_count == (uint32_t)tuple_count) {
446 
447             bool all_bounds_equal = true;
448             bool no_host_pointers_equal = true;
449             {
450                 for (int32_t i = 0; all_bounds_equal && i < tuple_count; i++) {
451                     halide_buffer_t *buf = tuple_buffers[i];
452                     all_bounds_equal = buffer_has_shape(tuple_buffers[i], entry->buf[i].dim);
453                     if (entry->buf[i].host == buf->host) {
454                         no_host_pointers_equal = false;
455                     }
456                 }
457             }
458             if (all_bounds_equal) {
459                 halide_assert(user_context, no_host_pointers_equal);
460                 // This entry is still in use by the caller. Mark it as having no cache entry
461                 // so halide_memoization_cache_release can free the buffer.
462                 for (int32_t i = 0; i < tuple_count; i++) {
463                     get_pointer_to_header(tuple_buffers[i]->host)->entry = NULL;
464                 }
465                 return 0;
466             }
467         }
468         entry = entry->next;
469     }
470 
471     uint64_t added_size = 0;
472     {
473         for (int32_t i = 0; i < tuple_count; i++) {
474             halide_buffer_t *buf = tuple_buffers[i];
475             added_size += buf->size_in_bytes();
476         }
477     }
478     current_cache_size += added_size;
479     prune_cache();
480 
481     CacheEntry *new_entry = (CacheEntry *)halide_malloc(NULL, sizeof(CacheEntry));
482     bool inited = false;
483     if (new_entry) {
484         inited = new_entry->init(cache_key, size, h, computed_bounds, tuple_count, tuple_buffers);
485     }
486     if (!inited) {
487         current_cache_size -= added_size;
488 
489         // This entry is still in use by the caller. Mark it as having no cache entry
490         // so halide_memoization_cache_release can free the buffer.
491         for (int32_t i = 0; i < tuple_count; i++) {
492             get_pointer_to_header(tuple_buffers[i]->host)->entry = NULL;
493         }
494 
495         if (new_entry) {
496             halide_free(user_context, new_entry);
497         }
498         return 0;
499     }
500 
501     new_entry->next = cache_entries[index];
502     new_entry->less_recent = most_recently_used;
503     if (most_recently_used != NULL) {
504         most_recently_used->more_recent = new_entry;
505     }
506     most_recently_used = new_entry;
507     if (least_recently_used == NULL) {
508         least_recently_used = new_entry;
509     }
510     cache_entries[index] = new_entry;
511 
512     new_entry->in_use_count = tuple_count;
513 
514     for (int32_t i = 0; i < tuple_count; i++) {
515         get_pointer_to_header(tuple_buffers[i]->host)->entry = new_entry;
516     }
517 
518 #if CACHE_DEBUGGING
519     validate_cache();
520 #endif
521     debug(user_context) << "Exiting halide_memoization_cache_store\n";
522 
523     return 0;
524 }
525 
halide_memoization_cache_release(void * user_context,void * host)526 WEAK void halide_memoization_cache_release(void *user_context, void *host) {
527     CacheBlockHeader *header = get_pointer_to_header((uint8_t *)host);
528     debug(user_context) << "halide_memoization_cache_release\n";
529     CacheEntry *entry = header->entry;
530 
531     if (entry == NULL) {
532         halide_free(user_context, header);
533     } else {
534         ScopedMutexLock lock(&memoization_lock);
535 
536         halide_assert(user_context, entry->in_use_count > 0);
537         entry->in_use_count--;
538 #if CACHE_DEBUGGING
539         validate_cache();
540 #endif
541     }
542 
543     debug(user_context) << "Exited halide_memoization_cache_release.\n";
544 }
545 
halide_memoization_cache_cleanup()546 WEAK void halide_memoization_cache_cleanup() {
547     debug(NULL) << "halide_memoization_cache_cleanup\n";
548     for (size_t i = 0; i < kHashTableSize; i++) {
549         CacheEntry *entry = cache_entries[i];
550         cache_entries[i] = NULL;
551         while (entry != NULL) {
552             CacheEntry *next = entry->next;
553             entry->destroy();
554             halide_free(NULL, entry);
555             entry = next;
556         }
557     }
558     current_cache_size = 0;
559     most_recently_used = NULL;
560     least_recently_used = NULL;
561 }
562 
563 namespace {
564 
halide_cache_cleanup()565 WEAK __attribute__((destructor)) void halide_cache_cleanup() {
566     halide_memoization_cache_cleanup();
567 }
568 
569 }  // namespace
570 }
571