1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 // LINT_C_FILE
21 
22 /*!
23  * \file graph_runtime.c
24  * \brief implement graph runtime in pure C
25  */
26 
27 #include <tvm/runtime/c_runtime_api.h>
28 #include <tvm/runtime/crt/internal/graph_runtime/graph_runtime.h>
29 #include <tvm/runtime/crt/logging.h>
30 #include <tvm/runtime/crt/memory.h>
31 #include <tvm/runtime/crt/module.h>
32 #include <tvm/runtime/crt/packed_func.h>
33 
34 #include "crt_config.h"
35 
36 #ifndef MAX
37 #define MAX(a, b) (((a) > (b)) ? (a) : (b))
38 #endif  // MAX
39 
Shape_Accumulate(int64_t * shape,uint32_t ndim)40 uint32_t Shape_Accumulate(int64_t* shape, uint32_t ndim) {
41   int64_t accum = 1;
42   uint32_t idx;
43   for (idx = 0; idx < ndim; idx++) {
44     if (shape[idx] == 0) {
45       break;
46     }
47     accum *= shape[idx];
48   }
49   return accum;
50 }
51 
NodeEntry_Load(TVMGraphRuntimeNodeEntry * entry,JSONReader * reader)52 int NodeEntry_Load(TVMGraphRuntimeNodeEntry* entry, JSONReader* reader) {
53   int status = 0;
54   reader->BeginArray(reader);
55   if (!(reader->NextArrayItem(reader))) {
56     fprintf(stderr, "invalid json format: failed to parse `node_id`\n");
57     status = -1;
58   }
59   reader->ReadUnsignedInteger(reader, &(entry->node_id));
60   if (!(reader->NextArrayItem(reader))) {
61     fprintf(stderr, "invalid json format: failed to parse `index`\n");
62     status = -1;
63   }
64   reader->ReadUnsignedInteger(reader, &(entry->index));
65   if (reader->NextArrayItem(reader)) {
66     reader->ReadUnsignedInteger(reader, &(entry->version));
67     if (reader->NextArrayItem(reader)) {
68       fprintf(stderr, "invalid json format: failed to parse `version`\n");
69       status = -1;
70     }
71   } else {
72     entry->version = 0;
73   }
74   return status;
75 }
76 
TVMGraphRuntimeNode_LoadAttrs(TVMGraphRuntimeNode * node,JSONReader * reader,TVMOpParam * param)77 void TVMGraphRuntimeNode_LoadAttrs(TVMGraphRuntimeNode* node, JSONReader* reader,
78                                    TVMOpParam* param) {
79   int bitmask = 0;
80   char key[20], value[120];
81   memset(param, 0, sizeof(TVMOpParam));
82   memset(key, 0, sizeof(key));
83   memset(value, 0, sizeof(value));
84   reader->BeginObject(reader);
85   while (reader->NextObjectItem(reader, key, sizeof(key))) {
86     int status = reader->ReadString(reader, value, sizeof(value));
87     if (status != 0) {
88       fprintf(stderr, "error reading value for key: %s\n", key);
89       break;
90     }
91     if (!strcmp(key, "func_name")) {
92       snprintf(param->func_name, sizeof(value), "%s", value);
93       bitmask |= 1;
94     } else if (!strcmp(key, "num_inputs")) {
95       param->num_inputs = strtoul(value, 0, 10);
96       bitmask |= 2;
97     } else if (!strcmp(key, "num_outputs")) {
98       param->num_outputs = strtoul(value, 0, 10);
99       bitmask |= 4;
100     } else if (!strcmp(key, "flatten_data")) {
101       param->flatten_data = strtoul(value, 0, 10);
102       bitmask |= 8;
103     } else {
104       fprintf(stderr, "do not support key %s", key);
105     }
106   }
107   if (bitmask != (1 | 2 | 4 | 8)) {
108     fprintf(stderr, "invalid format\n");
109   }
110 }
111 
TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode * node,JSONReader * reader)112 int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode* node, JSONReader* reader) {
113   int status = 0;
114   reader->BeginObject(reader);
115   int bitmask = 0;
116   char key[20];
117   while (reader->NextObjectItem(reader, key, sizeof(key))) {
118     if (!strcmp(key, "op")) {
119       status = reader->ReadString(reader, node->op_type, sizeof(node->op_type));
120       if (status != 0) {
121         fprintf(stderr, "error reading op\n");
122         break;
123       }
124       bitmask |= 1;
125     } else if (!strcmp(key, "name")) {
126       status = reader->ReadString(reader, node->name, sizeof(node->name));
127       if (status != 0) {
128         fprintf(stderr, "error reading name\n");
129         break;
130       }
131       bitmask |= 2;
132     } else if (!strcmp(key, "inputs")) {
133       size_t count = node->inputs_count;
134       reader->BeginArray(reader);
135       while (reader->NextArrayItem(reader)) {
136         node->inputs = vrealloc(node->inputs, sizeof(TVMGraphRuntimeNodeEntry) * (count + 1));
137         TVMGraphRuntimeNodeEntry* inputs = node->inputs + count;
138         reader->BeginArray(reader);
139         if (!reader->NextArrayItem(reader)) {
140           fprintf(stderr, "invalid json format\n");
141           status = -1;
142           break;
143         }
144         reader->ReadUnsignedInteger(reader, &(inputs->node_id));
145         if (!reader->NextArrayItem(reader)) {
146           fprintf(stderr, "invalid json format\n");
147           status = -1;
148           break;
149         }
150         reader->ReadUnsignedInteger(reader, &(inputs->index));
151         if (reader->NextArrayItem(reader)) {
152           reader->ReadUnsignedInteger(reader, &(inputs->version));
153           if (reader->NextArrayItem(reader)) {
154             fprintf(stderr, "invalid json format\n");
155             status = -1;
156             break;
157           }
158         } else {
159           inputs->version = 0;
160         }
161         count++;
162       }
163       node->inputs_count = count;
164       bitmask |= 4;
165     } else if (!strcmp(key, "attr") || !strcmp(key, "attrs")) {
166       TVMOpParam param;
167 
168       TVMGraphRuntimeNode_LoadAttrs(node, reader, &param);
169       memcpy(&node->param, &param, sizeof(param));
170     } else if (!strcmp(key, "control_deps")) {
171       fprintf(stderr, "do not support key %s", key);
172       status = -1;
173     } else {
174       fprintf(stderr, "do not support key %s", key);
175       status = -1;
176     }
177     if (status != 0) {
178       break;
179     }
180   }
181   if (bitmask != (1 | 2 | 4)) {
182     fprintf(stderr, "invalid format\n");
183     status = -1;
184   }
185   return status;
186 }
187 
TVMGraphRuntimeNodeCreate()188 TVMGraphRuntimeNode TVMGraphRuntimeNodeCreate() {
189   TVMGraphRuntimeNode node;
190   memset(&node, 0, sizeof(TVMGraphRuntimeNode));
191   node.LoadAttrs = TVMGraphRuntimeNode_LoadAttrs;
192   node.Load = TVMGraphRuntimeNode_Load;
193   return node;
194 }
195 
TVMGraphRuntimeNodeRelease(TVMGraphRuntimeNode * node)196 void TVMGraphRuntimeNodeRelease(TVMGraphRuntimeNode* node) {
197   if (!node) {
198     return;
199   }
200   if (node->inputs) {
201     vfree(node->inputs);
202     node->inputs = 0;
203   }
204 }
205 
TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr,JSONReader * reader)206 int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* reader) {
207   int status = 0;
208   int bitmask = 0;
209   char key[16], type[16];
210   uint32_t storage_id_count = 0;
211   uint32_t dltype_count = 0;
212   uint32_t shape_count = 0;
213   uint32_t device_index_count = 0;
214   reader->BeginObject(reader);
215   while (reader->NextObjectItem(reader, key, sizeof(key))) {
216     if (!strcmp(key, "dltype")) {
217       reader->BeginArray(reader);
218       if (!(reader->NextArrayItem(reader))) {
219         fprintf(stderr, "Invalid json format\n");
220         status = -1;
221         break;
222       }
223       status = reader->ReadString(reader, type, sizeof(type));
224       if (status != 0) {
225         fprintf(stderr, "error reading dltype type\n");
226         break;
227       }
228       if (strcmp(type, "list_str")) {
229         fprintf(stderr, "Invalid json format\n");
230         status = -1;
231         break;
232       }
233       if (!(reader->NextArrayItem(reader))) {
234         fprintf(stderr, "Invalid json format\n");
235         status = -1;
236         break;
237       }
238       reader->BeginArray(reader);
239       while (reader->NextArrayItem(reader)) {
240         attr->dltype = vrealloc(attr->dltype, TVM_CRT_STRLEN_DLTYPE * (dltype_count + 1));
241         status = reader->ReadString(reader, attr->dltype + dltype_count * TVM_CRT_STRLEN_DLTYPE,
242                                     TVM_CRT_STRLEN_DLTYPE);
243         if (status != 0) {
244           fprintf(stderr, "error reading dltype array item");
245           break;
246         }
247         dltype_count++;
248       }
249       attr->dltype_count = dltype_count;
250 
251       if (reader->NextArrayItem(reader)) {
252         fprintf(stderr, "Invalid json format\n");
253         status = -1;
254         break;
255       }
256       bitmask |= 1;
257     } else if (!strcmp(key, "storage_id")) {
258       reader->BeginArray(reader);
259       if (!(reader->NextArrayItem(reader))) {
260         fprintf(stderr, "Invalid json format\n");
261         status = -1;
262         break;
263       }
264       status = reader->ReadString(reader, type, sizeof(type));
265       if (status != 0) {
266         fprintf(stderr, "error reading device_index array item");
267       }
268       if (strcmp(type, "list_int")) {
269         fprintf(stderr, "Invalid json format\n");
270         status = -1;
271         break;
272       }
273       if (!(reader->NextArrayItem(reader))) {
274         fprintf(stderr, "Invalid json format\n");
275         status = -1;
276         break;
277       }
278       reader->BeginArray(reader);
279       while (reader->NextArrayItem(reader)) {
280         attr->storage_id = vrealloc(attr->storage_id, sizeof(uint32_t) * (storage_id_count + 1));
281         reader->ReadUnsignedInteger(reader, &(attr->storage_id[storage_id_count]));
282         storage_id_count++;
283       }
284       if (reader->NextArrayItem(reader)) {
285         fprintf(stderr, "Invalid json format\n");
286         status = -1;
287         break;
288       }
289       bitmask |= 2;
290     } else if (!strcmp(key, "shape")) {
291       reader->BeginArray(reader);
292       if (!(reader->NextArrayItem(reader))) {
293         fprintf(stderr, "Invalid json format\n");
294         status = -1;
295         break;
296       }
297       status = reader->ReadString(reader, type, sizeof(type));
298       if (status != 0) {
299         fprintf(stderr, "error reading shape array item\n");
300         break;
301       }
302       if (strcmp(type, "list_shape")) {
303         fprintf(stderr, "Invalid json format\n");
304         status = -1;
305         break;
306       }
307       if (!(reader->NextArrayItem(reader))) {
308         fprintf(stderr, "Invalid json format\n");
309         status = -1;
310         break;
311       }
312       reader->BeginArray(reader);
313       while (reader->NextArrayItem(reader)) {
314         attr->shape =
315             vrealloc(attr->shape, sizeof(attr->shape[0]) * (shape_count + 1) * TVM_CRT_MAX_NDIM);
316         attr->ndim = vrealloc(attr->ndim, sizeof(attr->ndim[0]) * (shape_count + 1));
317         reader->BeginArray(reader);
318         int64_t* attr_shape_ptr = attr->shape + shape_count * TVM_CRT_MAX_NDIM;
319         reader->ReadInteger(reader, attr_shape_ptr + 0);
320         uint32_t ndim = 1;
321         if (reader->NextArrayItem(reader)) {
322           for (ndim = 1; ndim < TVM_CRT_MAX_NDIM; ndim++) {
323             if (reader->NextArrayItem(reader)) {
324               reader->ReadInteger(reader, attr_shape_ptr + ndim);
325             } else {
326               break;
327             }
328           }
329           if (ndim == TVM_CRT_MAX_NDIM) {
330             reader->NextArrayItem(reader);
331           }
332         }
333         attr->ndim[shape_count] = ndim;
334         shape_count++;
335       }
336       attr->shape_count = shape_count;
337       if (reader->NextArrayItem(reader)) {
338         fprintf(stderr, "Invalid json format\n");
339         status = -1;
340         break;
341       }
342       bitmask |= 4;
343     } else if (!strcmp(key, "device_index")) {
344       reader->BeginArray(reader);
345       if (!(reader->NextArrayItem(reader))) {
346         fprintf(stderr, "Invalid json format\n");
347         status = -1;
348         break;
349       }
350       status = reader->ReadString(reader, type, sizeof(type));
351       if (status != 0) {
352         fprintf(stderr, "error reading device_index array item");
353         break;
354       }
355       if (strcmp(type, "list_int")) {
356         fprintf(stderr, "Invalid json format\n");
357         status = -1;
358         break;
359       }
360       if (!(reader->NextArrayItem(reader))) {
361         fprintf(stderr, "Invalid json format\n");
362         status = -1;
363         break;
364       }
365       while (reader->NextArrayItem(reader)) {
366         attr->device_index =
367             vrealloc(attr->device_index, sizeof(uint32_t) * (device_index_count + 1));
368         reader->ReadUnsignedInteger(reader, &(attr->device_index[device_index_count]));
369         device_index_count++;
370       }
371       if (reader->NextArrayItem(reader)) {
372         fprintf(stderr, "Invalid json format\n");
373         status = -1;
374         break;
375       }
376     } else {
377       reader->BeginArray(reader);
378       if (!(reader->NextArrayItem(reader))) {
379         fprintf(stderr, "Invalid json format\n");
380         status = -1;
381         break;
382       }
383       reader->ReadString(reader, type, sizeof(type));
384       if (!strcmp(type, "list_int")) {
385         if (!(reader->NextArrayItem(reader))) {
386           fprintf(stderr, "Invalid json format\n");
387           status = -1;
388           break;
389         }
390         uint32_t* temp = 0;
391         uint32_t temp_count = 0;
392         reader->BeginArray(reader);
393         while (reader->NextArrayItem(reader)) {
394           temp = vrealloc(temp, sizeof(uint32_t) * (temp_count + 1));
395           reader->ReadUnsignedInteger(reader, &(temp[temp_count]));
396           temp_count++;
397         }
398         if (temp) {
399           vfree(temp);
400           temp = 0;
401         }
402       } else if (!strcmp(type, "size_t")) {
403         if (!(reader->NextArrayItem(reader))) {
404           fprintf(stderr, "Invalid json format\n");
405           status = -1;
406           break;
407         }
408         uint32_t temp;
409         reader->ReadUnsignedInteger(reader, &temp);
410       } else {
411         fprintf(stderr, "cannot skip graph attr %s", key);
412         status = -1;
413         break;
414       }
415       if (reader->NextArrayItem(reader)) {
416         fprintf(stderr, "Invalid json format\n");
417         status = -1;
418         break;
419       }
420     }
421   }
422   if (bitmask != (1 | 2 | 4)) {
423     fprintf(stderr, "invalid format\n");
424     status = -1;
425   }
426   return status;
427 }
428 
TVMGraphRuntimeGraphAttr_Release(TVMGraphRuntimeGraphAttr * attr)429 void TVMGraphRuntimeGraphAttr_Release(TVMGraphRuntimeGraphAttr* attr) {
430   if (!attr) {
431     return;
432   }
433   if (attr->storage_id) {
434     vfree(attr->storage_id);
435     attr->storage_id = 0;
436   }
437   if (attr->device_index) {
438     vfree(attr->device_index);
439     attr->device_index = 0;
440   }
441   if (attr->dltype) {
442     vfree(attr->dltype);
443     attr->dltype = 0;
444   }
445   if (attr->shape) {
446     vfree(attr->shape);
447     attr->shape = 0;
448   }
449   if (attr->ndim) {
450     vfree(attr->ndim);
451     attr->ndim = 0;
452   }
453 }
454 
TVMGraphRuntime_Load(TVMGraphRuntime * runtime,JSONReader * reader)455 int TVMGraphRuntime_Load(TVMGraphRuntime* runtime, JSONReader* reader) {
456   int status = 0;
457   reader->BeginObject(reader);
458   int bitmask = 0;
459   char key[20];
460   while (reader->NextObjectItem(reader, key, sizeof(key))) {
461     if (!strcmp(key, "nodes")) {
462       reader->BeginArray(reader);
463       while (reader->NextArrayItem(reader)) {
464         runtime->nodes =
465             vrealloc(runtime->nodes, sizeof(TVMGraphRuntimeNode) * (runtime->nodes_count + 1));
466         TVMGraphRuntimeNode* node = runtime->nodes + runtime->nodes_count;
467         status = TVMGraphRuntimeNode_Load(node, reader);
468         if (status != 0) {
469           fprintf(stderr, "failed to load an element in `nodes` field in graph runtime node.\n");
470           break;
471 #if TVM_CRT_DEBUG
472         } else {
473           printf("loading: node (%u) %s loaded.\n", runtime->nodes_count, node->name);
474 #endif  // TVM_CRT_DEBUG
475         }
476         runtime->nodes_count++;
477       }
478       bitmask |= 1;
479     } else if (!strcmp(key, "arg_nodes")) {
480       reader->BeginArray(reader);
481       while (reader->NextArrayItem(reader)) {
482         runtime->input_nodes =
483             vrealloc(runtime->input_nodes, sizeof(uint32_t) * (runtime->input_nodes_count + 1));
484         uint32_t* node = runtime->input_nodes + runtime->input_nodes_count;
485         reader->ReadUnsignedInteger(reader, node);
486         runtime->input_nodes_count++;
487       }
488       bitmask |= 2;
489     } else if (!strcmp(key, "node_row_ptr")) {
490       reader->BeginArray(reader);
491       while (reader->NextArrayItem(reader)) {
492         runtime->node_row_ptr =
493             vrealloc(runtime->node_row_ptr, sizeof(uint32_t) * (runtime->node_row_ptr_count + 1));
494         uint32_t count = runtime->node_row_ptr_count;
495         uint32_t* node = runtime->node_row_ptr + count;
496         reader->ReadUnsignedInteger(reader, node);
497         runtime->node_row_ptr_count++;
498       }
499       bitmask |= 4;
500     } else if (!strcmp(key, "heads")) {
501       reader->BeginArray(reader);
502       while (reader->NextArrayItem(reader)) {
503         runtime->outputs = vrealloc(
504             runtime->outputs, sizeof(TVMGraphRuntimeNodeEntry) * (runtime->outputs_count + 1));
505         TVMGraphRuntimeNodeEntry* entry = runtime->outputs + runtime->outputs_count;
506         status = NodeEntry_Load(entry, reader);
507         if (status != 0) {
508           fprintf(stderr, "Fail to load an element in `heads` field in graph runtime node.\n");
509           break;
510         }
511         runtime->outputs_count++;
512       }
513       bitmask |= 8;
514     } else if (!strcmp(key, "attrs")) {
515       status = TVMGraphRuntimeGraphAttr_Load(&(runtime->attrs), reader);
516       if (status != 0) {
517         fprintf(stderr, "Fail to load an element in `heads` field in graph runtime node.\n");
518         break;
519       }
520       bitmask |= 16;
521     } else if (!strcmp(key, "metadata")) {
522       break;
523     } else {
524       fprintf(stderr, "key %s is not supported\n", key);
525       status = -1;
526     }
527     if (status != 0) {
528       break;
529     }
530   }
531   if (!(bitmask == (1 | 2 | 4 | 8 | 16))) {
532     fprintf(stderr, "invalid format\n");
533     status = -1;
534   }
535   return status;
536 }
537 
TVMGraphRuntime_GetEntryId(TVMGraphRuntime * runtime,uint32_t nid,uint32_t index)538 uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime* runtime, uint32_t nid, uint32_t index) {
539   return runtime->node_row_ptr[nid] + index;
540 }
541 
542 /*!
543  * \brief Get the input index given the name of input.
544  * \param runtime The graph runtime.
545  * \param name The name of the input.
546  * \return The index of input.
547  */
TVMGraphRuntime_GetInputIndex(TVMGraphRuntime * runtime,const char * name)548 int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime* runtime, const char* name) {
549   uint32_t i;
550   int32_t rv = -1;
551   for (i = 0; i < runtime->input_nodes_count; ++i) {
552     uint32_t nid = runtime->input_nodes[i];
553     if (!strcmp(runtime->nodes[nid].name, name)) {
554       rv = i;
555       break;
556     }
557   }
558   CHECK_GE(rv, 0, "cannot find '%s' among input.", name);
559   return rv;
560 }
561 
562 /*!
563  * \brief set input to the graph based on name.
564  * \param runtime The graph runtime.
565  * \param name The name of the input.
566  * \param data_in The input data.
567  */
TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime,const char * name,DLTensor * data_in)568 void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in) {
569   uint32_t index = TVMGraphRuntime_GetInputIndex(runtime, name);
570   if (index >= runtime->input_nodes_count) {
571     fprintf(stderr, "given index is greater than num of input nodes.\n");
572   }
573   uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, runtime->input_nodes[index], 0);
574   runtime->data_entry[eid].dl_tensor.data = data_in->data;
575 }
576 
577 /*!
578  * \brief Load parameters from parameter blob.
579  * \param runtime The graph runtime.
580  * \param param_blob A binary blob of parameter.
581  * \param param_size The parameter size.
582  * \return The result of this function execution.
583  */
TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime,const char * param_blob,const uint32_t param_size)584 int TVMGraphRuntime_LoadParams(TVMGraphRuntime* runtime, const char* param_blob,
585                                const uint32_t param_size) {
586   int status = 0;
587   const char* bptr = param_blob;
588   uint64_t header, reserved;
589   header = ((uint64_t*)bptr)[0];  // NOLINT(*)
590   bptr += sizeof(header);
591   if (header != kTVMNDArrayListMagic) {
592     fprintf(stderr, "Invalid parameters file format");
593     status = -1;
594   }
595   reserved = ((uint64_t*)bptr)[0];  // NOLINT(*)
596   bptr += sizeof(reserved);
597 
598   // read names
599   char* names = vmalloc(TVM_CRT_STRLEN_NAME * runtime->nodes_count);
600   memset(names, 0, TVM_CRT_STRLEN_NAME * runtime->nodes_count);
601   uint64_t names_count;
602   int idx;
603   names_count = ((uint64_t*)bptr)[0];  // NOLINT(*)
604   bptr += sizeof(names_count);
605   for (idx = 0; idx < names_count; idx++) {
606     uint64_t name_length;
607     name_length = ((uint64_t*)bptr)[0];  // NOLINT(*)
608     bptr += sizeof(name_length);
609     if (name_length >= TVM_CRT_STRLEN_NAME) {
610       fprintf(stderr, "Error: function name longer than expected.\n");
611       status = -1;
612     }
613     memcpy(names + TVM_CRT_STRLEN_NAME * idx, bptr, name_length);
614     bptr += name_length;
615   }
616 
617   // read sizes
618   uint64_t sz;
619   sz = ((uint64_t*)bptr)[0];  // NOLINT(*)
620   bptr += sizeof(sz);
621   uint32_t size = sz;
622   if (size != names_count) {
623     fprintf(stderr, "Invalid parameters file format\n");
624     status = -1;
625   }
626 
627   for (idx = 0; idx < size; idx++) {
628     int32_t in_idx = TVMGraphRuntime_GetInputIndex(runtime, names + TVM_CRT_STRLEN_NAME * idx);
629     CHECK_GT(in_idx, 0, "Found param for non-existent input: %s\n",
630              names + TVM_CRT_STRLEN_NAME * idx);
631     uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, runtime->input_nodes[in_idx], 0);
632     if (!(eid < runtime->data_entry_count)) {
633       fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", eid,
634               runtime->data_entry_count);
635       status = -1;
636     }
637 
638     if (runtime->data_entry[eid].dl_tensor.shape) {
639       vfree(runtime->data_entry[eid].dl_tensor.shape);
640       runtime->data_entry[eid].dl_tensor.shape = 0;
641     }
642     if (runtime->data_entry[eid].dl_tensor.data) {
643       vfree(runtime->data_entry[eid].dl_tensor.data);
644       runtime->data_entry[eid].dl_tensor.data = 0;
645     }
646     status |= TVMNDArray_Load(&(runtime->data_entry[eid]), &bptr);
647 #if TVM_CRT_DEBUG
648     TVMNDArray* entry = &(runtime->data_entry[eid]);
649     printf("loading: param %s loaded, in_idx=%d, eid=%d, ndim=%d, data[0]=%f\n",
650            names + TVM_CRT_STRLEN_NAME * idx, in_idx, eid, entry->dl_tensor.ndim,
651            ((float*)entry->dl_tensor.data)[0]);  // NOLINT(*)
652 #endif                                           // TVM_CRT_DEBUG
653   }
654 
655   // Release memory
656   vfree(names);
657 
658   return status;
659 }
660 
661 /*!
662  * \brief Run all the operations one by one.
663  * \param runtime The graph runtime.
664  */
TVMGraphRuntime_Run(TVMGraphRuntime * runtime)665 void TVMGraphRuntime_Run(TVMGraphRuntime* runtime) {
666   // setup the array and requirements.
667   uint32_t idx;
668   for (idx = 0; idx < runtime->op_execs_count; ++idx) {
669     if (runtime->op_execs[idx].fexec) {
670 #if TVM_CRT_DEBUG
671       printf("calling: %s (%d)\n", runtime->op_execs[idx].name, idx);
672 #endif  // TVM_CRT_DEBUG
673       runtime->op_execs[idx].Call(&(runtime->op_execs[idx]));
674     }
675   }
676 }
677 
TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime,const int32_t idx,DLTensor * out)678 int TVMGraphRuntime_GetOutput(TVMGraphRuntime* runtime, const int32_t idx, DLTensor* out) {
679   int status = 0;
680   uint32_t nid = runtime->outputs[idx].node_id;
681   uint32_t index = runtime->outputs[idx].index;
682   uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, nid, index);
683 
684   // copy data section to allocated output tensor
685   int32_t elem_bytes = out->dtype.bits / 8;
686   int64_t size = Shape_Accumulate(out->shape, out->ndim);
687   DLTensor* tensor = &(runtime->data_entry[eid].dl_tensor);
688   CHECK(out->ndim == tensor->ndim);
689   CHECK(out->dtype.bits == tensor->dtype.bits);
690   CHECK(Shape_Accumulate(out->shape, out->ndim) == Shape_Accumulate(tensor->shape, tensor->ndim));
691   memcpy(out->data, tensor->data, size * elem_bytes);
692   return status;
693 }
694 
TVMGraphRuntime_SetupStorage(TVMGraphRuntime * runtime)695 void TVMGraphRuntime_SetupStorage(TVMGraphRuntime* runtime) {
696   uint32_t idx;
697 
698   // Grab saved optimization plan from graph.
699   TVMGraphRuntimeGraphAttr* attrs = &(runtime->attrs);
700   DLDataType* vtype = vmalloc(sizeof(DLDataType) * attrs->dltype_count);
701   for (idx = 0; idx < attrs->dltype_count; idx++) {
702     vtype[idx] = String2DLDataType(attrs->dltype + idx * TVM_CRT_STRLEN_DLTYPE);
703   }
704 
705   // Size and device type of each storage pool entry.
706   TVMGraphRuntimePoolEntry* pool_entry =
707       vmalloc(sizeof(TVMGraphRuntimePoolEntry) * runtime->nodes_count);
708   memset(pool_entry, 0, sizeof(TVMGraphRuntimePoolEntry) * runtime->nodes_count);
709   uint32_t pool_entry_count = 0;
710   // Find the maximum space size.
711   for (idx = 0; idx < attrs->shape_count; idx++) {
712     int storage_id = attrs->storage_id[idx];
713     // Use the fallback device if no device index is available.
714     int device_type = runtime->ctxs[0].device_type;
715     uint32_t size = Shape_Accumulate(attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx]);
716     DLDataType t = vtype[idx];
717     uint32_t bits = t.bits * t.lanes;
718     size_t bytes = ((bits + 7U) / 8U) * size;
719 
720     uint32_t sid = storage_id;
721     if (sid >= pool_entry_count) {
722       pool_entry_count = sid + 1;
723     }
724     pool_entry[sid].size = MAX(pool_entry[sid].size, bytes);
725     pool_entry[sid].device_type = device_type;
726   }
727 
728   // Allocate the space.
729   for (idx = 0; idx < pool_entry_count; idx++) {
730     runtime->storage_pool =
731         vrealloc(runtime->storage_pool, sizeof(TVMNDArray) * (runtime->storage_pool_count + 1));
732     TVMGraphRuntimePoolEntry pit = pool_entry[idx];
733     int64_t shape[TVM_CRT_MAX_NDIM] = {
734         0,
735     };
736     TVMContext ctx = runtime->ctxs[0];
737     DLDataType dtype = {kDLFloat, 32, 1};
738     shape[0] = (pit.size + 3) / 4;
739     runtime->storage_pool[runtime->storage_pool_count] = TVMNDArray_Empty(1, shape, dtype, ctx);
740     CHECK_NE(runtime->storage_pool[runtime->storage_pool_count].dl_tensor.data, 0,
741              "fail to create storage_pool with idx=%d\n", idx);
742     runtime->storage_pool_count++;
743   }
744 
745   // Assign the pooled entries. A unified memory pool is used to simplifiy
746   // memory assignment for each node entry. The allocated memory on each device
747   // is mapped to this pool.
748   runtime->data_entry_count = runtime->node_row_ptr[runtime->node_row_ptr_count - 1];
749   runtime->data_entry = vmalloc(sizeof(TVMNDArray) * runtime->data_entry_count);
750   for (idx = 0; idx < runtime->data_entry_count; ++idx) {
751     uint32_t storage_id = attrs->storage_id[idx];
752     CHECK(storage_id < runtime->storage_pool_count);
753     runtime->data_entry[idx] =
754         TVMNDArray_CreateView(&(runtime->storage_pool[storage_id]),
755                               attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx], vtype[idx]);
756     CHECK_NE(runtime->data_entry[idx].dl_tensor.data, 0,
757              "fail to create for node with idx=%d, storage_id=%u\n", idx, storage_id);
758   }
759 
760   // Release memory
761   vfree(vtype);
762   vfree(pool_entry);
763 }
764 
TVMGraphRuntime_SetupOpExecs(TVMGraphRuntime * runtime)765 int TVMGraphRuntime_SetupOpExecs(TVMGraphRuntime* runtime) {
766   int status = 0;
767   uint32_t nid, idx;
768   runtime->op_execs_count = runtime->nodes_count;
769   runtime->op_execs = vmalloc(sizeof(TVMPackedFunc) * runtime->op_execs_count);
770   for (nid = 0; nid < runtime->nodes_count; nid++) {
771     const TVMGraphRuntimeNode* inode = runtime->nodes + nid;
772     if (strcmp(inode->op_type, "null")) {
773       DLTensorPtr args[TVM_CRT_MAX_ARGS];
774       uint32_t args_count = 0;
775       for (idx = 0; idx < inode->inputs_count; idx++) {
776         const TVMGraphRuntimeNodeEntry* entry = inode->inputs + idx;
777         uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, entry->node_id, entry->index);
778         args[idx] = &(runtime->data_entry[eid].dl_tensor);
779         args_count++;
780       }
781       for (idx = 0; idx < inode->param.num_outputs; idx++) {
782         uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, nid, idx);
783         args[args_count] = &(runtime->data_entry[eid].dl_tensor);
784         args_count++;
785       }
786       if (strcmp(inode->op_type, "tvm_op")) {
787         fprintf(stderr, "Can only take tvm_op as op, but \"%s\" is found.\n", inode->op_type);
788         status = -1;
789         break;
790       }
791       if (args_count >= TVM_CRT_MAX_ARGS) {
792         fprintf(stderr, "too many arguments: expected less than %d args, but got %d.\n",
793                 TVM_CRT_MAX_ARGS, args_count);
794         status = -1;
795         break;
796       }
797 #if TVM_CRT_DEBUG
798       printf("tvm_op: creating %s with node_id=%d\n", inode->param.func_name, nid);
799 #endif  // TVM_CRT_DEBUG
800       TVMPackedFunc pf;
801       TVMGraphRuntime_CreateTVMOp(runtime, &(inode->param), args, args_count, inode->inputs_count,
802                                   &pf);
803       runtime->op_execs[nid] = pf;
804     }
805   }
806   return status;
807 }
808 
809 typedef struct TVMOpArgs {
810   DLTensor args[TVM_CRT_MAX_ARGS];
811   uint32_t args_count;
812   TVMValue arg_values[TVM_CRT_MAX_ARGS];
813   uint32_t arg_values_count;
814   uint32_t arg_tcodes[TVM_CRT_MAX_ARGS];
815   uint32_t arg_tcodes_count;
816   int64_t shape_data[TVM_CRT_MAX_ARGS];
817   uint32_t shape_data_count;
818 } TVMOpArgs;
819 
TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime,const TVMOpParam * param,DLTensorPtr * args,const uint32_t args_count,uint32_t num_inputs,TVMPackedFunc * pf)820 int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime* runtime, const TVMOpParam* param,
821                                     DLTensorPtr* args, const uint32_t args_count,
822                                     uint32_t num_inputs, TVMPackedFunc* pf) {
823   int status = 0;
824   uint32_t idx;
825   TVMOpArgs arg_ptr;
826   memset(&arg_ptr, 0, sizeof(TVMOpArgs));
827   arg_ptr.args_count = args_count;
828   if (param->flatten_data) {
829     arg_ptr.shape_data_count = arg_ptr.args_count;
830   }
831   for (idx = 0; idx < arg_ptr.args_count; ++idx) {
832     TVMValue v;
833     memset(&v, 0, sizeof(v));
834     DLTensor* t = &(arg_ptr.args[idx]);
835     /* v.v_handle = &((*args)[idx]); */
836     v.v_handle = args[idx];
837     arg_ptr.arg_values[idx] = v;
838     arg_ptr.arg_values_count++;
839     arg_ptr.arg_tcodes[idx] = kTVMNDArrayHandle;
840     arg_ptr.arg_tcodes_count++;
841     if (param->flatten_data) {
842       arg_ptr.shape_data[idx] = Shape_Accumulate(t->shape, t->ndim);
843       t->ndim = 1;
844       t->shape[0] = arg_ptr.shape_data[idx];
845     }
846   }
847   if (!strcmp(param->func_name, "__nop") || !strcmp(param->func_name, "__copy")) {
848     fprintf(stderr, "%s function is not yet supported.", param->func_name);
849     status = -1;
850   }
851 
852   TVMArgs targs = TVMArgs_Create(arg_ptr.arg_values, arg_ptr.arg_tcodes, arg_ptr.arg_values_count);
853   status = TVMPackedFunc_InitModuleFunc(pf, runtime->module_handle, param->func_name, &targs);
854 
855   return status;
856 }
857 
858 /*!
859  * \brief Initialize the graph executor with graph and context.
860  * \param graph_json The execution graph.
861  * \param module The module containing the compiled functions for the host
862  * processor.
863  * \param ctxs The context of the host and devices where graph nodes will be
864  * executed on.
865  */
TVMGraphRuntime_Init(TVMGraphRuntime * runtime,const char * graph_json,const TVMModule * module,const TVMContext * ctxs)866 void TVMGraphRuntime_Init(TVMGraphRuntime* runtime, const char* graph_json, const TVMModule* module,
867                           const TVMContext* ctxs) {
868   JSONReader reader = JSONReader_Create(graph_json);
869   TVMGraphRuntime_Load(runtime, &reader);
870   JSONReader_Release(&reader);
871   runtime->ctxs[0] = ctxs[0];
872   TVMGraphRuntime_SetupStorage(runtime);
873   TVMGraphRuntime_SetupOpExecs(runtime);
874 }
875 
TVMGraphRuntime_Create(const char * sym_json,const TVMModule * m,const TVMContext * ctxs)876 TVMGraphRuntime* TVMGraphRuntime_Create(const char* sym_json, const TVMModule* m,
877                                         const TVMContext* ctxs) {
878   CHECK_EQ(vleak_size, 1, "memory leak checking won't work with concurrent CRT use");
879   TVMGraphRuntime* runtime = (TVMGraphRuntime*)vmalloc(sizeof(TVMGraphRuntime));  // NOLINT(*)
880   memset(runtime, 0, sizeof(TVMGraphRuntime));
881   // init
882   TVMGraphRuntime_Init(runtime, sym_json, m, ctxs);
883   return runtime;
884 }
885 
TVMGraphRuntime_Release(TVMGraphRuntime ** pptr)886 void TVMGraphRuntime_Release(TVMGraphRuntime** pptr) {
887   int32_t idx;
888   TVMGraphRuntime* runtime = (TVMGraphRuntime*)(*pptr);
889   for (idx = 0; idx < runtime->nodes_count; ++idx) {
890     TVMGraphRuntimeNodeRelease(&(runtime->nodes[idx]));
891   }
892   vfree(runtime->nodes);
893   TVMGraphRuntimeGraphAttr_Release(&(runtime->attrs));
894   for (idx = 0; idx < runtime->storage_pool_count; ++idx) {
895     TVMNDArray_Release(&(runtime->storage_pool[idx]));
896   }
897   for (idx = 0; idx < runtime->data_entry_count; ++idx) {
898     vfree(runtime->data_entry[idx].dl_tensor.shape);
899   }
900   vfree(runtime->input_nodes);
901   vfree(runtime->node_row_ptr);
902   vfree(runtime->outputs);
903   vfree(runtime->storage_pool);
904   vfree(runtime->data_entry);
905   vfree(runtime->op_execs);
906   vfree(*pptr);
907 
908   if (g_fexecs) {
909     vfree(g_fexecs);
910     g_fexecs = 0;
911   }
912 
913   CHECK_EQ(vleak_size, 1, "found memory leak, leak size=%d", vleak_size - 1);
914 }
915