1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file relay/backend/graph_codegen.cc
22 * \brief Graph runtime codegen
23 */
24
25 #include <dmlc/any.h>
26 #include <dmlc/json.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/runtime/device_api.h>
29
30
31 #include <list>
32 #include <string>
33 #include <vector>
34
35 #include "utils.h"
36 #include "compile_engine.h"
37
38 namespace tvm {
39 namespace relay {
40 namespace backend {
41
42 class GraphNode;
43 class GraphInputNode;
44 class GraphOpNode;
45
46 using IntegerArray = Array<Integer>;
47 using ShapeVector = std::vector<std::vector<int64_t> >;
48 using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
49 using GraphNodePtr = std::shared_ptr<GraphNode>;
50 using GraphInputNodePtr = std::shared_ptr<GraphInputNode>;
51 using GraphOpNodePtr = std::shared_ptr<GraphOpNode>;
52 using TargetsMap = std::unordered_map<int, Target>;
53
54 /*! \brief Lowered outputs */
55 struct LoweredOutput {
56 std::string graph_json;
57 Map<std::string, Array<LoweredFunc> > lowered_funcs;
58 std::unordered_map<std::string, tvm::runtime::NDArray> params;
59 };
60
61 /*! \brief Node types */
62 enum GraphNodeType {
63 kGraphNop,
64 kGraphInputNode,
65 kGraphOpNode,
66 };
67
68 class GraphNodeRef {
69 public:
GraphNodeRef()70 GraphNodeRef() {}
GraphNodeRef(int ident,int index,int version=0)71 GraphNodeRef(int ident, int index, int version = 0)
72 : ident_(ident), index_(index), version_(version) {}
73
74
Save(dmlc::JSONWriter * writer) const75 inline void Save(dmlc::JSONWriter* writer) const {
76 writer->BeginArray();
77 writer->WriteArrayItem(ident_);
78 writer->WriteArrayItem(index_);
79 writer->WriteArrayItem(version_);
80 writer->EndArray();
81 }
82
Load(dmlc::JSONReader * reader)83 inline void Load(dmlc::JSONReader* reader) {
84 LOG(FATAL) << "Not implemented.";
85 }
86
87 protected:
88 int ident_;
89 int index_{0};
90 int version_{0};
91 };
92
93 /*! \brief Base Node class */
94 class GraphNode {
95 public:
GraphNode()96 GraphNode() {}
Save(dmlc::JSONWriter * writer) const97 virtual void Save(dmlc::JSONWriter* writer) const {}
Load(dmlc::JSONReader * reader)98 virtual void Load(dmlc::JSONReader* reader) {}
Type() const99 virtual GraphNodeType Type() const { return kGraphNop; }
~GraphNode()100 virtual ~GraphNode() {}
101
102 public:
103 int num_outputs_{1};
104 std::string name_;
105 GraphAttrs attrs_;
106 };
107
108 /*! \brief Input Node */
109 class GraphInputNode : public GraphNode {
110 public:
GraphInputNode()111 GraphInputNode() {}
GraphInputNode(const std::string & name,const GraphAttrs & attrs)112 GraphInputNode(const std::string& name, const GraphAttrs& attrs) {
113 name_ = name;
114 attrs_ = attrs;
115 }
116
Type() const117 GraphNodeType Type() const override { return kGraphInputNode; }
118
Save(dmlc::JSONWriter * writer) const119 void Save(dmlc::JSONWriter* writer) const override {
120 const std::string op_name{"null"};
121 writer->BeginObject();
122 writer->WriteObjectKeyValue("op", op_name);
123 writer->WriteObjectKeyValue("name", this->name_);
124 writer->WriteObjectKeyValue("inputs", std::list<int>());
125 writer->EndObject();
126 }
make_node_ptr(const std::string & name,const GraphAttrs & attrs)127 static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
128 const GraphAttrs& attrs) {
129 auto ptr = std::make_shared<GraphInputNode>(name, attrs);
130 return std::dynamic_pointer_cast<GraphNode>(ptr);
131 }
132 };
133
134 /*! \brief Op Node */
135 class GraphOpNode : public GraphNode {
136 public:
GraphOpNode()137 GraphOpNode() {}
GraphOpNode(const std::string & name,const GraphAttrs & nd_attrs,const std::string & op_name,const std::vector<GraphNodeRef> & inputs,const GraphAttrs & attrs,size_t num_outputs=1)138 GraphOpNode(const std::string& name,
139 const GraphAttrs& nd_attrs,
140 const std::string& op_name,
141 const std::vector<GraphNodeRef>& inputs,
142 const GraphAttrs& attrs,
143 size_t num_outputs = 1) {
144 name_ = name;
145 attrs_ = nd_attrs;
146 op_name_ = op_name;
147 inputs_ = inputs;
148 op_attrs_ = attrs_;
149 num_outputs_ = num_outputs;
150 op_attrs_["func_name"] = op_name_;
151 op_attrs_["flatten_data"] = std::string("0");
152 op_attrs_["num_inputs"] = std::to_string(inputs_.size());
153 op_attrs_["num_outputs"] = std::to_string(num_outputs_);
154 }
155
Type() const156 GraphNodeType Type() const override { return kGraphOpNode; }
157
Save(dmlc::JSONWriter * writer) const158 void Save(dmlc::JSONWriter* writer) const override {
159 GraphAttrs attrs = op_attrs_;
160 attrs["func_name"] = this->op_name_;
161 attrs["flatten_data"] = std::string("0");
162 attrs["num_inputs"] = std::to_string(this->inputs_.size());
163 attrs["num_outputs"] = std::to_string(this->num_outputs_);
164 writer->BeginObject();
165 writer->WriteObjectKeyValue("op", op_type_name_);
166 writer->WriteObjectKeyValue("name", name_);
167 writer->WriteObjectKeyValue("attrs", attrs);
168 writer->WriteObjectKeyValue("inputs", this->inputs_);
169 writer->EndObject();
170 }
make_node_ptr(const std::string & name,const GraphAttrs & nd_attrs,const std::string & op_name,const std::vector<GraphNodeRef> & inputs,const GraphAttrs & attrs,size_t num_outputs=1)171 static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
172 const GraphAttrs& nd_attrs,
173 const std::string& op_name,
174 const std::vector<GraphNodeRef>& inputs,
175 const GraphAttrs& attrs,
176 size_t num_outputs = 1) {
177 auto ptr = std::make_shared<GraphOpNode>(name, nd_attrs, op_name, inputs, attrs, num_outputs);
178 return std::dynamic_pointer_cast<GraphNode>(ptr);
179 }
180
181 public:
182 std::string op_name_;
183 std::vector<GraphNodeRef> inputs_;
184 GraphAttrs op_attrs_;
185
186 private:
187 const std::string op_type_name_{"tvm_op"};
188 };
189
190 /*! \brief Code generator for graph runtime */
191 class GraphRuntimeCodegen
192 : public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
193 public:
GraphRuntimeCodegen(runtime::Module * mod,const TargetsMap & targets)194 GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
195 : mod_(mod) {
196 compile_engine_ = CompileEngine::Global();
197 targets_ = targets;
198 }
199
Codegen(relay::Function func)200 LoweredOutput Codegen(relay::Function func) {
201 auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
202 storage_device_map_ = (*pf)(func);
203 // First we convert all the parameters into input nodes.
204 for (auto param : func->params) {
205 auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
206 var_map_[param.get()] = AddNode(node_ptr, param);
207 }
208 heads_ = VisitExpr(func->body);
209 std::ostringstream os;
210 dmlc::JSONWriter writer(&os);
211 GetJSON(&writer);
212 LoweredOutput ret;
213 ret.graph_json = os.str();
214 ret.params = params_;
215 for (auto& kv : lowered_funcs_) {
216 if (ret.lowered_funcs.count(kv.first) == 0) {
217 ret.lowered_funcs.Set(kv.first, Array<LoweredFunc>());
218 }
219 auto& vec = ret.lowered_funcs[kv.first];
220 Array<LoweredFunc> tmp;
221 for (auto f : kv.second) {
222 tmp.push_back(f);
223 }
224 for (auto f : vec) {
225 tmp.push_back(f);
226 }
227 ret.lowered_funcs.Set(kv.first, tmp);
228 }
229 return ret;
230 }
231
232 protected:
233 /*!
234 * \brief Extract shape from expr to vector<int64_t>
235 *
236 * \param shape
237 * \return std::vector<int64_t>
238 */
_ShapeToJSON(tvm::Array<IndexExpr> shape)239 std::vector<int64_t> _ShapeToJSON(tvm::Array<IndexExpr> shape) {
240 std::vector<int64_t> ret;
241 for (IndexExpr dim : shape) {
242 const int64_t* pval = as_const_int(dim);
243 ret.push_back(*pval);
244 }
245 return ret;
246 }
247
248 /*!
249 * \brief Add node to graph
250 *
251 * \param node
252 * \param expr
253 * \return std::vector<_NodeRef>
254 */
AddNode(GraphNodePtr node,Expr expr)255 std::vector<GraphNodeRef> AddNode(GraphNodePtr node, Expr expr) {
256 auto checked_type = expr->checked_type();
257 size_t count = storage_device_map_.count(expr);
258 CHECK_GT(count, 0) << "Expr is not existing in storage plan";
259 auto storage_device_info = storage_device_map_[expr];
260 CHECK_EQ(storage_device_info.size(), 2);
261 // storage
262 std::vector<int64_t> storage_info;
263 for (auto& v : storage_device_info[0]) {
264 storage_info.push_back(v->value);
265 }
266 node->attrs_["storage_id"] = std::move(storage_info);
267 // type
268 std::vector<int64_t> device_types;
269 for (auto& v : storage_device_info[1]) {
270 device_types.push_back(v->value);
271 }
272 size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0);
273 if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) {
274 LOG(FATAL) << "The graph contains not annotated nodes for "
275 << "heterogeneous execution. All nodes must be "
276 << "annotated.";
277 }
278 if (num_unknown_devices == 0) {
279 node->attrs_["device_index"] = device_types;
280 }
281 auto node_id = nodes_.size();
282 nodes_.push_back(node);
283 // Tuple return value, flatten as tuple
284 if (const auto* tuple_type = checked_type.as<TupleTypeNode>()) {
285 std::vector<GraphNodeRef> ret;
286 ShapeVector shape;
287 std::vector<std::string> dtype;
288 for (size_t i = 0; i < tuple_type->fields.size(); ++i) {
289 if (const auto* typ = tuple_type->fields[i].as<TensorTypeNode>()) {
290 ret.push_back(GraphNodeRef(node_id, i));
291 shape.emplace_back(_ShapeToJSON(typ->shape));
292 dtype.emplace_back(DType2String(typ->dtype));
293 } else {
294 LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported";
295 }
296 }
297 CHECK_EQ(node->Type(), kGraphOpNode);
298 auto op_nd = std::dynamic_pointer_cast<GraphOpNode>(node);
299 op_nd->attrs_["shape"] = shape;
300 op_nd->attrs_["dtype"] = dtype;
301 op_nd->num_outputs_ = tuple_type->fields.size();
302 return ret;
303 }
304 // Normal tensor return type
305 if (const auto* tensor_type = checked_type.as<TensorTypeNode>()) {
306 ShapeVector shape;
307 std::vector<std::string> dtype;
308 shape.emplace_back(_ShapeToJSON(tensor_type->shape));
309 dtype.emplace_back(DType2String(tensor_type->dtype));
310 node->attrs_["shape"] = shape;
311 node->attrs_["dtype"] = dtype;
312 } else {
313 LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported";
314 }
315 return {GraphNodeRef(node_id, 0)};
316 }
317
318 /*! \brief Visitors */
319 std::unordered_map<Expr, std::vector<GraphNodeRef>, NodeHash, NodeEqual> visitor_cache_;
320
VisitExpr(const Expr & expr)321 std::vector<GraphNodeRef> VisitExpr(const Expr& expr) override {
322 if (visitor_cache_.count(expr)) return visitor_cache_.at(expr);
323 std::vector<GraphNodeRef> res;
324 if (expr.as<ConstantNode>()) {
325 res = VisitExpr_(expr.as<ConstantNode>());
326 } else if (expr.as<TupleNode>()) {
327 res = VisitExpr_(expr.as<TupleNode>());
328 } else if (expr.as<VarNode>()) {
329 res = VisitExpr_(expr.as<VarNode>());
330 } else if (expr.as<GlobalVarNode>()) {
331 res = VisitExpr_(expr.as<GlobalVarNode>());
332 } else if (expr.as<FunctionNode>()) {
333 res = VisitExpr_(expr.as<FunctionNode>());
334 } else if (expr.as<CallNode>()) {
335 res = VisitExpr_(expr.as<CallNode>());
336 } else if (expr.as<LetNode>()) {
337 res = VisitExpr_(expr.as<LetNode>());
338 } else if (expr.as<IfNode>()) {
339 res = VisitExpr_(expr.as<IfNode>());
340 } else if (expr.as<OpNode>()) {
341 res = VisitExpr_(expr.as<OpNode>());
342 } else if (expr.as<TupleGetItemNode>()) {
343 res = VisitExpr_(expr.as<TupleGetItemNode>());
344 } else if (expr.as<RefCreateNode>()) {
345 res = VisitExpr_(expr.as<RefCreateNode>());
346 } else if (expr.as<RefReadNode>()) {
347 res = VisitExpr_(expr.as<RefReadNode>());
348 } else if (expr.as<RefWriteNode>()) {
349 res = VisitExpr_(expr.as<RefWriteNode>());
350 } else if (expr.as<ConstructorNode>()) {
351 res = VisitExpr_(expr.as<ConstructorNode>());
352 } else if (expr.as<MatchNode>()) {
353 res = VisitExpr_(expr.as<MatchNode>());
354 }
355 visitor_cache_[expr] = res;
356 return res;
357 }
358
VisitExpr_(const VarNode * op)359 std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
360 Expr expr = GetRef<Expr>(op);
361 return var_map_[expr.get()];
362 }
363
VisitExpr_(const ConstantNode * op)364 std::vector<GraphNodeRef> VisitExpr_(const ConstantNode* op) override {
365 Expr expr = GetRef<Expr>(op);
366 size_t index = params_.size();
367 std::string name = "p" + std::to_string(index);
368 params_[name] = op->data;
369 auto node = GraphInputNode::make_node_ptr(name, GraphAttrs());
370 return AddNode(node, expr);
371 }
372
VisitExpr_(const TupleNode * op)373 std::vector<GraphNodeRef> VisitExpr_(const TupleNode* op) override {
374 std::vector<GraphNodeRef> fields;
375 for (auto field : op->fields) {
376 auto ref_vec = VisitExpr(field);
377 for (auto ref : ref_vec) {
378 fields.push_back(ref);
379 }
380 }
381 return fields;
382 }
VisitExpr_(const CallNode * op)383 std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
384 Expr expr = GetRef<Expr>(op);
385 Function func;
386 if (op->op.as<OpNode>()) {
387 LOG(FATAL) << "Operators should be transformed away; try applying"
388 << "the fuse_ops transformation to the expression.";
389 } else if (op->op.as<GlobalVarNode>()) {
390 LOG(FATAL) << "Not implemented";
391 } else if (op->op.as<FunctionNode>()) {
392 func = GetRef<Function>(op->op.as<FunctionNode>());
393 } else {
394 LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
395 }
396 if (!func->IsPrimitive()) {
397 LOG(FATAL) << "TVM only support calls to primitive functions "
398 << "(i.e functions composed of fusable operator invocations)";
399 }
400
401 CHECK_GE(storage_device_map_.count(expr), 0);
402 auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
403 auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
404 auto &device_type = storage_device_map_[expr][1];
405 auto call_dev_type = device_type[0]->value;
406 Target target;
407 if (targets_.size() == 1) {
408 // homogeneous execution.
409 for (auto kv : targets_) {
410 target = kv.second;
411 }
412 } else {
413 // heterogeneous execution.
414 std::string call_dev_name;
415 if (call_dev_type == 0) {
416 call_dev_name = "llvm";
417 } else {
418 call_dev_name = runtime::DeviceName(call_dev_type);
419 }
420 if (targets_.count(call_dev_type) == 0) {
421 LOG(FATAL) << "No target is provided for device "
422 << call_dev_name;
423 }
424 target = targets_[call_dev_type];
425 }
426 CCacheKey key = (*pf0)(func, target);
427 CachedFunc lowerd_func = (*pf1)(compile_engine_, key);
428 if (!lowered_funcs_.count(target->str())) {
429 lowered_funcs_[target->str()] = {};
430 }
431 for (auto f : lowerd_func->funcs) {
432 lowered_funcs_[target->str()].insert(f);
433 }
434
435 std::vector<GraphNodeRef> inputs;
436 for (auto arg : op->args) {
437 auto res = VisitExpr(arg);
438 for (auto nr : res) {
439 inputs.push_back(nr);
440 }
441 }
442 auto& op_name = lowerd_func->func_name;
443 auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name),
444 GraphAttrs(),
445 op_name,
446 inputs,
447 GraphAttrs());
448 return AddNode(node, expr);
449 }
450
VisitExpr_(const LetNode * op)451 std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
452 CHECK_EQ(var_map_.count(op->var.get()), 0);
453 var_map_[op->var.get()] = VisitExpr(op->value);
454 return VisitExpr(op->body);
455 }
VisitExpr_(const TupleGetItemNode * op)456 std::vector<GraphNodeRef> VisitExpr_(const TupleGetItemNode* op) override {
457 auto vtuple = VisitExpr(op->tuple);
458 return {vtuple[op->index]};
459 }
VisitExpr_(const OpNode * op)460 std::vector<GraphNodeRef> VisitExpr_(const OpNode* op) override {
461 throw std::runtime_error("can not compile op in non-eta expanded form");
462 return {};
463 }
VisitExpr_(const GlobalVarNode * op)464 std::vector<GraphNodeRef> VisitExpr_(const GlobalVarNode* op) override {
465 throw std::runtime_error("");
466 return {};
467 }
VisitExpr_(const IfNode * op)468 std::vector<GraphNodeRef> VisitExpr_(const IfNode* op) override {
469 throw std::invalid_argument("if not supported");
470 return {};
471 }
VisitExpr_(const FunctionNode * op)472 std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
473 throw std::invalid_argument("function not supported");
474 return {};
475 }
VisitExpr_(const RefCreateNode * op)476 std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
477 throw std::invalid_argument("reference not supported");
478 return {};
479 }
VisitExpr_(const RefReadNode * op)480 std::vector<GraphNodeRef> VisitExpr_(const RefReadNode* op) override {
481 throw std::invalid_argument("reference not supported");
482 return {};
483 }
VisitExpr_(const RefWriteNode * op)484 std::vector<GraphNodeRef> VisitExpr_(const RefWriteNode* op) override {
485 throw std::invalid_argument("reference not supported");
486 return {};
487 }
VisitExpr_(const ConstructorNode * op)488 std::vector<GraphNodeRef> VisitExpr_(const ConstructorNode* op) override {
489 throw std::invalid_argument("ADT constructor case not yet implemented");
490 return {};
491 }
VisitExpr_(const MatchNode * op)492 std::vector<GraphNodeRef> VisitExpr_(const MatchNode* op) override {
493 throw std::invalid_argument("match case not yet implemented");
494 return {};
495 }
496 /*!
497 * \brief Generate Graph JSON
498 *
499 * \param writer json writer
500 */
GetJSON(dmlc::JSONWriter * writer)501 void GetJSON(dmlc::JSONWriter* writer) {
502 std::vector<size_t> arg_nodes;
503 for (size_t i = 0; i < nodes_.size(); ++i) {
504 auto node = nodes_[i];
505 if (node->Type() == kGraphInputNode) {
506 arg_nodes.push_back(i);
507 }
508 }
509 size_t num_entry = 0;
510 ShapeVector shapes;
511 std::vector<size_t> storage_ids;
512 std::vector<size_t> device_types;
513 std::vector<std::string> dltypes;
514 std::vector<size_t> node_row_ptr{0};
515 for (auto node : nodes_) {
516 const auto& shape_vec = dmlc::get<ShapeVector>(node->attrs_["shape"]);
517 const auto& storage_id = dmlc::get<std::vector<int64_t>>(node->attrs_["storage_id"]);
518 const auto& dtype_vec = dmlc::get<std::vector<std::string>>(node->attrs_["dtype"]);
519
520 CHECK_EQ(node->num_outputs_, shape_vec.size());
521 num_entry += node->num_outputs_;
522
523 shapes.insert(shapes.end(), shape_vec.begin(), shape_vec.end());
524 dltypes.insert(dltypes.end(), dtype_vec.begin(), dtype_vec.end());
525 storage_ids.insert(storage_ids.end(), storage_id.begin(), storage_id.end());
526 if (node->attrs_.count("device_index")) {
527 const auto& dev_types = dmlc::get<std::vector<int64_t>>(node->attrs_["device_index"]);
528 device_types.insert(device_types.end(), dev_types.begin(), dev_types.end());
529 }
530 node_row_ptr.push_back(num_entry);
531 }
532 writer->BeginObject();
533 writer->WriteObjectKeyValue("nodes", nodes_);
534 writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
535 writer->WriteObjectKeyValue("heads", heads_);
536 std::unordered_map<std::string, std::vector<dmlc::any>> attrs;
537 attrs["shape"].emplace_back(std::string("list_shape"));
538 attrs["shape"].emplace_back(shapes);
539 attrs["storage_id"].emplace_back(std::string("list_int"));
540 attrs["storage_id"].emplace_back(storage_ids);
541 if (device_types.size()) {
542 attrs["device_index"].emplace_back(std::string("list_int"));
543 attrs["device_index"].emplace_back(device_types);
544 }
545 attrs["dltype"].emplace_back(std::string("list_str"));
546 attrs["dltype"].emplace_back(dltypes);
547 writer->WriteObjectKeyValue("attrs", attrs);
548 writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
549 writer->EndObject();
550 }
551
552 /*!
553 * \brief Get unique name for func
554 *
555 * \param name
556 * \return std::string
557 */
_GetUniqueName(const std::string & name)558 std::string _GetUniqueName(const std::string& name) {
559 if (!name_map_.count(name)) {
560 name_map_[name] = 1;
561 return name;
562 }
563 auto index = name_map_[name];
564 name_map_[name] += 1;
565 return _GetUniqueName(name + std::to_string(index));
566 }
567
568 protected:
569 /*! \brief nodes */
570 std::vector<GraphNodePtr> nodes_;
571 /*! \brief output of graph */
572 std::vector<GraphNodeRef> heads_;
573 /*! \brief mod */
574 runtime::Module* mod_;
575 /*! \brief variable map */
576 std::unordered_map<const Node*, std::vector<GraphNodeRef>> var_map_;
577 /*! \brief target device */
578 TargetsMap targets_;
579 /*! \brief params */
580 std::unordered_map<std::string, runtime::NDArray> params_;
581 /*! \brief plan memory of device result */
582 Map<Expr, Array<IntegerArray>> storage_device_map_;
583 /*! \brief lowered funcs */
584 std::unordered_map<std::string, std::unordered_set<LoweredFunc, NodeHash, NodeEqual>>
585 lowered_funcs_;
586 /*! \brief name map */
587 std::unordered_map<std::string, size_t> name_map_;
588 /*! \brief compile engine */
589 CompileEngine compile_engine_;
590 };
591
592 class GraphRuntimeCodegenModule : public runtime::ModuleNode {
593 public:
GraphRuntimeCodegenModule()594 GraphRuntimeCodegenModule() {}
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)595 virtual PackedFunc GetFunction(const std::string& name,
596 const ObjectPtr<Object>& sptr_to_self) {
597 if (name == "init") {
598 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
599 CHECK_EQ(args.num_args, 2)
600 << "The expected of arguments are: "
601 << "runtime::Module mod and Map<int, Target> targets";
602 void* mod = args[0];
603 Map<Integer, tvm::Target> tmp = args[1];
604 TargetsMap targets;
605 for (const auto& it : tmp) {
606 auto dev_type = it.first.as<ir::IntImm>();
607 CHECK(dev_type);
608 targets[dev_type->value] = it.second;
609 }
610 codegen_ = std::make_shared<GraphRuntimeCodegen>(
611 reinterpret_cast<runtime::Module*>(mod), targets);
612 });
613 } else if (name == "codegen") {
614 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
615 Function func = args[0];
616 this->output_ = this->codegen_->Codegen(func);
617 });
618 } else if (name == "get_graph_json") {
619 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
620 *rv = this->output_.graph_json;
621 });
622 } else if (name == "list_params_name") {
623 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
624 Array<tvm::Expr> ret;
625 for (const auto &kv : this->output_.params) {
626 tvm::Expr name = ir::StringImm::make(kv.first);
627 ret.push_back(name);
628 }
629 *rv = ret;
630 });
631
632 } else if (name == "get_param_by_name") {
633 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
634 std::string key = args[0];
635 CHECK_GT(this->output_.params.count(key), 0);
636 *rv = this->output_.params[key];
637 });
638 } else if (name == "get_lowered_funcs") {
639 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
640 *rv = this->output_.lowered_funcs;
641 });
642 } else {
643 return PackedFunc([](TVMArgs args, TVMRetValue* rv) {});
644 }
645 }
646
type_key() const647 const char* type_key() const final {
648 return "RelayGraphRuntimeCodegenModule";
649 }
650
651 private:
652 std::shared_ptr<GraphRuntimeCodegen> codegen_;
653 LoweredOutput output_;
654 };
655
CreateGraphCodegenMod()656 runtime::Module CreateGraphCodegenMod() {
657 auto ptr = make_object<GraphRuntimeCodegenModule>();
658 return runtime::Module(ptr);
659 }
660
661 TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen")
__anon8b7667920802(TVMArgs args, TVMRetValue* rv) 662 .set_body([](TVMArgs args, TVMRetValue* rv) {
663 *rv = CreateGraphCodegenMod();
664 });
665
666 } // namespace backend
667 } // namespace relay
668 } // namespace tvm
669
670 namespace dmlc {
671 namespace json {
672 // JSON utils
673 template <typename T>
SameType(const dmlc::any & data)674 inline bool SameType(const dmlc::any& data) {
675 return std::type_index(data.type()) == std::type_index(typeid(T));
676 }
677
678 template <>
679 struct Handler<std::shared_ptr<tvm::relay::backend::GraphNode>> {
Writedmlc::json::Handler680 inline static void Write(dmlc::JSONWriter* writer,
681 const std::shared_ptr<tvm::relay::backend::GraphNode>& data) {
682 data->Save(writer);
683 }
Readdmlc::json::Handler684 inline static void Read(dmlc::JSONReader* reader,
685 std::shared_ptr<tvm::relay::backend::GraphNode>* data) {
686 LOG(FATAL) << "Not implemented.";
687 }
688 };
689
690 template <>
691 struct Handler<std::unordered_map<std::string, dmlc::any>> {
Writedmlc::json::Handler692 inline static void Write(dmlc::JSONWriter* writer,
693 const std::unordered_map<std::string, dmlc::any>& data) {
694 writer->BeginObject();
695 for (const auto& kv : data) {
696 auto k = kv.first;
697 const dmlc::any& v = kv.second;
698 if (SameType<std::string>(v)) {
699 writer->WriteObjectKeyValue(k, dmlc::get<std::string>(v));
700 } else if (SameType<int>(v)) {
701 writer->WriteObjectKeyValue(k, dmlc::get<int>(v));
702 } else if (SameType<std::vector<size_t>>(v)) {
703 writer->WriteObjectKeyValue(k, dmlc::get<std::vector<size_t>>(v));
704 } else if (SameType<std::vector<std::vector<int64_t>>>(v)) {
705 writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::vector<int64_t>>>(v));
706 } else if (SameType<std::vector<std::string>>(v)) {
707 writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::string>>(v));
708 } else {
709 LOG(FATAL) << "Not supported";
710 }
711 }
712 writer->EndObject();
713 }
Readdmlc::json::Handler714 inline static void Read(dmlc::JSONReader* reader,
715 std::unordered_map<std::string, dmlc::any>* data) {
716 LOG(FATAL) << "Not implemented.";
717 }
718 };
719
720 template <>
721 struct Handler<std::vector<dmlc::any>> {
Writedmlc::json::Handler722 inline static void Write(dmlc::JSONWriter* writer, const std::vector<dmlc::any>& data) {
723 writer->BeginArray();
724 for (const auto& v : data) {
725 if (SameType<std::string>(v)) {
726 writer->WriteArrayItem(dmlc::get<std::string>(v));
727 } else if (SameType<int>(v)) {
728 writer->WriteArrayItem(dmlc::get<int>(v));
729 } else if (SameType<std::vector<size_t>>(v)) {
730 writer->WriteArrayItem(dmlc::get<std::vector<size_t>>(v));
731 } else if (SameType<std::vector<std::vector<int64_t>>>(v)) {
732 writer->WriteArrayItem(dmlc::get<std::vector<std::vector<int64_t>>>(v));
733 } else if (SameType<std::vector<std::string>>(v)) {
734 writer->WriteArrayItem(dmlc::get<std::vector<std::string>>(v));
735 } else {
736 LOG(FATAL) << "Not supported";
737 }
738 }
739 writer->EndArray();
740 }
Readdmlc::json::Handler741 inline static void Read(dmlc::JSONReader* reader, std::vector<dmlc::any>* data) {
742 LOG(FATAL) << "Not implemented.";
743 }
744 };
745 } // namespace json
746 } // namespace dmlc
747