1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 // LINT_C_FILE
21
22 /*!
23 * \file graph_runtime.c
24 * \brief implement graph runtime in pure C
25 */
26
27 #include <tvm/runtime/c_runtime_api.h>
28 #include <tvm/runtime/crt/internal/graph_runtime/graph_runtime.h>
29 #include <tvm/runtime/crt/logging.h>
30 #include <tvm/runtime/crt/memory.h>
31 #include <tvm/runtime/crt/module.h>
32 #include <tvm/runtime/crt/packed_func.h>
33
34 #include "crt_config.h"
35
36 #ifndef MAX
37 #define MAX(a, b) (((a) > (b)) ? (a) : (b))
38 #endif // MAX
39
Shape_Accumulate(int64_t * shape,uint32_t ndim)40 uint32_t Shape_Accumulate(int64_t* shape, uint32_t ndim) {
41 int64_t accum = 1;
42 uint32_t idx;
43 for (idx = 0; idx < ndim; idx++) {
44 if (shape[idx] == 0) {
45 break;
46 }
47 accum *= shape[idx];
48 }
49 return accum;
50 }
51
NodeEntry_Load(TVMGraphRuntimeNodeEntry * entry,JSONReader * reader)52 int NodeEntry_Load(TVMGraphRuntimeNodeEntry* entry, JSONReader* reader) {
53 int status = 0;
54 reader->BeginArray(reader);
55 if (!(reader->NextArrayItem(reader))) {
56 fprintf(stderr, "invalid json format: failed to parse `node_id`\n");
57 status = -1;
58 }
59 reader->ReadUnsignedInteger(reader, &(entry->node_id));
60 if (!(reader->NextArrayItem(reader))) {
61 fprintf(stderr, "invalid json format: failed to parse `index`\n");
62 status = -1;
63 }
64 reader->ReadUnsignedInteger(reader, &(entry->index));
65 if (reader->NextArrayItem(reader)) {
66 reader->ReadUnsignedInteger(reader, &(entry->version));
67 if (reader->NextArrayItem(reader)) {
68 fprintf(stderr, "invalid json format: failed to parse `version`\n");
69 status = -1;
70 }
71 } else {
72 entry->version = 0;
73 }
74 return status;
75 }
76
TVMGraphRuntimeNode_LoadAttrs(TVMGraphRuntimeNode * node,JSONReader * reader,TVMOpParam * param)77 void TVMGraphRuntimeNode_LoadAttrs(TVMGraphRuntimeNode* node, JSONReader* reader,
78 TVMOpParam* param) {
79 int bitmask = 0;
80 char key[20], value[120];
81 memset(param, 0, sizeof(TVMOpParam));
82 memset(key, 0, sizeof(key));
83 memset(value, 0, sizeof(value));
84 reader->BeginObject(reader);
85 while (reader->NextObjectItem(reader, key, sizeof(key))) {
86 int status = reader->ReadString(reader, value, sizeof(value));
87 if (status != 0) {
88 fprintf(stderr, "error reading value for key: %s\n", key);
89 break;
90 }
91 if (!strcmp(key, "func_name")) {
92 snprintf(param->func_name, sizeof(value), "%s", value);
93 bitmask |= 1;
94 } else if (!strcmp(key, "num_inputs")) {
95 param->num_inputs = strtoul(value, 0, 10);
96 bitmask |= 2;
97 } else if (!strcmp(key, "num_outputs")) {
98 param->num_outputs = strtoul(value, 0, 10);
99 bitmask |= 4;
100 } else if (!strcmp(key, "flatten_data")) {
101 param->flatten_data = strtoul(value, 0, 10);
102 bitmask |= 8;
103 } else {
104 fprintf(stderr, "do not support key %s", key);
105 }
106 }
107 if (bitmask != (1 | 2 | 4 | 8)) {
108 fprintf(stderr, "invalid format\n");
109 }
110 }
111
TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode * node,JSONReader * reader)112 int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode* node, JSONReader* reader) {
113 int status = 0;
114 reader->BeginObject(reader);
115 int bitmask = 0;
116 char key[20];
117 while (reader->NextObjectItem(reader, key, sizeof(key))) {
118 if (!strcmp(key, "op")) {
119 status = reader->ReadString(reader, node->op_type, sizeof(node->op_type));
120 if (status != 0) {
121 fprintf(stderr, "error reading op\n");
122 break;
123 }
124 bitmask |= 1;
125 } else if (!strcmp(key, "name")) {
126 status = reader->ReadString(reader, node->name, sizeof(node->name));
127 if (status != 0) {
128 fprintf(stderr, "error reading name\n");
129 break;
130 }
131 bitmask |= 2;
132 } else if (!strcmp(key, "inputs")) {
133 size_t count = node->inputs_count;
134 reader->BeginArray(reader);
135 while (reader->NextArrayItem(reader)) {
136 node->inputs = vrealloc(node->inputs, sizeof(TVMGraphRuntimeNodeEntry) * (count + 1));
137 TVMGraphRuntimeNodeEntry* inputs = node->inputs + count;
138 reader->BeginArray(reader);
139 if (!reader->NextArrayItem(reader)) {
140 fprintf(stderr, "invalid json format\n");
141 status = -1;
142 break;
143 }
144 reader->ReadUnsignedInteger(reader, &(inputs->node_id));
145 if (!reader->NextArrayItem(reader)) {
146 fprintf(stderr, "invalid json format\n");
147 status = -1;
148 break;
149 }
150 reader->ReadUnsignedInteger(reader, &(inputs->index));
151 if (reader->NextArrayItem(reader)) {
152 reader->ReadUnsignedInteger(reader, &(inputs->version));
153 if (reader->NextArrayItem(reader)) {
154 fprintf(stderr, "invalid json format\n");
155 status = -1;
156 break;
157 }
158 } else {
159 inputs->version = 0;
160 }
161 count++;
162 }
163 node->inputs_count = count;
164 bitmask |= 4;
165 } else if (!strcmp(key, "attr") || !strcmp(key, "attrs")) {
166 TVMOpParam param;
167
168 TVMGraphRuntimeNode_LoadAttrs(node, reader, ¶m);
169 memcpy(&node->param, ¶m, sizeof(param));
170 } else if (!strcmp(key, "control_deps")) {
171 fprintf(stderr, "do not support key %s", key);
172 status = -1;
173 } else {
174 fprintf(stderr, "do not support key %s", key);
175 status = -1;
176 }
177 if (status != 0) {
178 break;
179 }
180 }
181 if (bitmask != (1 | 2 | 4)) {
182 fprintf(stderr, "invalid format\n");
183 status = -1;
184 }
185 return status;
186 }
187
TVMGraphRuntimeNodeCreate()188 TVMGraphRuntimeNode TVMGraphRuntimeNodeCreate() {
189 TVMGraphRuntimeNode node;
190 memset(&node, 0, sizeof(TVMGraphRuntimeNode));
191 node.LoadAttrs = TVMGraphRuntimeNode_LoadAttrs;
192 node.Load = TVMGraphRuntimeNode_Load;
193 return node;
194 }
195
TVMGraphRuntimeNodeRelease(TVMGraphRuntimeNode * node)196 void TVMGraphRuntimeNodeRelease(TVMGraphRuntimeNode* node) {
197 if (!node) {
198 return;
199 }
200 if (node->inputs) {
201 vfree(node->inputs);
202 node->inputs = 0;
203 }
204 }
205
TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr,JSONReader * reader)206 int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* reader) {
207 int status = 0;
208 int bitmask = 0;
209 char key[16], type[16];
210 uint32_t storage_id_count = 0;
211 uint32_t dltype_count = 0;
212 uint32_t shape_count = 0;
213 uint32_t device_index_count = 0;
214 reader->BeginObject(reader);
215 while (reader->NextObjectItem(reader, key, sizeof(key))) {
216 if (!strcmp(key, "dltype")) {
217 reader->BeginArray(reader);
218 if (!(reader->NextArrayItem(reader))) {
219 fprintf(stderr, "Invalid json format\n");
220 status = -1;
221 break;
222 }
223 status = reader->ReadString(reader, type, sizeof(type));
224 if (status != 0) {
225 fprintf(stderr, "error reading dltype type\n");
226 break;
227 }
228 if (strcmp(type, "list_str")) {
229 fprintf(stderr, "Invalid json format\n");
230 status = -1;
231 break;
232 }
233 if (!(reader->NextArrayItem(reader))) {
234 fprintf(stderr, "Invalid json format\n");
235 status = -1;
236 break;
237 }
238 reader->BeginArray(reader);
239 while (reader->NextArrayItem(reader)) {
240 attr->dltype = vrealloc(attr->dltype, TVM_CRT_STRLEN_DLTYPE * (dltype_count + 1));
241 status = reader->ReadString(reader, attr->dltype + dltype_count * TVM_CRT_STRLEN_DLTYPE,
242 TVM_CRT_STRLEN_DLTYPE);
243 if (status != 0) {
244 fprintf(stderr, "error reading dltype array item");
245 break;
246 }
247 dltype_count++;
248 }
249 attr->dltype_count = dltype_count;
250
251 if (reader->NextArrayItem(reader)) {
252 fprintf(stderr, "Invalid json format\n");
253 status = -1;
254 break;
255 }
256 bitmask |= 1;
257 } else if (!strcmp(key, "storage_id")) {
258 reader->BeginArray(reader);
259 if (!(reader->NextArrayItem(reader))) {
260 fprintf(stderr, "Invalid json format\n");
261 status = -1;
262 break;
263 }
264 status = reader->ReadString(reader, type, sizeof(type));
265 if (status != 0) {
266 fprintf(stderr, "error reading device_index array item");
267 }
268 if (strcmp(type, "list_int")) {
269 fprintf(stderr, "Invalid json format\n");
270 status = -1;
271 break;
272 }
273 if (!(reader->NextArrayItem(reader))) {
274 fprintf(stderr, "Invalid json format\n");
275 status = -1;
276 break;
277 }
278 reader->BeginArray(reader);
279 while (reader->NextArrayItem(reader)) {
280 attr->storage_id = vrealloc(attr->storage_id, sizeof(uint32_t) * (storage_id_count + 1));
281 reader->ReadUnsignedInteger(reader, &(attr->storage_id[storage_id_count]));
282 storage_id_count++;
283 }
284 if (reader->NextArrayItem(reader)) {
285 fprintf(stderr, "Invalid json format\n");
286 status = -1;
287 break;
288 }
289 bitmask |= 2;
290 } else if (!strcmp(key, "shape")) {
291 reader->BeginArray(reader);
292 if (!(reader->NextArrayItem(reader))) {
293 fprintf(stderr, "Invalid json format\n");
294 status = -1;
295 break;
296 }
297 status = reader->ReadString(reader, type, sizeof(type));
298 if (status != 0) {
299 fprintf(stderr, "error reading shape array item\n");
300 break;
301 }
302 if (strcmp(type, "list_shape")) {
303 fprintf(stderr, "Invalid json format\n");
304 status = -1;
305 break;
306 }
307 if (!(reader->NextArrayItem(reader))) {
308 fprintf(stderr, "Invalid json format\n");
309 status = -1;
310 break;
311 }
312 reader->BeginArray(reader);
313 while (reader->NextArrayItem(reader)) {
314 attr->shape =
315 vrealloc(attr->shape, sizeof(attr->shape[0]) * (shape_count + 1) * TVM_CRT_MAX_NDIM);
316 attr->ndim = vrealloc(attr->ndim, sizeof(attr->ndim[0]) * (shape_count + 1));
317 reader->BeginArray(reader);
318 int64_t* attr_shape_ptr = attr->shape + shape_count * TVM_CRT_MAX_NDIM;
319 reader->ReadInteger(reader, attr_shape_ptr + 0);
320 uint32_t ndim = 1;
321 if (reader->NextArrayItem(reader)) {
322 for (ndim = 1; ndim < TVM_CRT_MAX_NDIM; ndim++) {
323 if (reader->NextArrayItem(reader)) {
324 reader->ReadInteger(reader, attr_shape_ptr + ndim);
325 } else {
326 break;
327 }
328 }
329 if (ndim == TVM_CRT_MAX_NDIM) {
330 reader->NextArrayItem(reader);
331 }
332 }
333 attr->ndim[shape_count] = ndim;
334 shape_count++;
335 }
336 attr->shape_count = shape_count;
337 if (reader->NextArrayItem(reader)) {
338 fprintf(stderr, "Invalid json format\n");
339 status = -1;
340 break;
341 }
342 bitmask |= 4;
343 } else if (!strcmp(key, "device_index")) {
344 reader->BeginArray(reader);
345 if (!(reader->NextArrayItem(reader))) {
346 fprintf(stderr, "Invalid json format\n");
347 status = -1;
348 break;
349 }
350 status = reader->ReadString(reader, type, sizeof(type));
351 if (status != 0) {
352 fprintf(stderr, "error reading device_index array item");
353 break;
354 }
355 if (strcmp(type, "list_int")) {
356 fprintf(stderr, "Invalid json format\n");
357 status = -1;
358 break;
359 }
360 if (!(reader->NextArrayItem(reader))) {
361 fprintf(stderr, "Invalid json format\n");
362 status = -1;
363 break;
364 }
365 while (reader->NextArrayItem(reader)) {
366 attr->device_index =
367 vrealloc(attr->device_index, sizeof(uint32_t) * (device_index_count + 1));
368 reader->ReadUnsignedInteger(reader, &(attr->device_index[device_index_count]));
369 device_index_count++;
370 }
371 if (reader->NextArrayItem(reader)) {
372 fprintf(stderr, "Invalid json format\n");
373 status = -1;
374 break;
375 }
376 } else {
377 reader->BeginArray(reader);
378 if (!(reader->NextArrayItem(reader))) {
379 fprintf(stderr, "Invalid json format\n");
380 status = -1;
381 break;
382 }
383 reader->ReadString(reader, type, sizeof(type));
384 if (!strcmp(type, "list_int")) {
385 if (!(reader->NextArrayItem(reader))) {
386 fprintf(stderr, "Invalid json format\n");
387 status = -1;
388 break;
389 }
390 uint32_t* temp = 0;
391 uint32_t temp_count = 0;
392 reader->BeginArray(reader);
393 while (reader->NextArrayItem(reader)) {
394 temp = vrealloc(temp, sizeof(uint32_t) * (temp_count + 1));
395 reader->ReadUnsignedInteger(reader, &(temp[temp_count]));
396 temp_count++;
397 }
398 if (temp) {
399 vfree(temp);
400 temp = 0;
401 }
402 } else if (!strcmp(type, "size_t")) {
403 if (!(reader->NextArrayItem(reader))) {
404 fprintf(stderr, "Invalid json format\n");
405 status = -1;
406 break;
407 }
408 uint32_t temp;
409 reader->ReadUnsignedInteger(reader, &temp);
410 } else {
411 fprintf(stderr, "cannot skip graph attr %s", key);
412 status = -1;
413 break;
414 }
415 if (reader->NextArrayItem(reader)) {
416 fprintf(stderr, "Invalid json format\n");
417 status = -1;
418 break;
419 }
420 }
421 }
422 if (bitmask != (1 | 2 | 4)) {
423 fprintf(stderr, "invalid format\n");
424 status = -1;
425 }
426 return status;
427 }
428
TVMGraphRuntimeGraphAttr_Release(TVMGraphRuntimeGraphAttr * attr)429 void TVMGraphRuntimeGraphAttr_Release(TVMGraphRuntimeGraphAttr* attr) {
430 if (!attr) {
431 return;
432 }
433 if (attr->storage_id) {
434 vfree(attr->storage_id);
435 attr->storage_id = 0;
436 }
437 if (attr->device_index) {
438 vfree(attr->device_index);
439 attr->device_index = 0;
440 }
441 if (attr->dltype) {
442 vfree(attr->dltype);
443 attr->dltype = 0;
444 }
445 if (attr->shape) {
446 vfree(attr->shape);
447 attr->shape = 0;
448 }
449 if (attr->ndim) {
450 vfree(attr->ndim);
451 attr->ndim = 0;
452 }
453 }
454
TVMGraphRuntime_Load(TVMGraphRuntime * runtime,JSONReader * reader)455 int TVMGraphRuntime_Load(TVMGraphRuntime* runtime, JSONReader* reader) {
456 int status = 0;
457 reader->BeginObject(reader);
458 int bitmask = 0;
459 char key[20];
460 while (reader->NextObjectItem(reader, key, sizeof(key))) {
461 if (!strcmp(key, "nodes")) {
462 reader->BeginArray(reader);
463 while (reader->NextArrayItem(reader)) {
464 runtime->nodes =
465 vrealloc(runtime->nodes, sizeof(TVMGraphRuntimeNode) * (runtime->nodes_count + 1));
466 TVMGraphRuntimeNode* node = runtime->nodes + runtime->nodes_count;
467 status = TVMGraphRuntimeNode_Load(node, reader);
468 if (status != 0) {
469 fprintf(stderr, "failed to load an element in `nodes` field in graph runtime node.\n");
470 break;
471 #if TVM_CRT_DEBUG
472 } else {
473 printf("loading: node (%u) %s loaded.\n", runtime->nodes_count, node->name);
474 #endif // TVM_CRT_DEBUG
475 }
476 runtime->nodes_count++;
477 }
478 bitmask |= 1;
479 } else if (!strcmp(key, "arg_nodes")) {
480 reader->BeginArray(reader);
481 while (reader->NextArrayItem(reader)) {
482 runtime->input_nodes =
483 vrealloc(runtime->input_nodes, sizeof(uint32_t) * (runtime->input_nodes_count + 1));
484 uint32_t* node = runtime->input_nodes + runtime->input_nodes_count;
485 reader->ReadUnsignedInteger(reader, node);
486 runtime->input_nodes_count++;
487 }
488 bitmask |= 2;
489 } else if (!strcmp(key, "node_row_ptr")) {
490 reader->BeginArray(reader);
491 while (reader->NextArrayItem(reader)) {
492 runtime->node_row_ptr =
493 vrealloc(runtime->node_row_ptr, sizeof(uint32_t) * (runtime->node_row_ptr_count + 1));
494 uint32_t count = runtime->node_row_ptr_count;
495 uint32_t* node = runtime->node_row_ptr + count;
496 reader->ReadUnsignedInteger(reader, node);
497 runtime->node_row_ptr_count++;
498 }
499 bitmask |= 4;
500 } else if (!strcmp(key, "heads")) {
501 reader->BeginArray(reader);
502 while (reader->NextArrayItem(reader)) {
503 runtime->outputs = vrealloc(
504 runtime->outputs, sizeof(TVMGraphRuntimeNodeEntry) * (runtime->outputs_count + 1));
505 TVMGraphRuntimeNodeEntry* entry = runtime->outputs + runtime->outputs_count;
506 status = NodeEntry_Load(entry, reader);
507 if (status != 0) {
508 fprintf(stderr, "Fail to load an element in `heads` field in graph runtime node.\n");
509 break;
510 }
511 runtime->outputs_count++;
512 }
513 bitmask |= 8;
514 } else if (!strcmp(key, "attrs")) {
515 status = TVMGraphRuntimeGraphAttr_Load(&(runtime->attrs), reader);
516 if (status != 0) {
517 fprintf(stderr, "Fail to load an element in `heads` field in graph runtime node.\n");
518 break;
519 }
520 bitmask |= 16;
521 } else if (!strcmp(key, "metadata")) {
522 break;
523 } else {
524 fprintf(stderr, "key %s is not supported\n", key);
525 status = -1;
526 }
527 if (status != 0) {
528 break;
529 }
530 }
531 if (!(bitmask == (1 | 2 | 4 | 8 | 16))) {
532 fprintf(stderr, "invalid format\n");
533 status = -1;
534 }
535 return status;
536 }
537
TVMGraphRuntime_GetEntryId(TVMGraphRuntime * runtime,uint32_t nid,uint32_t index)538 uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime* runtime, uint32_t nid, uint32_t index) {
539 return runtime->node_row_ptr[nid] + index;
540 }
541
542 /*!
543 * \brief Get the input index given the name of input.
544 * \param runtime The graph runtime.
545 * \param name The name of the input.
546 * \return The index of input.
547 */
TVMGraphRuntime_GetInputIndex(TVMGraphRuntime * runtime,const char * name)548 int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime* runtime, const char* name) {
549 uint32_t i;
550 int32_t rv = -1;
551 for (i = 0; i < runtime->input_nodes_count; ++i) {
552 uint32_t nid = runtime->input_nodes[i];
553 if (!strcmp(runtime->nodes[nid].name, name)) {
554 rv = i;
555 break;
556 }
557 }
558 CHECK_GE(rv, 0, "cannot find '%s' among input.", name);
559 return rv;
560 }
561
562 /*!
563 * \brief set input to the graph based on name.
564 * \param runtime The graph runtime.
565 * \param name The name of the input.
566 * \param data_in The input data.
567 */
TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime,const char * name,DLTensor * data_in)568 void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in) {
569 uint32_t index = TVMGraphRuntime_GetInputIndex(runtime, name);
570 if (index >= runtime->input_nodes_count) {
571 fprintf(stderr, "given index is greater than num of input nodes.\n");
572 }
573 uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, runtime->input_nodes[index], 0);
574 runtime->data_entry[eid].dl_tensor.data = data_in->data;
575 }
576
577 /*!
578 * \brief Load parameters from parameter blob.
579 * \param runtime The graph runtime.
580 * \param param_blob A binary blob of parameter.
581 * \param param_size The parameter size.
582 * \return The result of this function execution.
583 */
TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime,const char * param_blob,const uint32_t param_size)584 int TVMGraphRuntime_LoadParams(TVMGraphRuntime* runtime, const char* param_blob,
585 const uint32_t param_size) {
586 int status = 0;
587 const char* bptr = param_blob;
588 uint64_t header, reserved;
589 header = ((uint64_t*)bptr)[0]; // NOLINT(*)
590 bptr += sizeof(header);
591 if (header != kTVMNDArrayListMagic) {
592 fprintf(stderr, "Invalid parameters file format");
593 status = -1;
594 }
595 reserved = ((uint64_t*)bptr)[0]; // NOLINT(*)
596 bptr += sizeof(reserved);
597
598 // read names
599 char* names = vmalloc(TVM_CRT_STRLEN_NAME * runtime->nodes_count);
600 memset(names, 0, TVM_CRT_STRLEN_NAME * runtime->nodes_count);
601 uint64_t names_count;
602 int idx;
603 names_count = ((uint64_t*)bptr)[0]; // NOLINT(*)
604 bptr += sizeof(names_count);
605 for (idx = 0; idx < names_count; idx++) {
606 uint64_t name_length;
607 name_length = ((uint64_t*)bptr)[0]; // NOLINT(*)
608 bptr += sizeof(name_length);
609 if (name_length >= TVM_CRT_STRLEN_NAME) {
610 fprintf(stderr, "Error: function name longer than expected.\n");
611 status = -1;
612 }
613 memcpy(names + TVM_CRT_STRLEN_NAME * idx, bptr, name_length);
614 bptr += name_length;
615 }
616
617 // read sizes
618 uint64_t sz;
619 sz = ((uint64_t*)bptr)[0]; // NOLINT(*)
620 bptr += sizeof(sz);
621 uint32_t size = sz;
622 if (size != names_count) {
623 fprintf(stderr, "Invalid parameters file format\n");
624 status = -1;
625 }
626
627 for (idx = 0; idx < size; idx++) {
628 int32_t in_idx = TVMGraphRuntime_GetInputIndex(runtime, names + TVM_CRT_STRLEN_NAME * idx);
629 CHECK_GT(in_idx, 0, "Found param for non-existent input: %s\n",
630 names + TVM_CRT_STRLEN_NAME * idx);
631 uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, runtime->input_nodes[in_idx], 0);
632 if (!(eid < runtime->data_entry_count)) {
633 fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", eid,
634 runtime->data_entry_count);
635 status = -1;
636 }
637
638 if (runtime->data_entry[eid].dl_tensor.shape) {
639 vfree(runtime->data_entry[eid].dl_tensor.shape);
640 runtime->data_entry[eid].dl_tensor.shape = 0;
641 }
642 if (runtime->data_entry[eid].dl_tensor.data) {
643 vfree(runtime->data_entry[eid].dl_tensor.data);
644 runtime->data_entry[eid].dl_tensor.data = 0;
645 }
646 status |= TVMNDArray_Load(&(runtime->data_entry[eid]), &bptr);
647 #if TVM_CRT_DEBUG
648 TVMNDArray* entry = &(runtime->data_entry[eid]);
649 printf("loading: param %s loaded, in_idx=%d, eid=%d, ndim=%d, data[0]=%f\n",
650 names + TVM_CRT_STRLEN_NAME * idx, in_idx, eid, entry->dl_tensor.ndim,
651 ((float*)entry->dl_tensor.data)[0]); // NOLINT(*)
652 #endif // TVM_CRT_DEBUG
653 }
654
655 // Release memory
656 vfree(names);
657
658 return status;
659 }
660
661 /*!
662 * \brief Run all the operations one by one.
663 * \param runtime The graph runtime.
664 */
TVMGraphRuntime_Run(TVMGraphRuntime * runtime)665 void TVMGraphRuntime_Run(TVMGraphRuntime* runtime) {
666 // setup the array and requirements.
667 uint32_t idx;
668 for (idx = 0; idx < runtime->op_execs_count; ++idx) {
669 if (runtime->op_execs[idx].fexec) {
670 #if TVM_CRT_DEBUG
671 printf("calling: %s (%d)\n", runtime->op_execs[idx].name, idx);
672 #endif // TVM_CRT_DEBUG
673 runtime->op_execs[idx].Call(&(runtime->op_execs[idx]));
674 }
675 }
676 }
677
TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime,const int32_t idx,DLTensor * out)678 int TVMGraphRuntime_GetOutput(TVMGraphRuntime* runtime, const int32_t idx, DLTensor* out) {
679 int status = 0;
680 uint32_t nid = runtime->outputs[idx].node_id;
681 uint32_t index = runtime->outputs[idx].index;
682 uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, nid, index);
683
684 // copy data section to allocated output tensor
685 int32_t elem_bytes = out->dtype.bits / 8;
686 int64_t size = Shape_Accumulate(out->shape, out->ndim);
687 DLTensor* tensor = &(runtime->data_entry[eid].dl_tensor);
688 CHECK(out->ndim == tensor->ndim);
689 CHECK(out->dtype.bits == tensor->dtype.bits);
690 CHECK(Shape_Accumulate(out->shape, out->ndim) == Shape_Accumulate(tensor->shape, tensor->ndim));
691 memcpy(out->data, tensor->data, size * elem_bytes);
692 return status;
693 }
694
TVMGraphRuntime_SetupStorage(TVMGraphRuntime * runtime)695 void TVMGraphRuntime_SetupStorage(TVMGraphRuntime* runtime) {
696 uint32_t idx;
697
698 // Grab saved optimization plan from graph.
699 TVMGraphRuntimeGraphAttr* attrs = &(runtime->attrs);
700 DLDataType* vtype = vmalloc(sizeof(DLDataType) * attrs->dltype_count);
701 for (idx = 0; idx < attrs->dltype_count; idx++) {
702 vtype[idx] = String2DLDataType(attrs->dltype + idx * TVM_CRT_STRLEN_DLTYPE);
703 }
704
705 // Size and device type of each storage pool entry.
706 TVMGraphRuntimePoolEntry* pool_entry =
707 vmalloc(sizeof(TVMGraphRuntimePoolEntry) * runtime->nodes_count);
708 memset(pool_entry, 0, sizeof(TVMGraphRuntimePoolEntry) * runtime->nodes_count);
709 uint32_t pool_entry_count = 0;
710 // Find the maximum space size.
711 for (idx = 0; idx < attrs->shape_count; idx++) {
712 int storage_id = attrs->storage_id[idx];
713 // Use the fallback device if no device index is available.
714 int device_type = runtime->ctxs[0].device_type;
715 uint32_t size = Shape_Accumulate(attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx]);
716 DLDataType t = vtype[idx];
717 uint32_t bits = t.bits * t.lanes;
718 size_t bytes = ((bits + 7U) / 8U) * size;
719
720 uint32_t sid = storage_id;
721 if (sid >= pool_entry_count) {
722 pool_entry_count = sid + 1;
723 }
724 pool_entry[sid].size = MAX(pool_entry[sid].size, bytes);
725 pool_entry[sid].device_type = device_type;
726 }
727
728 // Allocate the space.
729 for (idx = 0; idx < pool_entry_count; idx++) {
730 runtime->storage_pool =
731 vrealloc(runtime->storage_pool, sizeof(TVMNDArray) * (runtime->storage_pool_count + 1));
732 TVMGraphRuntimePoolEntry pit = pool_entry[idx];
733 int64_t shape[TVM_CRT_MAX_NDIM] = {
734 0,
735 };
736 TVMContext ctx = runtime->ctxs[0];
737 DLDataType dtype = {kDLFloat, 32, 1};
738 shape[0] = (pit.size + 3) / 4;
739 runtime->storage_pool[runtime->storage_pool_count] = TVMNDArray_Empty(1, shape, dtype, ctx);
740 CHECK_NE(runtime->storage_pool[runtime->storage_pool_count].dl_tensor.data, 0,
741 "fail to create storage_pool with idx=%d\n", idx);
742 runtime->storage_pool_count++;
743 }
744
745 // Assign the pooled entries. A unified memory pool is used to simplifiy
746 // memory assignment for each node entry. The allocated memory on each device
747 // is mapped to this pool.
748 runtime->data_entry_count = runtime->node_row_ptr[runtime->node_row_ptr_count - 1];
749 runtime->data_entry = vmalloc(sizeof(TVMNDArray) * runtime->data_entry_count);
750 for (idx = 0; idx < runtime->data_entry_count; ++idx) {
751 uint32_t storage_id = attrs->storage_id[idx];
752 CHECK(storage_id < runtime->storage_pool_count);
753 runtime->data_entry[idx] =
754 TVMNDArray_CreateView(&(runtime->storage_pool[storage_id]),
755 attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx], vtype[idx]);
756 CHECK_NE(runtime->data_entry[idx].dl_tensor.data, 0,
757 "fail to create for node with idx=%d, storage_id=%u\n", idx, storage_id);
758 }
759
760 // Release memory
761 vfree(vtype);
762 vfree(pool_entry);
763 }
764
TVMGraphRuntime_SetupOpExecs(TVMGraphRuntime * runtime)765 int TVMGraphRuntime_SetupOpExecs(TVMGraphRuntime* runtime) {
766 int status = 0;
767 uint32_t nid, idx;
768 runtime->op_execs_count = runtime->nodes_count;
769 runtime->op_execs = vmalloc(sizeof(TVMPackedFunc) * runtime->op_execs_count);
770 for (nid = 0; nid < runtime->nodes_count; nid++) {
771 const TVMGraphRuntimeNode* inode = runtime->nodes + nid;
772 if (strcmp(inode->op_type, "null")) {
773 DLTensorPtr args[TVM_CRT_MAX_ARGS];
774 uint32_t args_count = 0;
775 for (idx = 0; idx < inode->inputs_count; idx++) {
776 const TVMGraphRuntimeNodeEntry* entry = inode->inputs + idx;
777 uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, entry->node_id, entry->index);
778 args[idx] = &(runtime->data_entry[eid].dl_tensor);
779 args_count++;
780 }
781 for (idx = 0; idx < inode->param.num_outputs; idx++) {
782 uint32_t eid = TVMGraphRuntime_GetEntryId(runtime, nid, idx);
783 args[args_count] = &(runtime->data_entry[eid].dl_tensor);
784 args_count++;
785 }
786 if (strcmp(inode->op_type, "tvm_op")) {
787 fprintf(stderr, "Can only take tvm_op as op, but \"%s\" is found.\n", inode->op_type);
788 status = -1;
789 break;
790 }
791 if (args_count >= TVM_CRT_MAX_ARGS) {
792 fprintf(stderr, "too many arguments: expected less than %d args, but got %d.\n",
793 TVM_CRT_MAX_ARGS, args_count);
794 status = -1;
795 break;
796 }
797 #if TVM_CRT_DEBUG
798 printf("tvm_op: creating %s with node_id=%d\n", inode->param.func_name, nid);
799 #endif // TVM_CRT_DEBUG
800 TVMPackedFunc pf;
801 TVMGraphRuntime_CreateTVMOp(runtime, &(inode->param), args, args_count, inode->inputs_count,
802 &pf);
803 runtime->op_execs[nid] = pf;
804 }
805 }
806 return status;
807 }
808
809 typedef struct TVMOpArgs {
810 DLTensor args[TVM_CRT_MAX_ARGS];
811 uint32_t args_count;
812 TVMValue arg_values[TVM_CRT_MAX_ARGS];
813 uint32_t arg_values_count;
814 uint32_t arg_tcodes[TVM_CRT_MAX_ARGS];
815 uint32_t arg_tcodes_count;
816 int64_t shape_data[TVM_CRT_MAX_ARGS];
817 uint32_t shape_data_count;
818 } TVMOpArgs;
819
TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime,const TVMOpParam * param,DLTensorPtr * args,const uint32_t args_count,uint32_t num_inputs,TVMPackedFunc * pf)820 int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime* runtime, const TVMOpParam* param,
821 DLTensorPtr* args, const uint32_t args_count,
822 uint32_t num_inputs, TVMPackedFunc* pf) {
823 int status = 0;
824 uint32_t idx;
825 TVMOpArgs arg_ptr;
826 memset(&arg_ptr, 0, sizeof(TVMOpArgs));
827 arg_ptr.args_count = args_count;
828 if (param->flatten_data) {
829 arg_ptr.shape_data_count = arg_ptr.args_count;
830 }
831 for (idx = 0; idx < arg_ptr.args_count; ++idx) {
832 TVMValue v;
833 memset(&v, 0, sizeof(v));
834 DLTensor* t = &(arg_ptr.args[idx]);
835 /* v.v_handle = &((*args)[idx]); */
836 v.v_handle = args[idx];
837 arg_ptr.arg_values[idx] = v;
838 arg_ptr.arg_values_count++;
839 arg_ptr.arg_tcodes[idx] = kTVMNDArrayHandle;
840 arg_ptr.arg_tcodes_count++;
841 if (param->flatten_data) {
842 arg_ptr.shape_data[idx] = Shape_Accumulate(t->shape, t->ndim);
843 t->ndim = 1;
844 t->shape[0] = arg_ptr.shape_data[idx];
845 }
846 }
847 if (!strcmp(param->func_name, "__nop") || !strcmp(param->func_name, "__copy")) {
848 fprintf(stderr, "%s function is not yet supported.", param->func_name);
849 status = -1;
850 }
851
852 TVMArgs targs = TVMArgs_Create(arg_ptr.arg_values, arg_ptr.arg_tcodes, arg_ptr.arg_values_count);
853 status = TVMPackedFunc_InitModuleFunc(pf, runtime->module_handle, param->func_name, &targs);
854
855 return status;
856 }
857
858 /*!
859 * \brief Initialize the graph executor with graph and context.
860 * \param graph_json The execution graph.
861 * \param module The module containing the compiled functions for the host
862 * processor.
863 * \param ctxs The context of the host and devices where graph nodes will be
864 * executed on.
865 */
TVMGraphRuntime_Init(TVMGraphRuntime * runtime,const char * graph_json,const TVMModule * module,const TVMContext * ctxs)866 void TVMGraphRuntime_Init(TVMGraphRuntime* runtime, const char* graph_json, const TVMModule* module,
867 const TVMContext* ctxs) {
868 JSONReader reader = JSONReader_Create(graph_json);
869 TVMGraphRuntime_Load(runtime, &reader);
870 JSONReader_Release(&reader);
871 runtime->ctxs[0] = ctxs[0];
872 TVMGraphRuntime_SetupStorage(runtime);
873 TVMGraphRuntime_SetupOpExecs(runtime);
874 }
875
TVMGraphRuntime_Create(const char * sym_json,const TVMModule * m,const TVMContext * ctxs)876 TVMGraphRuntime* TVMGraphRuntime_Create(const char* sym_json, const TVMModule* m,
877 const TVMContext* ctxs) {
878 CHECK_EQ(vleak_size, 1, "memory leak checking won't work with concurrent CRT use");
879 TVMGraphRuntime* runtime = (TVMGraphRuntime*)vmalloc(sizeof(TVMGraphRuntime)); // NOLINT(*)
880 memset(runtime, 0, sizeof(TVMGraphRuntime));
881 // init
882 TVMGraphRuntime_Init(runtime, sym_json, m, ctxs);
883 return runtime;
884 }
885
TVMGraphRuntime_Release(TVMGraphRuntime ** pptr)886 void TVMGraphRuntime_Release(TVMGraphRuntime** pptr) {
887 int32_t idx;
888 TVMGraphRuntime* runtime = (TVMGraphRuntime*)(*pptr);
889 for (idx = 0; idx < runtime->nodes_count; ++idx) {
890 TVMGraphRuntimeNodeRelease(&(runtime->nodes[idx]));
891 }
892 vfree(runtime->nodes);
893 TVMGraphRuntimeGraphAttr_Release(&(runtime->attrs));
894 for (idx = 0; idx < runtime->storage_pool_count; ++idx) {
895 TVMNDArray_Release(&(runtime->storage_pool[idx]));
896 }
897 for (idx = 0; idx < runtime->data_entry_count; ++idx) {
898 vfree(runtime->data_entry[idx].dl_tensor.shape);
899 }
900 vfree(runtime->input_nodes);
901 vfree(runtime->node_row_ptr);
902 vfree(runtime->outputs);
903 vfree(runtime->storage_pool);
904 vfree(runtime->data_entry);
905 vfree(runtime->op_execs);
906 vfree(*pptr);
907
908 if (g_fexecs) {
909 vfree(g_fexecs);
910 g_fexecs = 0;
911 }
912
913 CHECK_EQ(vleak_size, 1, "found memory leak, leak size=%d", vleak_size - 1);
914 }
915