/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file relay/backend/graph_codegen.cc * \brief Graph runtime codegen */ #include #include #include #include #include #include #include #include "utils.h" #include "compile_engine.h" namespace tvm { namespace relay { namespace backend { class GraphNode; class GraphInputNode; class GraphOpNode; using IntegerArray = Array; using ShapeVector = std::vector >; using GraphAttrs = std::unordered_map; using GraphNodePtr = std::shared_ptr; using GraphInputNodePtr = std::shared_ptr; using GraphOpNodePtr = std::shared_ptr; using TargetsMap = std::unordered_map; /*! \brief Lowered outputs */ struct LoweredOutput { std::string graph_json; Map > lowered_funcs; std::unordered_map params; }; /*! \brief Node types */ enum GraphNodeType { kGraphNop, kGraphInputNode, kGraphOpNode, }; class GraphNodeRef { public: GraphNodeRef() {} GraphNodeRef(int ident, int index, int version = 0) : ident_(ident), index_(index), version_(version) {} inline void Save(dmlc::JSONWriter* writer) const { writer->BeginArray(); writer->WriteArrayItem(ident_); writer->WriteArrayItem(index_); writer->WriteArrayItem(version_); writer->EndArray(); } inline void Load(dmlc::JSONReader* reader) { LOG(FATAL) << "Not implemented."; } protected: int ident_; int index_{0}; int version_{0}; }; /*! \brief Base Node class */ class GraphNode { public: GraphNode() {} virtual void Save(dmlc::JSONWriter* writer) const {} virtual void Load(dmlc::JSONReader* reader) {} virtual GraphNodeType Type() const { return kGraphNop; } virtual ~GraphNode() {} public: int num_outputs_{1}; std::string name_; GraphAttrs attrs_; }; /*! \brief Input Node */ class GraphInputNode : public GraphNode { public: GraphInputNode() {} GraphInputNode(const std::string& name, const GraphAttrs& attrs) { name_ = name; attrs_ = attrs; } GraphNodeType Type() const override { return kGraphInputNode; } void Save(dmlc::JSONWriter* writer) const override { const std::string op_name{"null"}; writer->BeginObject(); writer->WriteObjectKeyValue("op", op_name); writer->WriteObjectKeyValue("name", this->name_); writer->WriteObjectKeyValue("inputs", std::list()); writer->EndObject(); } static std::shared_ptr make_node_ptr(const std::string& name, const GraphAttrs& attrs) { auto ptr = std::make_shared(name, attrs); return std::dynamic_pointer_cast(ptr); } }; /*! \brief Op Node */ class GraphOpNode : public GraphNode { public: GraphOpNode() {} GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name, const std::vector& inputs, const GraphAttrs& attrs, size_t num_outputs = 1) { name_ = name; attrs_ = nd_attrs; op_name_ = op_name; inputs_ = inputs; op_attrs_ = attrs_; num_outputs_ = num_outputs; op_attrs_["func_name"] = op_name_; op_attrs_["flatten_data"] = std::string("0"); op_attrs_["num_inputs"] = std::to_string(inputs_.size()); op_attrs_["num_outputs"] = std::to_string(num_outputs_); } GraphNodeType Type() const override { return kGraphOpNode; } void Save(dmlc::JSONWriter* writer) const override { GraphAttrs attrs = op_attrs_; attrs["func_name"] = this->op_name_; attrs["flatten_data"] = std::string("0"); attrs["num_inputs"] = std::to_string(this->inputs_.size()); attrs["num_outputs"] = std::to_string(this->num_outputs_); writer->BeginObject(); writer->WriteObjectKeyValue("op", op_type_name_); writer->WriteObjectKeyValue("name", name_); writer->WriteObjectKeyValue("attrs", attrs); writer->WriteObjectKeyValue("inputs", this->inputs_); writer->EndObject(); } static std::shared_ptr make_node_ptr(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name, const std::vector& inputs, const GraphAttrs& attrs, size_t num_outputs = 1) { auto ptr = std::make_shared(name, nd_attrs, op_name, inputs, attrs, num_outputs); return std::dynamic_pointer_cast(ptr); } public: std::string op_name_; std::vector inputs_; GraphAttrs op_attrs_; private: const std::string op_type_name_{"tvm_op"}; }; /*! \brief Code generator for graph runtime */ class GraphRuntimeCodegen : public ::tvm::relay::ExprFunctor(const Expr&)> { public: GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) { compile_engine_ = CompileEngine::Global(); targets_ = targets; } LoweredOutput Codegen(relay::Function func) { auto pf = GetPackedFunc("relay.backend.GraphPlanMemory"); storage_device_map_ = (*pf)(func); // First we convert all the parameters into input nodes. for (auto param : func->params) { auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); var_map_[param.get()] = AddNode(node_ptr, param); } heads_ = VisitExpr(func->body); std::ostringstream os; dmlc::JSONWriter writer(&os); GetJSON(&writer); LoweredOutput ret; ret.graph_json = os.str(); ret.params = params_; for (auto& kv : lowered_funcs_) { if (ret.lowered_funcs.count(kv.first) == 0) { ret.lowered_funcs.Set(kv.first, Array()); } auto& vec = ret.lowered_funcs[kv.first]; Array tmp; for (auto f : kv.second) { tmp.push_back(f); } for (auto f : vec) { tmp.push_back(f); } ret.lowered_funcs.Set(kv.first, tmp); } return ret; } protected: /*! * \brief Extract shape from expr to vector * * \param shape * \return std::vector */ std::vector _ShapeToJSON(tvm::Array shape) { std::vector ret; for (IndexExpr dim : shape) { const int64_t* pval = as_const_int(dim); ret.push_back(*pval); } return ret; } /*! * \brief Add node to graph * * \param node * \param expr * \return std::vector<_NodeRef> */ std::vector AddNode(GraphNodePtr node, Expr expr) { auto checked_type = expr->checked_type(); size_t count = storage_device_map_.count(expr); CHECK_GT(count, 0) << "Expr is not existing in storage plan"; auto storage_device_info = storage_device_map_[expr]; CHECK_EQ(storage_device_info.size(), 2); // storage std::vector storage_info; for (auto& v : storage_device_info[0]) { storage_info.push_back(v->value); } node->attrs_["storage_id"] = std::move(storage_info); // type std::vector device_types; for (auto& v : storage_device_info[1]) { device_types.push_back(v->value); } size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) { LOG(FATAL) << "The graph contains not annotated nodes for " << "heterogeneous execution. All nodes must be " << "annotated."; } if (num_unknown_devices == 0) { node->attrs_["device_index"] = device_types; } auto node_id = nodes_.size(); nodes_.push_back(node); // Tuple return value, flatten as tuple if (const auto* tuple_type = checked_type.as()) { std::vector ret; ShapeVector shape; std::vector dtype; for (size_t i = 0; i < tuple_type->fields.size(); ++i) { if (const auto* typ = tuple_type->fields[i].as()) { ret.push_back(GraphNodeRef(node_id, i)); shape.emplace_back(_ShapeToJSON(typ->shape)); dtype.emplace_back(DType2String(typ->dtype)); } else { LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported"; } } CHECK_EQ(node->Type(), kGraphOpNode); auto op_nd = std::dynamic_pointer_cast(node); op_nd->attrs_["shape"] = shape; op_nd->attrs_["dtype"] = dtype; op_nd->num_outputs_ = tuple_type->fields.size(); return ret; } // Normal tensor return type if (const auto* tensor_type = checked_type.as()) { ShapeVector shape; std::vector dtype; shape.emplace_back(_ShapeToJSON(tensor_type->shape)); dtype.emplace_back(DType2String(tensor_type->dtype)); node->attrs_["shape"] = shape; node->attrs_["dtype"] = dtype; } else { LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported"; } return {GraphNodeRef(node_id, 0)}; } /*! \brief Visitors */ std::unordered_map, NodeHash, NodeEqual> visitor_cache_; std::vector VisitExpr(const Expr& expr) override { if (visitor_cache_.count(expr)) return visitor_cache_.at(expr); std::vector res; if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } else if (expr.as()) { res = VisitExpr_(expr.as()); } visitor_cache_[expr] = res; return res; } std::vector VisitExpr_(const VarNode* op) override { Expr expr = GetRef(op); return var_map_[expr.get()]; } std::vector VisitExpr_(const ConstantNode* op) override { Expr expr = GetRef(op); size_t index = params_.size(); std::string name = "p" + std::to_string(index); params_[name] = op->data; auto node = GraphInputNode::make_node_ptr(name, GraphAttrs()); return AddNode(node, expr); } std::vector VisitExpr_(const TupleNode* op) override { std::vector fields; for (auto field : op->fields) { auto ref_vec = VisitExpr(field); for (auto ref : ref_vec) { fields.push_back(ref); } } return fields; } std::vector VisitExpr_(const CallNode* op) override { Expr expr = GetRef(op); Function func; if (op->op.as()) { LOG(FATAL) << "Operators should be transformed away; try applying" << "the fuse_ops transformation to the expression."; } else if (op->op.as()) { LOG(FATAL) << "Not implemented"; } else if (op->op.as()) { func = GetRef(op->op.as()); } else { LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); } if (!func->IsPrimitive()) { LOG(FATAL) << "TVM only support calls to primitive functions " << "(i.e functions composed of fusable operator invocations)"; } CHECK_GE(storage_device_map_.count(expr), 0); auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); auto &device_type = storage_device_map_[expr][1]; auto call_dev_type = device_type[0]->value; Target target; if (targets_.size() == 1) { // homogeneous execution. for (auto kv : targets_) { target = kv.second; } } else { // heterogeneous execution. std::string call_dev_name; if (call_dev_type == 0) { call_dev_name = "llvm"; } else { call_dev_name = runtime::DeviceName(call_dev_type); } if (targets_.count(call_dev_type) == 0) { LOG(FATAL) << "No target is provided for device " << call_dev_name; } target = targets_[call_dev_type]; } CCacheKey key = (*pf0)(func, target); CachedFunc lowerd_func = (*pf1)(compile_engine_, key); if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = {}; } for (auto f : lowerd_func->funcs) { lowered_funcs_[target->str()].insert(f); } std::vector inputs; for (auto arg : op->args) { auto res = VisitExpr(arg); for (auto nr : res) { inputs.push_back(nr); } } auto& op_name = lowerd_func->func_name; auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name), GraphAttrs(), op_name, inputs, GraphAttrs()); return AddNode(node, expr); } std::vector VisitExpr_(const LetNode* op) override { CHECK_EQ(var_map_.count(op->var.get()), 0); var_map_[op->var.get()] = VisitExpr(op->value); return VisitExpr(op->body); } std::vector VisitExpr_(const TupleGetItemNode* op) override { auto vtuple = VisitExpr(op->tuple); return {vtuple[op->index]}; } std::vector VisitExpr_(const OpNode* op) override { throw std::runtime_error("can not compile op in non-eta expanded form"); return {}; } std::vector VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); return {}; } std::vector VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); return {}; } std::vector VisitExpr_(const FunctionNode* op) override { throw std::invalid_argument("function not supported"); return {}; } std::vector VisitExpr_(const RefCreateNode* op) override { throw std::invalid_argument("reference not supported"); return {}; } std::vector VisitExpr_(const RefReadNode* op) override { throw std::invalid_argument("reference not supported"); return {}; } std::vector VisitExpr_(const RefWriteNode* op) override { throw std::invalid_argument("reference not supported"); return {}; } std::vector VisitExpr_(const ConstructorNode* op) override { throw std::invalid_argument("ADT constructor case not yet implemented"); return {}; } std::vector VisitExpr_(const MatchNode* op) override { throw std::invalid_argument("match case not yet implemented"); return {}; } /*! * \brief Generate Graph JSON * * \param writer json writer */ void GetJSON(dmlc::JSONWriter* writer) { std::vector arg_nodes; for (size_t i = 0; i < nodes_.size(); ++i) { auto node = nodes_[i]; if (node->Type() == kGraphInputNode) { arg_nodes.push_back(i); } } size_t num_entry = 0; ShapeVector shapes; std::vector storage_ids; std::vector device_types; std::vector dltypes; std::vector node_row_ptr{0}; for (auto node : nodes_) { const auto& shape_vec = dmlc::get(node->attrs_["shape"]); const auto& storage_id = dmlc::get>(node->attrs_["storage_id"]); const auto& dtype_vec = dmlc::get>(node->attrs_["dtype"]); CHECK_EQ(node->num_outputs_, shape_vec.size()); num_entry += node->num_outputs_; shapes.insert(shapes.end(), shape_vec.begin(), shape_vec.end()); dltypes.insert(dltypes.end(), dtype_vec.begin(), dtype_vec.end()); storage_ids.insert(storage_ids.end(), storage_id.begin(), storage_id.end()); if (node->attrs_.count("device_index")) { const auto& dev_types = dmlc::get>(node->attrs_["device_index"]); device_types.insert(device_types.end(), dev_types.begin(), dev_types.end()); } node_row_ptr.push_back(num_entry); } writer->BeginObject(); writer->WriteObjectKeyValue("nodes", nodes_); writer->WriteObjectKeyValue("arg_nodes", arg_nodes); writer->WriteObjectKeyValue("heads", heads_); std::unordered_map> attrs; attrs["shape"].emplace_back(std::string("list_shape")); attrs["shape"].emplace_back(shapes); attrs["storage_id"].emplace_back(std::string("list_int")); attrs["storage_id"].emplace_back(storage_ids); if (device_types.size()) { attrs["device_index"].emplace_back(std::string("list_int")); attrs["device_index"].emplace_back(device_types); } attrs["dltype"].emplace_back(std::string("list_str")); attrs["dltype"].emplace_back(dltypes); writer->WriteObjectKeyValue("attrs", attrs); writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr); writer->EndObject(); } /*! * \brief Get unique name for func * * \param name * \return std::string */ std::string _GetUniqueName(const std::string& name) { if (!name_map_.count(name)) { name_map_[name] = 1; return name; } auto index = name_map_[name]; name_map_[name] += 1; return _GetUniqueName(name + std::to_string(index)); } protected: /*! \brief nodes */ std::vector nodes_; /*! \brief output of graph */ std::vector heads_; /*! \brief mod */ runtime::Module* mod_; /*! \brief variable map */ std::unordered_map> var_map_; /*! \brief target device */ TargetsMap targets_; /*! \brief params */ std::unordered_map params_; /*! \brief plan memory of device result */ Map> storage_device_map_; /*! \brief lowered funcs */ std::unordered_map> lowered_funcs_; /*! \brief name map */ std::unordered_map name_map_; /*! \brief compile engine */ CompileEngine compile_engine_; }; class GraphRuntimeCodegenModule : public runtime::ModuleNode { public: GraphRuntimeCodegenModule() {} virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "init") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; Map tmp = args[1]; TargetsMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); CHECK(dev_type); targets[dev_type->value] = it.second; } codegen_ = std::make_shared( reinterpret_cast(mod), targets); }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Function func = args[0]; this->output_ = this->codegen_->Codegen(func); }); } else if (name == "get_graph_json") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; }); } else if (name == "list_params_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Array ret; for (const auto &kv : this->output_.params) { tvm::Expr name = ir::StringImm::make(kv.first); ret.push_back(name); } *rv = ret; }); } else if (name == "get_param_by_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string key = args[0]; CHECK_GT(this->output_.params.count(key), 0); *rv = this->output_.params[key]; }); } else if (name == "get_lowered_funcs") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.lowered_funcs; }); } else { return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); } } const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } private: std::shared_ptr codegen_; LoweredOutput output_; }; runtime::Module CreateGraphCodegenMod() { auto ptr = make_object(); return runtime::Module(ptr); } TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); }); } // namespace backend } // namespace relay } // namespace tvm namespace dmlc { namespace json { // JSON utils template inline bool SameType(const dmlc::any& data) { return std::type_index(data.type()) == std::type_index(typeid(T)); } template <> struct Handler> { inline static void Write(dmlc::JSONWriter* writer, const std::shared_ptr& data) { data->Save(writer); } inline static void Read(dmlc::JSONReader* reader, std::shared_ptr* data) { LOG(FATAL) << "Not implemented."; } }; template <> struct Handler> { inline static void Write(dmlc::JSONWriter* writer, const std::unordered_map& data) { writer->BeginObject(); for (const auto& kv : data) { auto k = kv.first; const dmlc::any& v = kv.second; if (SameType(v)) { writer->WriteObjectKeyValue(k, dmlc::get(v)); } else if (SameType(v)) { writer->WriteObjectKeyValue(k, dmlc::get(v)); } else if (SameType>(v)) { writer->WriteObjectKeyValue(k, dmlc::get>(v)); } else if (SameType>>(v)) { writer->WriteObjectKeyValue(k, dmlc::get>>(v)); } else if (SameType>(v)) { writer->WriteObjectKeyValue(k, dmlc::get>(v)); } else { LOG(FATAL) << "Not supported"; } } writer->EndObject(); } inline static void Read(dmlc::JSONReader* reader, std::unordered_map* data) { LOG(FATAL) << "Not implemented."; } }; template <> struct Handler> { inline static void Write(dmlc::JSONWriter* writer, const std::vector& data) { writer->BeginArray(); for (const auto& v : data) { if (SameType(v)) { writer->WriteArrayItem(dmlc::get(v)); } else if (SameType(v)) { writer->WriteArrayItem(dmlc::get(v)); } else if (SameType>(v)) { writer->WriteArrayItem(dmlc::get>(v)); } else if (SameType>>(v)) { writer->WriteArrayItem(dmlc::get>>(v)); } else if (SameType>(v)) { writer->WriteArrayItem(dmlc::get>(v)); } else { LOG(FATAL) << "Not supported"; } } writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, std::vector* data) { LOG(FATAL) << "Not implemented."; } }; } // namespace json } // namespace dmlc