1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file relay/backend/graph_codegen.cc
22  * \brief Graph runtime codegen
23  */
24 
25 #include <dmlc/any.h>
26 #include <dmlc/json.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/runtime/device_api.h>
29 
30 
31 #include <list>
32 #include <string>
33 #include <vector>
34 
35 #include "utils.h"
36 #include "compile_engine.h"
37 
38 namespace tvm {
39 namespace relay {
40 namespace backend {
41 
42 class GraphNode;
43 class GraphInputNode;
44 class GraphOpNode;
45 
46 using IntegerArray = Array<Integer>;
47 using ShapeVector = std::vector<std::vector<int64_t> >;
48 using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
49 using GraphNodePtr = std::shared_ptr<GraphNode>;
50 using GraphInputNodePtr = std::shared_ptr<GraphInputNode>;
51 using GraphOpNodePtr = std::shared_ptr<GraphOpNode>;
52 using TargetsMap = std::unordered_map<int, Target>;
53 
54 /*! \brief Lowered outputs */
55 struct LoweredOutput {
56   std::string graph_json;
57   Map<std::string, Array<LoweredFunc> > lowered_funcs;
58   std::unordered_map<std::string, tvm::runtime::NDArray> params;
59 };
60 
61 /*! \brief Node types */
62 enum GraphNodeType {
63   kGraphNop,
64   kGraphInputNode,
65   kGraphOpNode,
66 };
67 
68 class GraphNodeRef {
69  public:
GraphNodeRef()70   GraphNodeRef() {}
GraphNodeRef(int ident,int index,int version=0)71   GraphNodeRef(int ident, int index, int version = 0)
72     : ident_(ident), index_(index), version_(version) {}
73 
74 
Save(dmlc::JSONWriter * writer) const75   inline void Save(dmlc::JSONWriter* writer) const {
76     writer->BeginArray();
77     writer->WriteArrayItem(ident_);
78     writer->WriteArrayItem(index_);
79     writer->WriteArrayItem(version_);
80     writer->EndArray();
81   }
82 
Load(dmlc::JSONReader * reader)83   inline void Load(dmlc::JSONReader* reader) {
84     LOG(FATAL) << "Not implemented.";
85   }
86 
87  protected:
88   int ident_;
89   int index_{0};
90   int version_{0};
91 };
92 
93 /*! \brief Base Node class */
94 class GraphNode {
95  public:
GraphNode()96   GraphNode() {}
Save(dmlc::JSONWriter * writer) const97   virtual void Save(dmlc::JSONWriter* writer) const {}
Load(dmlc::JSONReader * reader)98   virtual void Load(dmlc::JSONReader* reader) {}
Type() const99   virtual GraphNodeType Type() const { return kGraphNop; }
~GraphNode()100   virtual ~GraphNode() {}
101 
102  public:
103   int num_outputs_{1};
104   std::string name_;
105   GraphAttrs attrs_;
106 };
107 
108 /*! \brief Input Node */
109 class GraphInputNode : public GraphNode {
110  public:
GraphInputNode()111   GraphInputNode() {}
GraphInputNode(const std::string & name,const GraphAttrs & attrs)112   GraphInputNode(const std::string& name, const GraphAttrs& attrs) {
113     name_ = name;
114     attrs_ = attrs;
115   }
116 
Type() const117   GraphNodeType Type() const override { return kGraphInputNode; }
118 
Save(dmlc::JSONWriter * writer) const119   void Save(dmlc::JSONWriter* writer) const override {
120     const std::string op_name{"null"};
121     writer->BeginObject();
122     writer->WriteObjectKeyValue("op", op_name);
123     writer->WriteObjectKeyValue("name", this->name_);
124     writer->WriteObjectKeyValue("inputs", std::list<int>());
125     writer->EndObject();
126   }
make_node_ptr(const std::string & name,const GraphAttrs & attrs)127   static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
128                                                   const GraphAttrs& attrs) {
129     auto ptr = std::make_shared<GraphInputNode>(name, attrs);
130     return std::dynamic_pointer_cast<GraphNode>(ptr);
131   }
132 };
133 
134 /*! \brief Op Node */
135 class GraphOpNode : public GraphNode {
136  public:
GraphOpNode()137   GraphOpNode() {}
GraphOpNode(const std::string & name,const GraphAttrs & nd_attrs,const std::string & op_name,const std::vector<GraphNodeRef> & inputs,const GraphAttrs & attrs,size_t num_outputs=1)138   GraphOpNode(const std::string& name,
139               const GraphAttrs& nd_attrs,
140               const std::string& op_name,
141               const std::vector<GraphNodeRef>& inputs,
142               const GraphAttrs& attrs,
143               size_t num_outputs = 1) {
144     name_ = name;
145     attrs_ = nd_attrs;
146     op_name_ = op_name;
147     inputs_ = inputs;
148     op_attrs_ = attrs_;
149     num_outputs_ = num_outputs;
150     op_attrs_["func_name"] = op_name_;
151     op_attrs_["flatten_data"] = std::string("0");
152     op_attrs_["num_inputs"] = std::to_string(inputs_.size());
153     op_attrs_["num_outputs"] = std::to_string(num_outputs_);
154   }
155 
Type() const156   GraphNodeType Type() const override { return kGraphOpNode; }
157 
Save(dmlc::JSONWriter * writer) const158   void Save(dmlc::JSONWriter* writer) const override {
159     GraphAttrs attrs = op_attrs_;
160     attrs["func_name"] = this->op_name_;
161     attrs["flatten_data"] = std::string("0");
162     attrs["num_inputs"] = std::to_string(this->inputs_.size());
163     attrs["num_outputs"] = std::to_string(this->num_outputs_);
164     writer->BeginObject();
165     writer->WriteObjectKeyValue("op", op_type_name_);
166     writer->WriteObjectKeyValue("name", name_);
167     writer->WriteObjectKeyValue("attrs", attrs);
168     writer->WriteObjectKeyValue("inputs", this->inputs_);
169     writer->EndObject();
170   }
make_node_ptr(const std::string & name,const GraphAttrs & nd_attrs,const std::string & op_name,const std::vector<GraphNodeRef> & inputs,const GraphAttrs & attrs,size_t num_outputs=1)171   static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
172                                                   const GraphAttrs& nd_attrs,
173                                                   const std::string& op_name,
174                                                   const std::vector<GraphNodeRef>& inputs,
175                                                   const GraphAttrs& attrs,
176                                                   size_t num_outputs = 1) {
177     auto ptr = std::make_shared<GraphOpNode>(name, nd_attrs, op_name, inputs, attrs, num_outputs);
178     return std::dynamic_pointer_cast<GraphNode>(ptr);
179   }
180 
181  public:
182   std::string op_name_;
183   std::vector<GraphNodeRef> inputs_;
184   GraphAttrs op_attrs_;
185 
186  private:
187   const std::string op_type_name_{"tvm_op"};
188 };
189 
190 /*! \brief Code generator for graph runtime */
191 class GraphRuntimeCodegen
192     : public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
193  public:
GraphRuntimeCodegen(runtime::Module * mod,const TargetsMap & targets)194   GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
195       : mod_(mod) {
196     compile_engine_ = CompileEngine::Global();
197     targets_ = targets;
198   }
199 
Codegen(relay::Function func)200   LoweredOutput Codegen(relay::Function func) {
201     auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
202     storage_device_map_ = (*pf)(func);
203     // First we convert all the parameters into input nodes.
204     for (auto param : func->params) {
205       auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
206       var_map_[param.get()] = AddNode(node_ptr, param);
207     }
208     heads_ = VisitExpr(func->body);
209     std::ostringstream os;
210     dmlc::JSONWriter writer(&os);
211     GetJSON(&writer);
212     LoweredOutput ret;
213     ret.graph_json = os.str();
214     ret.params = params_;
215     for (auto& kv : lowered_funcs_) {
216       if (ret.lowered_funcs.count(kv.first) == 0) {
217         ret.lowered_funcs.Set(kv.first, Array<LoweredFunc>());
218       }
219       auto& vec = ret.lowered_funcs[kv.first];
220       Array<LoweredFunc> tmp;
221       for (auto f : kv.second) {
222         tmp.push_back(f);
223       }
224       for (auto f : vec) {
225         tmp.push_back(f);
226       }
227       ret.lowered_funcs.Set(kv.first, tmp);
228     }
229     return ret;
230   }
231 
232  protected:
233   /*!
234    * \brief Extract shape from expr to vector<int64_t>
235    *
236    * \param shape
237    * \return std::vector<int64_t>
238    */
_ShapeToJSON(tvm::Array<IndexExpr> shape)239   std::vector<int64_t> _ShapeToJSON(tvm::Array<IndexExpr> shape) {
240     std::vector<int64_t> ret;
241     for (IndexExpr dim : shape) {
242       const int64_t* pval = as_const_int(dim);
243       ret.push_back(*pval);
244     }
245     return ret;
246   }
247 
248   /*!
249    * \brief Add node to graph
250    *
251    * \param node
252    * \param expr
253    * \return std::vector<_NodeRef>
254    */
AddNode(GraphNodePtr node,Expr expr)255   std::vector<GraphNodeRef> AddNode(GraphNodePtr node, Expr expr) {
256     auto checked_type = expr->checked_type();
257     size_t count = storage_device_map_.count(expr);
258     CHECK_GT(count, 0) << "Expr is not existing in storage plan";
259     auto storage_device_info = storage_device_map_[expr];
260     CHECK_EQ(storage_device_info.size(), 2);
261     // storage
262     std::vector<int64_t> storage_info;
263     for (auto& v : storage_device_info[0]) {
264       storage_info.push_back(v->value);
265     }
266     node->attrs_["storage_id"] = std::move(storage_info);
267     // type
268     std::vector<int64_t> device_types;
269     for (auto& v : storage_device_info[1]) {
270       device_types.push_back(v->value);
271     }
272     size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0);
273     if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) {
274       LOG(FATAL) << "The graph contains not annotated nodes for "
275                  << "heterogeneous execution. All nodes must be "
276                  << "annotated.";
277     }
278     if (num_unknown_devices == 0) {
279       node->attrs_["device_index"] = device_types;
280     }
281     auto node_id = nodes_.size();
282     nodes_.push_back(node);
283     // Tuple return value, flatten as tuple
284     if (const auto* tuple_type = checked_type.as<TupleTypeNode>()) {
285       std::vector<GraphNodeRef> ret;
286       ShapeVector shape;
287       std::vector<std::string> dtype;
288       for (size_t i = 0; i < tuple_type->fields.size(); ++i) {
289         if (const auto* typ = tuple_type->fields[i].as<TensorTypeNode>()) {
290           ret.push_back(GraphNodeRef(node_id, i));
291           shape.emplace_back(_ShapeToJSON(typ->shape));
292           dtype.emplace_back(DType2String(typ->dtype));
293         } else {
294           LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported";
295         }
296       }
297       CHECK_EQ(node->Type(), kGraphOpNode);
298       auto op_nd = std::dynamic_pointer_cast<GraphOpNode>(node);
299       op_nd->attrs_["shape"] = shape;
300       op_nd->attrs_["dtype"] = dtype;
301       op_nd->num_outputs_ = tuple_type->fields.size();
302       return ret;
303     }
304     // Normal tensor return type
305     if (const auto* tensor_type = checked_type.as<TensorTypeNode>()) {
306       ShapeVector shape;
307       std::vector<std::string> dtype;
308       shape.emplace_back(_ShapeToJSON(tensor_type->shape));
309       dtype.emplace_back(DType2String(tensor_type->dtype));
310       node->attrs_["shape"] = shape;
311       node->attrs_["dtype"] = dtype;
312     } else {
313       LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported";
314     }
315     return {GraphNodeRef(node_id, 0)};
316   }
317 
318   /*! \brief Visitors */
319   std::unordered_map<Expr, std::vector<GraphNodeRef>, NodeHash, NodeEqual> visitor_cache_;
320 
VisitExpr(const Expr & expr)321   std::vector<GraphNodeRef> VisitExpr(const Expr& expr) override {
322     if (visitor_cache_.count(expr)) return visitor_cache_.at(expr);
323     std::vector<GraphNodeRef> res;
324     if (expr.as<ConstantNode>()) {
325       res = VisitExpr_(expr.as<ConstantNode>());
326     } else if (expr.as<TupleNode>()) {
327       res = VisitExpr_(expr.as<TupleNode>());
328     } else if (expr.as<VarNode>()) {
329       res = VisitExpr_(expr.as<VarNode>());
330     } else if (expr.as<GlobalVarNode>()) {
331       res = VisitExpr_(expr.as<GlobalVarNode>());
332     } else if (expr.as<FunctionNode>()) {
333       res = VisitExpr_(expr.as<FunctionNode>());
334     } else if (expr.as<CallNode>()) {
335       res = VisitExpr_(expr.as<CallNode>());
336     } else if (expr.as<LetNode>()) {
337       res = VisitExpr_(expr.as<LetNode>());
338     } else if (expr.as<IfNode>()) {
339       res = VisitExpr_(expr.as<IfNode>());
340     } else if (expr.as<OpNode>()) {
341       res = VisitExpr_(expr.as<OpNode>());
342     } else if (expr.as<TupleGetItemNode>()) {
343       res = VisitExpr_(expr.as<TupleGetItemNode>());
344     } else if (expr.as<RefCreateNode>()) {
345       res = VisitExpr_(expr.as<RefCreateNode>());
346     } else if (expr.as<RefReadNode>()) {
347       res = VisitExpr_(expr.as<RefReadNode>());
348     } else if (expr.as<RefWriteNode>()) {
349       res = VisitExpr_(expr.as<RefWriteNode>());
350     } else if (expr.as<ConstructorNode>()) {
351       res = VisitExpr_(expr.as<ConstructorNode>());
352     } else if (expr.as<MatchNode>()) {
353       res = VisitExpr_(expr.as<MatchNode>());
354     }
355     visitor_cache_[expr] = res;
356     return res;
357   }
358 
VisitExpr_(const VarNode * op)359   std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
360     Expr expr = GetRef<Expr>(op);
361     return var_map_[expr.get()];
362   }
363 
VisitExpr_(const ConstantNode * op)364   std::vector<GraphNodeRef> VisitExpr_(const ConstantNode* op) override {
365     Expr expr = GetRef<Expr>(op);
366     size_t index = params_.size();
367     std::string name = "p" + std::to_string(index);
368     params_[name] = op->data;
369     auto node = GraphInputNode::make_node_ptr(name, GraphAttrs());
370     return AddNode(node, expr);
371   }
372 
VisitExpr_(const TupleNode * op)373   std::vector<GraphNodeRef> VisitExpr_(const TupleNode* op) override {
374     std::vector<GraphNodeRef> fields;
375     for (auto field : op->fields) {
376       auto ref_vec = VisitExpr(field);
377       for (auto ref : ref_vec) {
378         fields.push_back(ref);
379       }
380     }
381     return fields;
382   }
VisitExpr_(const CallNode * op)383   std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
384     Expr expr = GetRef<Expr>(op);
385     Function func;
386     if (op->op.as<OpNode>()) {
387       LOG(FATAL) << "Operators should be transformed away; try applying"
388                  << "the fuse_ops transformation to the expression.";
389     } else if (op->op.as<GlobalVarNode>()) {
390       LOG(FATAL) << "Not implemented";
391     } else if (op->op.as<FunctionNode>()) {
392       func = GetRef<Function>(op->op.as<FunctionNode>());
393     } else {
394       LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
395     }
396     if (!func->IsPrimitive()) {
397       LOG(FATAL) << "TVM only support calls to primitive functions "
398                  << "(i.e functions composed of fusable operator invocations)";
399     }
400 
401     CHECK_GE(storage_device_map_.count(expr), 0);
402     auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
403     auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
404     auto &device_type = storage_device_map_[expr][1];
405     auto call_dev_type = device_type[0]->value;
406     Target target;
407     if (targets_.size() == 1) {
408        // homogeneous execution.
409        for (auto kv : targets_) {
410          target = kv.second;
411        }
412     } else {
413       // heterogeneous execution.
414       std::string call_dev_name;
415       if (call_dev_type == 0) {
416         call_dev_name = "llvm";
417       } else {
418         call_dev_name = runtime::DeviceName(call_dev_type);
419       }
420       if (targets_.count(call_dev_type) == 0) {
421         LOG(FATAL) << "No target is provided for device "
422                    << call_dev_name;
423       }
424       target = targets_[call_dev_type];
425     }
426     CCacheKey key = (*pf0)(func, target);
427     CachedFunc lowerd_func = (*pf1)(compile_engine_, key);
428     if (!lowered_funcs_.count(target->str())) {
429       lowered_funcs_[target->str()] = {};
430     }
431     for (auto f : lowerd_func->funcs) {
432       lowered_funcs_[target->str()].insert(f);
433     }
434 
435     std::vector<GraphNodeRef> inputs;
436     for (auto arg : op->args) {
437       auto res = VisitExpr(arg);
438       for (auto nr : res) {
439         inputs.push_back(nr);
440       }
441     }
442     auto& op_name = lowerd_func->func_name;
443     auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name),
444                                            GraphAttrs(),
445                                            op_name,
446                                            inputs,
447                                            GraphAttrs());
448     return AddNode(node, expr);
449   }
450 
VisitExpr_(const LetNode * op)451   std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
452     CHECK_EQ(var_map_.count(op->var.get()), 0);
453     var_map_[op->var.get()] = VisitExpr(op->value);
454     return VisitExpr(op->body);
455   }
VisitExpr_(const TupleGetItemNode * op)456   std::vector<GraphNodeRef> VisitExpr_(const TupleGetItemNode* op) override {
457     auto vtuple = VisitExpr(op->tuple);
458     return {vtuple[op->index]};
459   }
VisitExpr_(const OpNode * op)460   std::vector<GraphNodeRef> VisitExpr_(const OpNode* op) override {
461     throw std::runtime_error("can not compile op in non-eta expanded form");
462     return {};
463   }
VisitExpr_(const GlobalVarNode * op)464   std::vector<GraphNodeRef> VisitExpr_(const GlobalVarNode* op) override {
465     throw std::runtime_error("");
466     return {};
467   }
VisitExpr_(const IfNode * op)468   std::vector<GraphNodeRef> VisitExpr_(const IfNode* op) override {
469     throw std::invalid_argument("if not supported");
470     return {};
471   }
VisitExpr_(const FunctionNode * op)472   std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
473     throw std::invalid_argument("function not supported");
474     return {};
475   }
VisitExpr_(const RefCreateNode * op)476   std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
477     throw std::invalid_argument("reference not supported");
478     return {};
479   }
VisitExpr_(const RefReadNode * op)480   std::vector<GraphNodeRef> VisitExpr_(const RefReadNode* op) override {
481     throw std::invalid_argument("reference not supported");
482     return {};
483   }
VisitExpr_(const RefWriteNode * op)484   std::vector<GraphNodeRef> VisitExpr_(const RefWriteNode* op) override {
485     throw std::invalid_argument("reference not supported");
486     return {};
487   }
VisitExpr_(const ConstructorNode * op)488   std::vector<GraphNodeRef> VisitExpr_(const ConstructorNode* op) override {
489     throw std::invalid_argument("ADT constructor case not yet implemented");
490     return {};
491   }
VisitExpr_(const MatchNode * op)492   std::vector<GraphNodeRef> VisitExpr_(const MatchNode* op) override {
493     throw std::invalid_argument("match case not yet implemented");
494     return {};
495   }
496   /*!
497    * \brief Generate Graph JSON
498    *
499    * \param writer json writer
500    */
GetJSON(dmlc::JSONWriter * writer)501   void GetJSON(dmlc::JSONWriter* writer) {
502     std::vector<size_t> arg_nodes;
503     for (size_t i = 0; i < nodes_.size(); ++i) {
504       auto node = nodes_[i];
505       if (node->Type() == kGraphInputNode) {
506         arg_nodes.push_back(i);
507       }
508     }
509     size_t num_entry = 0;
510     ShapeVector shapes;
511     std::vector<size_t> storage_ids;
512     std::vector<size_t> device_types;
513     std::vector<std::string> dltypes;
514     std::vector<size_t> node_row_ptr{0};
515     for (auto node : nodes_) {
516       const auto& shape_vec = dmlc::get<ShapeVector>(node->attrs_["shape"]);
517       const auto& storage_id = dmlc::get<std::vector<int64_t>>(node->attrs_["storage_id"]);
518       const auto& dtype_vec = dmlc::get<std::vector<std::string>>(node->attrs_["dtype"]);
519 
520       CHECK_EQ(node->num_outputs_, shape_vec.size());
521       num_entry += node->num_outputs_;
522 
523       shapes.insert(shapes.end(), shape_vec.begin(), shape_vec.end());
524       dltypes.insert(dltypes.end(), dtype_vec.begin(), dtype_vec.end());
525       storage_ids.insert(storage_ids.end(), storage_id.begin(), storage_id.end());
526       if (node->attrs_.count("device_index")) {
527         const auto& dev_types = dmlc::get<std::vector<int64_t>>(node->attrs_["device_index"]);
528         device_types.insert(device_types.end(), dev_types.begin(), dev_types.end());
529       }
530       node_row_ptr.push_back(num_entry);
531     }
532     writer->BeginObject();
533     writer->WriteObjectKeyValue("nodes", nodes_);
534     writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
535     writer->WriteObjectKeyValue("heads", heads_);
536     std::unordered_map<std::string, std::vector<dmlc::any>> attrs;
537     attrs["shape"].emplace_back(std::string("list_shape"));
538     attrs["shape"].emplace_back(shapes);
539     attrs["storage_id"].emplace_back(std::string("list_int"));
540     attrs["storage_id"].emplace_back(storage_ids);
541     if (device_types.size()) {
542       attrs["device_index"].emplace_back(std::string("list_int"));
543       attrs["device_index"].emplace_back(device_types);
544     }
545     attrs["dltype"].emplace_back(std::string("list_str"));
546     attrs["dltype"].emplace_back(dltypes);
547     writer->WriteObjectKeyValue("attrs", attrs);
548     writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
549     writer->EndObject();
550   }
551 
552   /*!
553    * \brief Get unique name for func
554    *
555    * \param name
556    * \return std::string
557    */
_GetUniqueName(const std::string & name)558   std::string _GetUniqueName(const std::string& name) {
559     if (!name_map_.count(name)) {
560       name_map_[name] = 1;
561       return name;
562     }
563     auto index = name_map_[name];
564     name_map_[name] += 1;
565     return _GetUniqueName(name + std::to_string(index));
566   }
567 
568  protected:
569   /*! \brief nodes */
570   std::vector<GraphNodePtr> nodes_;
571   /*! \brief output of graph */
572   std::vector<GraphNodeRef> heads_;
573   /*! \brief mod */
574   runtime::Module* mod_;
575   /*! \brief variable map */
576   std::unordered_map<const Node*, std::vector<GraphNodeRef>> var_map_;
577   /*! \brief target device */
578   TargetsMap targets_;
579   /*! \brief params */
580   std::unordered_map<std::string, runtime::NDArray> params_;
581   /*! \brief plan memory of device result */
582   Map<Expr, Array<IntegerArray>> storage_device_map_;
583   /*! \brief lowered funcs */
584   std::unordered_map<std::string, std::unordered_set<LoweredFunc, NodeHash, NodeEqual>>
585       lowered_funcs_;
586   /*! \brief name map */
587   std::unordered_map<std::string, size_t> name_map_;
588   /*! \brief compile engine */
589   CompileEngine compile_engine_;
590 };
591 
592 class GraphRuntimeCodegenModule : public runtime::ModuleNode {
593  public:
GraphRuntimeCodegenModule()594   GraphRuntimeCodegenModule() {}
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)595   virtual PackedFunc GetFunction(const std::string& name,
596                                  const ObjectPtr<Object>& sptr_to_self) {
597      if (name == "init") {
598        return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
599          CHECK_EQ(args.num_args, 2)
600              << "The expected of arguments are: "
601              << "runtime::Module mod and Map<int, Target> targets";
602          void* mod = args[0];
603          Map<Integer, tvm::Target> tmp = args[1];
604          TargetsMap targets;
605          for (const auto& it : tmp) {
606            auto dev_type = it.first.as<ir::IntImm>();
607            CHECK(dev_type);
608            targets[dev_type->value] = it.second;
609          }
610          codegen_ = std::make_shared<GraphRuntimeCodegen>(
611              reinterpret_cast<runtime::Module*>(mod), targets);
612        });
613     } else if (name == "codegen") {
614       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
615         Function func = args[0];
616         this->output_ = this->codegen_->Codegen(func);
617       });
618     } else if (name == "get_graph_json") {
619       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
620         *rv = this->output_.graph_json;
621       });
622     } else if (name == "list_params_name") {
623       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
624         Array<tvm::Expr> ret;
625         for (const auto &kv : this->output_.params) {
626           tvm::Expr name = ir::StringImm::make(kv.first);
627           ret.push_back(name);
628         }
629         *rv = ret;
630       });
631 
632     } else if (name == "get_param_by_name") {
633       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
634         std::string key = args[0];
635         CHECK_GT(this->output_.params.count(key), 0);
636         *rv = this->output_.params[key];
637       });
638     } else if (name == "get_lowered_funcs") {
639       return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
640         *rv = this->output_.lowered_funcs;
641       });
642     } else {
643       return PackedFunc([](TVMArgs args, TVMRetValue* rv) {});
644     }
645   }
646 
type_key() const647   const char* type_key() const final {
648     return "RelayGraphRuntimeCodegenModule";
649   }
650 
651  private:
652   std::shared_ptr<GraphRuntimeCodegen> codegen_;
653   LoweredOutput output_;
654 };
655 
CreateGraphCodegenMod()656 runtime::Module CreateGraphCodegenMod() {
657   auto ptr = make_object<GraphRuntimeCodegenModule>();
658   return runtime::Module(ptr);
659 }
660 
661 TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen")
__anon8b7667920802(TVMArgs args, TVMRetValue* rv) 662 .set_body([](TVMArgs args, TVMRetValue* rv) {
663   *rv = CreateGraphCodegenMod();
664 });
665 
666 }  // namespace backend
667 }  // namespace relay
668 }  // namespace tvm
669 
670 namespace dmlc {
671 namespace json {
672 // JSON utils
673 template <typename T>
SameType(const dmlc::any & data)674 inline bool SameType(const dmlc::any& data) {
675   return std::type_index(data.type()) == std::type_index(typeid(T));
676 }
677 
678 template <>
679 struct Handler<std::shared_ptr<tvm::relay::backend::GraphNode>> {
Writedmlc::json::Handler680   inline static void Write(dmlc::JSONWriter* writer,
681                            const std::shared_ptr<tvm::relay::backend::GraphNode>& data) {
682     data->Save(writer);
683   }
Readdmlc::json::Handler684   inline static void Read(dmlc::JSONReader* reader,
685                           std::shared_ptr<tvm::relay::backend::GraphNode>* data) {
686     LOG(FATAL) << "Not implemented.";
687   }
688 };
689 
690 template <>
691 struct Handler<std::unordered_map<std::string, dmlc::any>> {
Writedmlc::json::Handler692   inline static void Write(dmlc::JSONWriter* writer,
693                            const std::unordered_map<std::string, dmlc::any>& data) {
694     writer->BeginObject();
695     for (const auto& kv : data) {
696       auto k = kv.first;
697       const dmlc::any& v = kv.second;
698       if (SameType<std::string>(v)) {
699         writer->WriteObjectKeyValue(k, dmlc::get<std::string>(v));
700       } else if (SameType<int>(v)) {
701         writer->WriteObjectKeyValue(k, dmlc::get<int>(v));
702       } else if (SameType<std::vector<size_t>>(v)) {
703         writer->WriteObjectKeyValue(k, dmlc::get<std::vector<size_t>>(v));
704       } else if (SameType<std::vector<std::vector<int64_t>>>(v)) {
705         writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::vector<int64_t>>>(v));
706       } else if (SameType<std::vector<std::string>>(v)) {
707         writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::string>>(v));
708       } else {
709         LOG(FATAL) << "Not supported";
710       }
711     }
712     writer->EndObject();
713   }
Readdmlc::json::Handler714   inline static void Read(dmlc::JSONReader* reader,
715                           std::unordered_map<std::string, dmlc::any>* data) {
716     LOG(FATAL) << "Not implemented.";
717   }
718 };
719 
720 template <>
721 struct Handler<std::vector<dmlc::any>> {
Writedmlc::json::Handler722   inline static void Write(dmlc::JSONWriter* writer, const std::vector<dmlc::any>& data) {
723     writer->BeginArray();
724     for (const auto& v : data) {
725       if (SameType<std::string>(v)) {
726         writer->WriteArrayItem(dmlc::get<std::string>(v));
727       } else if (SameType<int>(v)) {
728         writer->WriteArrayItem(dmlc::get<int>(v));
729       } else if (SameType<std::vector<size_t>>(v)) {
730         writer->WriteArrayItem(dmlc::get<std::vector<size_t>>(v));
731       } else if (SameType<std::vector<std::vector<int64_t>>>(v)) {
732         writer->WriteArrayItem(dmlc::get<std::vector<std::vector<int64_t>>>(v));
733       } else if (SameType<std::vector<std::string>>(v)) {
734         writer->WriteArrayItem(dmlc::get<std::vector<std::string>>(v));
735       } else {
736         LOG(FATAL) << "Not supported";
737       }
738     }
739     writer->EndArray();
740   }
Readdmlc::json::Handler741   inline static void Read(dmlc::JSONReader* reader, std::vector<dmlc::any>* data) {
742     LOG(FATAL) << "Not implemented.";
743   }
744 };
745 }  // namespace json
746 }  // namespace dmlc
747