1 #include "HalideRuntime.h"
2 #include "device_buffer_utils.h"
3 #include "printer.h"
4 #include "scoped_mutex_lock.h"
5
6 namespace Halide {
7 namespace Runtime {
8 namespace Internal {
9
10 #define CACHE_DEBUGGING 0
11
12 #if CACHE_DEBUGGING
debug_print_buffer(void * user_context,const char * buf_name,const halide_buffer_t & buf)13 WEAK void debug_print_buffer(void *user_context, const char *buf_name, const halide_buffer_t &buf) {
14 debug(user_context) << buf_name << ": elem_size " << buf.type.bytes() << " dimensions " << buf.dimensions << ", ";
15 for (int i = 0; i < buf.dimensions; i++) {
16 debug(user_context) << "(" << buf.dim[i].min
17 << ", " << buf.dim[i].extent
18 << ", " << buf.dim[i].stride << ") ";
19 }
20 debug(user_context) << "\n";
21 }
22
to_hex_char(int val)23 WEAK char to_hex_char(int val) {
24 if (val < 10) {
25 return '0' + val;
26 }
27 return 'A' + (val - 10);
28 }
29
debug_print_key(void * user_context,const char * msg,const uint8_t * cache_key,int32_t key_size)30 WEAK void debug_print_key(void *user_context, const char *msg, const uint8_t *cache_key, int32_t key_size) {
31 debug(user_context) << "Key for " << msg << "\n";
32 char buf[1024];
33 bool append_ellipses = false;
34 if ((size_t)key_size > (sizeof(buf) / 2) - 1) { // Each byte in key can take two bytes in output
35 append_ellipses = true;
36 key_size = (sizeof(buf) / 2) - 4; // room for NUL and "..."
37 }
38 char *buf_ptr = buf;
39 for (int i = 0; i < key_size; i++) {
40 if (cache_key[i] >= 32 && cache_key[i] <= '~') {
41 *buf_ptr++ = cache_key[i];
42 } else {
43 *buf_ptr++ = to_hex_char((cache_key[i] >> 4));
44 *buf_ptr++ = to_hex_char((cache_key[i] & 0xf));
45 }
46 }
47 if (append_ellipses) {
48 *buf_ptr++ = '.';
49 *buf_ptr++ = '.';
50 *buf_ptr++ = '.';
51 }
52 *buf_ptr++ = '\0';
53 debug(user_context) << buf << "\n";
54 }
55 #endif
56
keys_equal(const uint8_t * key1,const uint8_t * key2,size_t key_size)57 WEAK bool keys_equal(const uint8_t *key1, const uint8_t *key2, size_t key_size) {
58 return memcmp(key1, key2, key_size) == 0;
59 }
60
buffer_has_shape(const halide_buffer_t * buf,const halide_dimension_t * shape)61 WEAK bool buffer_has_shape(const halide_buffer_t *buf, const halide_dimension_t *shape) {
62 for (int i = 0; i < buf->dimensions; i++) {
63 if (buf->dim[i] != shape[i]) return false;
64 }
65 return true;
66 }
67
68 struct CacheEntry {
69 CacheEntry *next;
70 CacheEntry *more_recent;
71 CacheEntry *less_recent;
72 uint8_t *metadata_storage;
73 size_t key_size;
74 uint8_t *key;
75 uint32_t hash;
76 uint32_t in_use_count; // 0 if none returned from halide_cache_lookup
77 uint32_t tuple_count;
78 // The shape of the computed data. There may be more data allocated than this.
79 int32_t dimensions;
80 halide_dimension_t *computed_bounds;
81 // The actual stored data.
82 halide_buffer_t *buf;
83
84 bool init(const uint8_t *cache_key, size_t cache_key_size,
85 uint32_t key_hash,
86 const halide_buffer_t *computed_bounds_buf,
87 int32_t tuples, halide_buffer_t **tuple_buffers);
88 void destroy();
89 halide_buffer_t &buffer(int32_t i);
90 };
91
92 struct CacheBlockHeader {
93 CacheEntry *entry;
94 uint32_t hash;
95 };
96
97 // Each host block has extra space to store a header just before the
98 // contents. This block must respect the same alignment as
99 // halide_malloc, because it offsets the return value from
100 // halide_malloc. The header holds the cache key hash and pointer to
101 // the hash entry.
WEAK(always_inline)102 WEAK __attribute((always_inline)) size_t header_bytes() {
103 size_t s = sizeof(CacheBlockHeader);
104 size_t mask = halide_malloc_alignment() - 1;
105 return (s + mask) & ~mask;
106 }
107
get_pointer_to_header(uint8_t * host)108 WEAK CacheBlockHeader *get_pointer_to_header(uint8_t *host) {
109 return (CacheBlockHeader *)(host - header_bytes());
110 }
111
init(const uint8_t * cache_key,size_t cache_key_size,uint32_t key_hash,const halide_buffer_t * computed_bounds_buf,int32_t tuples,halide_buffer_t ** tuple_buffers)112 WEAK bool CacheEntry::init(const uint8_t *cache_key, size_t cache_key_size,
113 uint32_t key_hash, const halide_buffer_t *computed_bounds_buf,
114 int32_t tuples, halide_buffer_t **tuple_buffers) {
115 next = NULL;
116 more_recent = NULL;
117 less_recent = NULL;
118 key_size = cache_key_size;
119 hash = key_hash;
120 in_use_count = 0;
121 tuple_count = tuples;
122 dimensions = computed_bounds_buf->dimensions;
123
124 // Allocate all the necessary space (or die)
125 size_t storage_bytes = 0;
126
127 // First storage for the tuple halide_buffer_t's
128 storage_bytes += sizeof(halide_buffer_t) * tuple_count;
129
130 // Then storage for the computed shape, and the allocated shape for
131 // each tuple buffer. These may all be distinct.
132 size_t shape_offset = storage_bytes;
133 storage_bytes += sizeof(halide_dimension_t) * dimensions * (tuple_count + 1);
134
135 // Then storage for the key
136 size_t key_offset = storage_bytes;
137 storage_bytes += key_size;
138
139 // Do the single malloc call
140 metadata_storage = (uint8_t *)halide_malloc(NULL, storage_bytes);
141 if (!metadata_storage) {
142 return false;
143 }
144
145 // Set up the pointers into the allocated metadata space
146 buf = (halide_buffer_t *)metadata_storage;
147 computed_bounds = (halide_dimension_t *)(metadata_storage + shape_offset);
148 key = metadata_storage + key_offset;
149
150 // Copy over the key
151 for (size_t i = 0; i < key_size; i++) {
152 key[i] = cache_key[i];
153 }
154
155 // Copy over the shape of the computed region
156 for (int i = 0; i < dimensions; i++) {
157 computed_bounds[i] = computed_bounds_buf->dim[i];
158 }
159
160 // Copy over the tuple buffers and the shapes of the allocated regions
161 for (uint32_t i = 0; i < tuple_count; i++) {
162 buf[i] = *tuple_buffers[i];
163 buf[i].dim = computed_bounds + (i + 1) * dimensions;
164 for (int j = 0; j < dimensions; j++) {
165 buf[i].dim[j] = tuple_buffers[i]->dim[j];
166 }
167 }
168 return true;
169 }
170
destroy()171 WEAK void CacheEntry::destroy() {
172 for (uint32_t i = 0; i < tuple_count; i++) {
173 halide_device_free(NULL, &buf[i]);
174 halide_free(NULL, get_pointer_to_header(buf[i].host));
175 }
176 halide_free(NULL, metadata_storage);
177 }
178
djb_hash(const uint8_t * key,size_t key_size)179 WEAK uint32_t djb_hash(const uint8_t *key, size_t key_size) {
180 uint32_t h = 5381;
181 for (size_t i = 0; i < key_size; i++) {
182 h = (h << 5) + h + key[i];
183 }
184 return h;
185 }
186
187 WEAK halide_mutex memoization_lock = {{0}};
188
189 const size_t kHashTableSize = 256;
190
191 WEAK CacheEntry *cache_entries[kHashTableSize];
192
193 WEAK CacheEntry *most_recently_used = NULL;
194 WEAK CacheEntry *least_recently_used = NULL;
195
196 const uint64_t kDefaultCacheSize = 1 << 20;
197 WEAK int64_t max_cache_size = kDefaultCacheSize;
198 WEAK int64_t current_cache_size = 0;
199
200 #if CACHE_DEBUGGING
validate_cache()201 WEAK void validate_cache() {
202 print(NULL) << "validating cache, "
203 << "current size " << current_cache_size
204 << " of maximum " << max_cache_size << "\n";
205 int entries_in_hash_table = 0;
206 for (size_t i = 0; i < kHashTableSize; i++) {
207 CacheEntry *entry = cache_entries[i];
208 while (entry != NULL) {
209 entries_in_hash_table++;
210 if (entry->more_recent == NULL && entry != most_recently_used) {
211 halide_print(NULL, "cache invalid case 1\n");
212 __builtin_trap();
213 }
214 if (entry->less_recent == NULL && entry != least_recently_used) {
215 halide_print(NULL, "cache invalid case 2\n");
216 __builtin_trap();
217 }
218 entry = entry->next;
219 }
220 }
221 int entries_from_mru = 0;
222 CacheEntry *mru_chain = most_recently_used;
223 while (mru_chain != NULL) {
224 entries_from_mru++;
225 mru_chain = mru_chain->less_recent;
226 }
227 int entries_from_lru = 0;
228 CacheEntry *lru_chain = least_recently_used;
229 while (lru_chain != NULL) {
230 entries_from_lru++;
231 lru_chain = lru_chain->more_recent;
232 }
233 print(NULL) << "hash entries " << entries_in_hash_table
234 << ", mru entries " << entries_from_mru
235 << ", lru entries " << entries_from_lru << "\n";
236 if (entries_in_hash_table != entries_from_mru) {
237 halide_print(NULL, "cache invalid case 3\n");
238 __builtin_trap();
239 }
240 if (entries_in_hash_table != entries_from_lru) {
241 halide_print(NULL, "cache invalid case 4\n");
242 __builtin_trap();
243 }
244 if (current_cache_size < 0) {
245 halide_print(NULL, "cache size is negative\n");
246 __builtin_trap();
247 }
248 }
249 #endif
250
prune_cache()251 WEAK void prune_cache() {
252 #if CACHE_DEBUGGING
253 validate_cache();
254 #endif
255 CacheEntry *prune_candidate = least_recently_used;
256 while (current_cache_size > max_cache_size &&
257 prune_candidate != NULL) {
258 CacheEntry *more_recent = prune_candidate->more_recent;
259
260 if (prune_candidate->in_use_count == 0) {
261 uint32_t h = prune_candidate->hash;
262 uint32_t index = h % kHashTableSize;
263
264 // Remove from hash table
265 CacheEntry *prev_hash_entry = cache_entries[index];
266 if (prev_hash_entry == prune_candidate) {
267 cache_entries[index] = prune_candidate->next;
268 } else {
269 while (prev_hash_entry != NULL && prev_hash_entry->next != prune_candidate) {
270 prev_hash_entry = prev_hash_entry->next;
271 }
272 halide_assert(NULL, prev_hash_entry != NULL);
273 prev_hash_entry->next = prune_candidate->next;
274 }
275
276 // Remove from less recent chain.
277 if (least_recently_used == prune_candidate) {
278 least_recently_used = more_recent;
279 }
280 if (more_recent != NULL) {
281 more_recent->less_recent = prune_candidate->less_recent;
282 }
283
284 // Remove from more recent chain.
285 if (most_recently_used == prune_candidate) {
286 most_recently_used = prune_candidate->less_recent;
287 }
288 if (prune_candidate->less_recent != NULL) {
289 prune_candidate->less_recent = more_recent;
290 }
291
292 // Decrease cache used amount.
293 for (uint32_t i = 0; i < prune_candidate->tuple_count; i++) {
294 current_cache_size -= prune_candidate->buf[i].size_in_bytes();
295 }
296
297 // Deallocate the entry.
298 prune_candidate->destroy();
299 halide_free(NULL, prune_candidate);
300 }
301
302 prune_candidate = more_recent;
303 }
304 #if CACHE_DEBUGGING
305 validate_cache();
306 #endif
307 }
308
309 } // namespace Internal
310 } // namespace Runtime
311 } // namespace Halide
312
313 extern "C" {
314
halide_memoization_cache_set_size(int64_t size)315 WEAK void halide_memoization_cache_set_size(int64_t size) {
316 if (size == 0) {
317 size = kDefaultCacheSize;
318 }
319
320 ScopedMutexLock lock(&memoization_lock);
321
322 max_cache_size = size;
323 prune_cache();
324 }
325
halide_memoization_cache_lookup(void * user_context,const uint8_t * cache_key,int32_t size,halide_buffer_t * computed_bounds,int32_t tuple_count,halide_buffer_t ** tuple_buffers)326 WEAK int halide_memoization_cache_lookup(void *user_context, const uint8_t *cache_key, int32_t size,
327 halide_buffer_t *computed_bounds, int32_t tuple_count, halide_buffer_t **tuple_buffers) {
328 uint32_t h = djb_hash(cache_key, size);
329 uint32_t index = h % kHashTableSize;
330
331 ScopedMutexLock lock(&memoization_lock);
332
333 #if CACHE_DEBUGGING
334 debug_print_key(user_context, "halide_memoization_cache_lookup", cache_key, size);
335
336 debug_print_buffer(user_context, "computed_bounds", *computed_bounds);
337
338 {
339 for (int32_t i = 0; i < tuple_count; i++) {
340 halide_buffer_t *buf = tuple_buffers[i];
341 debug_print_buffer(user_context, "Allocation bounds", *buf);
342 }
343 }
344 #endif
345
346 CacheEntry *entry = cache_entries[index];
347 while (entry != NULL) {
348 if (entry->hash == h && entry->key_size == (size_t)size &&
349 keys_equal(entry->key, cache_key, size) &&
350 buffer_has_shape(computed_bounds, entry->computed_bounds) &&
351 entry->tuple_count == (uint32_t)tuple_count) {
352
353 // Check all the tuple buffers have the same bounds (they should).
354 bool all_bounds_equal = true;
355 for (int32_t i = 0; all_bounds_equal && i < tuple_count; i++) {
356 all_bounds_equal = buffer_has_shape(tuple_buffers[i], entry->buf[i].dim);
357 }
358
359 if (all_bounds_equal) {
360 if (entry != most_recently_used) {
361 halide_assert(user_context, entry->more_recent != NULL);
362 if (entry->less_recent != NULL) {
363 entry->less_recent->more_recent = entry->more_recent;
364 } else {
365 halide_assert(user_context, least_recently_used == entry);
366 least_recently_used = entry->more_recent;
367 }
368 halide_assert(user_context, entry->more_recent != NULL);
369 entry->more_recent->less_recent = entry->less_recent;
370
371 entry->more_recent = NULL;
372 entry->less_recent = most_recently_used;
373 if (most_recently_used != NULL) {
374 most_recently_used->more_recent = entry;
375 }
376 most_recently_used = entry;
377 }
378
379 for (int32_t i = 0; i < tuple_count; i++) {
380 halide_buffer_t *buf = tuple_buffers[i];
381 *buf = entry->buf[i];
382 }
383
384 entry->in_use_count += tuple_count;
385
386 return 0;
387 }
388 }
389 entry = entry->next;
390 }
391
392 for (int32_t i = 0; i < tuple_count; i++) {
393 halide_buffer_t *buf = tuple_buffers[i];
394
395 buf->host = ((uint8_t *)halide_malloc(user_context, buf->size_in_bytes() + header_bytes()));
396 if (buf->host == NULL) {
397 for (int32_t j = i; j > 0; j--) {
398 halide_free(user_context, get_pointer_to_header(tuple_buffers[j - 1]->host));
399 tuple_buffers[j - 1]->host = NULL;
400 }
401 return -1;
402 }
403 buf->host += header_bytes();
404 CacheBlockHeader *header = get_pointer_to_header(buf->host);
405 header->hash = h;
406 header->entry = NULL;
407 }
408
409 #if CACHE_DEBUGGING
410 validate_cache();
411 #endif
412
413 return 1;
414 }
415
halide_memoization_cache_store(void * user_context,const uint8_t * cache_key,int32_t size,halide_buffer_t * computed_bounds,int32_t tuple_count,halide_buffer_t ** tuple_buffers)416 WEAK int halide_memoization_cache_store(void *user_context, const uint8_t *cache_key, int32_t size,
417 halide_buffer_t *computed_bounds,
418 int32_t tuple_count, halide_buffer_t **tuple_buffers) {
419 debug(user_context) << "halide_memoization_cache_store\n";
420
421 uint32_t h = get_pointer_to_header(tuple_buffers[0]->host)->hash;
422
423 uint32_t index = h % kHashTableSize;
424
425 ScopedMutexLock lock(&memoization_lock);
426
427 #if CACHE_DEBUGGING
428 debug_print_key(user_context, "halide_memoization_cache_store", cache_key, size);
429
430 debug_print_buffer(user_context, "computed_bounds", *computed_bounds);
431
432 {
433 for (int32_t i = 0; i < tuple_count; i++) {
434 halide_buffer_t *buf = tuple_buffers[i];
435 debug_print_buffer(user_context, "Allocation bounds", *buf);
436 }
437 }
438 #endif
439
440 CacheEntry *entry = cache_entries[index];
441 while (entry != NULL) {
442 if (entry->hash == h && entry->key_size == (size_t)size &&
443 keys_equal(entry->key, cache_key, size) &&
444 buffer_has_shape(computed_bounds, entry->computed_bounds) &&
445 entry->tuple_count == (uint32_t)tuple_count) {
446
447 bool all_bounds_equal = true;
448 bool no_host_pointers_equal = true;
449 {
450 for (int32_t i = 0; all_bounds_equal && i < tuple_count; i++) {
451 halide_buffer_t *buf = tuple_buffers[i];
452 all_bounds_equal = buffer_has_shape(tuple_buffers[i], entry->buf[i].dim);
453 if (entry->buf[i].host == buf->host) {
454 no_host_pointers_equal = false;
455 }
456 }
457 }
458 if (all_bounds_equal) {
459 halide_assert(user_context, no_host_pointers_equal);
460 // This entry is still in use by the caller. Mark it as having no cache entry
461 // so halide_memoization_cache_release can free the buffer.
462 for (int32_t i = 0; i < tuple_count; i++) {
463 get_pointer_to_header(tuple_buffers[i]->host)->entry = NULL;
464 }
465 return 0;
466 }
467 }
468 entry = entry->next;
469 }
470
471 uint64_t added_size = 0;
472 {
473 for (int32_t i = 0; i < tuple_count; i++) {
474 halide_buffer_t *buf = tuple_buffers[i];
475 added_size += buf->size_in_bytes();
476 }
477 }
478 current_cache_size += added_size;
479 prune_cache();
480
481 CacheEntry *new_entry = (CacheEntry *)halide_malloc(NULL, sizeof(CacheEntry));
482 bool inited = false;
483 if (new_entry) {
484 inited = new_entry->init(cache_key, size, h, computed_bounds, tuple_count, tuple_buffers);
485 }
486 if (!inited) {
487 current_cache_size -= added_size;
488
489 // This entry is still in use by the caller. Mark it as having no cache entry
490 // so halide_memoization_cache_release can free the buffer.
491 for (int32_t i = 0; i < tuple_count; i++) {
492 get_pointer_to_header(tuple_buffers[i]->host)->entry = NULL;
493 }
494
495 if (new_entry) {
496 halide_free(user_context, new_entry);
497 }
498 return 0;
499 }
500
501 new_entry->next = cache_entries[index];
502 new_entry->less_recent = most_recently_used;
503 if (most_recently_used != NULL) {
504 most_recently_used->more_recent = new_entry;
505 }
506 most_recently_used = new_entry;
507 if (least_recently_used == NULL) {
508 least_recently_used = new_entry;
509 }
510 cache_entries[index] = new_entry;
511
512 new_entry->in_use_count = tuple_count;
513
514 for (int32_t i = 0; i < tuple_count; i++) {
515 get_pointer_to_header(tuple_buffers[i]->host)->entry = new_entry;
516 }
517
518 #if CACHE_DEBUGGING
519 validate_cache();
520 #endif
521 debug(user_context) << "Exiting halide_memoization_cache_store\n";
522
523 return 0;
524 }
525
halide_memoization_cache_release(void * user_context,void * host)526 WEAK void halide_memoization_cache_release(void *user_context, void *host) {
527 CacheBlockHeader *header = get_pointer_to_header((uint8_t *)host);
528 debug(user_context) << "halide_memoization_cache_release\n";
529 CacheEntry *entry = header->entry;
530
531 if (entry == NULL) {
532 halide_free(user_context, header);
533 } else {
534 ScopedMutexLock lock(&memoization_lock);
535
536 halide_assert(user_context, entry->in_use_count > 0);
537 entry->in_use_count--;
538 #if CACHE_DEBUGGING
539 validate_cache();
540 #endif
541 }
542
543 debug(user_context) << "Exited halide_memoization_cache_release.\n";
544 }
545
halide_memoization_cache_cleanup()546 WEAK void halide_memoization_cache_cleanup() {
547 debug(NULL) << "halide_memoization_cache_cleanup\n";
548 for (size_t i = 0; i < kHashTableSize; i++) {
549 CacheEntry *entry = cache_entries[i];
550 cache_entries[i] = NULL;
551 while (entry != NULL) {
552 CacheEntry *next = entry->next;
553 entry->destroy();
554 halide_free(NULL, entry);
555 entry = next;
556 }
557 }
558 current_cache_size = 0;
559 most_recently_used = NULL;
560 least_recently_used = NULL;
561 }
562
563 namespace {
564
halide_cache_cleanup()565 WEAK __attribute__((destructor)) void halide_cache_cleanup() {
566 halide_memoization_cache_cleanup();
567 }
568
569 } // namespace
570 }
571