1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file pretty_printer.cc
22  * \brief Pretty printer for Relay programs
23  * Supports ANF, GNF, and metadata.
24  *
25  * Inlining heuristics:
26  *  - Always inline:
27  *    - GlobalVar
28  *    - Constant
29  *    - Op
30  *    - Var
31  *  - Otherwise, inline if the node is at the end of a scope and is used at most once.
32  */
33 
34 #include <tvm/node/serialization.h>
35 #include <tvm/relay/expr_functor.h>
36 #include <tvm/relay/module.h>
37 #include <tvm/relay/pattern_functor.h>
38 #include "doc.h"
39 #include "type_functor.h"
40 #include "../pass/dependency_graph.h"
41 #include "../../lang/attr_functor.h"
42 
43 namespace tvm {
44 namespace relay {
45 
46 static const char* kSemVer = "v0.0.4";
47 
Brace(const Doc & d,const std::string & open="{",const std::string & close="}",int indent=2)48 Doc Brace(const Doc& d,
49           const std::string& open = "{",
50           const std::string& close = "}",
51           int indent = 2) {
52   Doc doc;
53   doc << open;
54   doc << Indent(indent, PrintNewLine() << d) << PrintNewLine();
55   doc << close;
56   return doc;
57 }
58 
59 /*!
60  * \brief Meta data context for PrettyPrinter.
61  *
62  * This is an important part to enable bi-directional serializability.
63  * We use tvm's Node system to build the current IR.
64  * It can be hard to design a text format for all the possible nodes
65  * as the set of nodes can grow when we do more extensions.
66  *
67  * Instead of trying to design readable text format for every node,
68  * we support a meta data section in the text format.
69  * We allow the text format to refer to a node in the meta data section.
70  *
71  * The meta data section is a json serialized string of an Map<string, Array<NodeRef>>.
72  * Each element in the meta data section can be referenced by the text format.
73  * Each meta data node is printed in the following format.
74  *
75  * meta[type-key-of-node>][<index-in-meta-section>]
76  *
77  * Specifically, consider the following IR(constructed by python).
78  *
79  * \code
80  *
81  * n = tvm.var("n")
82  * x = tvm.relay.var("x", shape=(n, 1))
83  * f = tvm.relay.Function([x], x)
84  * print(f.astext())
85  *
86  * \endcode
87  *
88  * The corresponding text format is shown in the following code block.
89  *
90  * \code
91  *
92  * fn (%x: Tensor[(meta[Variable][0],), float32]) {
93  *   %x
94  * }
95  * # Meta data section is a json-serialized string
96  * # of the following array.
97  * # [tvm.var("n")]
98  *
99  * \endcode
100  *
101  * Note that we store tvm.var("n") in the meta data section.
102  * Since it is stored in the index-0 in the meta data section,
103  * we print it as meta[Variable][0].
104  *
105  * The text parser can recover this object by loading from the corresponding
106  * location in the meta data section.
107  *
108  * This is is a design trade-off.
109  * It allows us to embedded any meta data in the text format,
110  * while still being able to tweak the text part of the printed IR easily.
111  */
112 class TextMetaDataContext {
113  public:
114   /*!
115    * \brief Get text representation of meta node.
116    * \param node The node to be converted to meta node.
117    * \return A string representation of the meta node.
118    */
GetMetaNode(const NodeRef & node)119   Doc GetMetaNode(const NodeRef& node) {
120     auto it = meta_repr_.find(node);
121     if (it != meta_repr_.end()) {
122       return it->second;
123     }
124     std::string type_key = node->GetTypeKey();
125     CHECK(!type_key.empty());
126     Array<NodeRef>& mvector =
127         meta_data_[type_key];
128     int64_t index = static_cast<int64_t>(mvector.size());
129     mvector.push_back(node);
130     Doc doc;
131     doc << "meta[" << type_key << "][" << index << "]";
132     meta_repr_[node] = doc;
133     return meta_repr_[node];
134   }
135 
PrintKeyValue(const std::string & str,const Doc & v) const136   Doc PrintKeyValue(const std::string& str, const Doc& v) const {
137     return Doc("\"") << str << "\": " << v;
138   }
139 
140   /*!
141    * \brief Get the metadata section in json format.
142    * \return the meta data string.
143    */
GetMetaSection() const144   Doc GetMetaSection() const {
145     if (meta_data_.size() == 0) return Doc();
146     return Doc(SaveJSON(Map<std::string, NodeRef>(meta_data_.begin(), meta_data_.end())));
147   }
148 
149   /*! \return whether the meta data context is empty. */
empty() const150   bool empty() const {
151     return meta_data_.empty();
152   }
153 
154  private:
155   /*! \brief additional metadata stored in TVM json format */
156   std::unordered_map<std::string, Array<NodeRef> > meta_data_;
157   /*! \brief map from meta data into its string representation */
158   std::unordered_map<NodeRef, Doc, NodeHash, NodeEqual> meta_repr_;
159 };
160 
161 class PrettyPrinter :
162     public ExprFunctor<Doc(const Expr&)>,
163     public PatternFunctor<Doc(const Pattern&)>,
164     public TypeFunctor<Doc(const Type&)>,
165     public AttrFunctor<Doc(const ObjectRef&)> {
166  public:
PrettyPrinter(bool show_meta_data,runtime::TypedPackedFunc<std::string (Expr)> annotate)167   explicit PrettyPrinter(bool show_meta_data,
168                          runtime::TypedPackedFunc<std::string(Expr)> annotate) :
169                          show_meta_data_(show_meta_data),
170                          annotate_(annotate) {}
171 
172   /*!
173     * \brief Print additional info about expr in comment.
174     * \param expr The expression.
175     */
PrintOptionalInfo(const Expr & expr)176   Doc PrintOptionalInfo(const Expr& expr) {
177     Doc doc;
178     // default annotations
179     if (annotate_ == nullptr) {
180       if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
181         doc << " /* ty=" << Print(expr->checked_type()) << " */";
182       }
183     } else {
184       std::string annotated_expr = annotate_(expr);
185       if (annotated_expr != "") {
186         doc << annotated_expr;
187       }
188     }
189 
190     return doc;
191   }
192 
193   // indent a new body
PrintBody(const NodeRef & node,int indent=2)194   Doc PrintBody(const NodeRef& node, int indent = 2) {
195     Doc doc;
196     Doc body;
197     doc << "{";
198     doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine();
199     doc << "}";
200     return doc;
201   }
202 
203   // create a new scope by creating a new printer object. This allows temp var
204   // numbers to be reused and prevents hoisted vars from escaping too far
PrintScope(const NodeRef & node)205   Doc PrintScope(const NodeRef& node) {
206     // print in a new scope
207     doc_stack_.push_back(Doc());
208     // must print first so doc_stack_.back() reference doesn't become stale
209     Doc doc = Print(node, false, true);
210     doc = doc_stack_.back() << doc;
211     doc_stack_.pop_back();
212     return doc;
213   }
214 
PrintFinal(const NodeRef & node)215   Doc PrintFinal(const NodeRef& node) {
216     if (node.as<ExprNode>()) {
217       Expr expr = Downcast<Expr>(node);
218       dg_ = DependencyGraph::Create(&arena_, expr);
219     }
220 
221     Doc doc;
222     doc << PrintScope(node);
223     if (!meta_.empty()) {
224       doc << PrintNewLine();
225       if (show_meta_data_) {
226         // append meta data in the end.
227         doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection();
228       } else {
229         doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
230       }
231     }
232     return doc;
233   }
234 
235   std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
236   std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
237 
Print(const NodeRef & node,bool meta=false,bool try_inline=false)238   Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
239     if (node.as<ExprNode>()) {
240       return PrintExpr(Downcast<Expr>(node), meta, try_inline);
241     } else if (node.as<TypeNode>()) {
242       return PrintType(Downcast<Type>(node), meta);
243     } else if (node.as<PatternNode>()) {
244       return PrintPattern(Downcast<Pattern>(node), meta);
245     } else if (node.as<ModuleNode>()) {
246       return PrintMod(Downcast<Module>(node));
247     } else {
248       Doc doc;
249       return doc << node;
250     }
251   }
252 
TempVar(int n)253   Doc TempVar(int n) {
254     Doc doc;
255     return doc << "%" << n;
256   }
257 
AllocTemp()258   Doc AllocTemp() {
259     return TempVar(temp_var_counter_++);
260   }
261 
262   /*!
263     * \brief get a unique name with the corresponding prefix
264     * \param prefix The prefix of the name
265     * \return The returned name.
266     */
GetUniqueName(const std::string & prefix)267   Doc GetUniqueName(const std::string& prefix) {
268     std::string unique_prefix = prefix;
269     auto it = name_alloc_map_.find(prefix);
270     if (it != name_alloc_map_.end()) {
271       while (true) {
272         std::ostringstream os;
273         os << prefix << (++it->second);
274         std::string name = os.str();
275         if (name_alloc_map_.count(name) == 0) {
276           unique_prefix = name;
277           break;
278         }
279       }
280     }
281     name_alloc_map_[unique_prefix] = 0;
282     return Doc(unique_prefix);
283   }
284 
Print(Kind k)285   Doc Print(Kind k) {
286     switch (k) {
287     case kType:
288       return Doc("Type");
289     case kShapeVar:
290       return Doc("Shape");
291     case kBaseType:
292       return Doc("BaseType");
293     case kConstraint:
294       return Doc("Constraint");
295     case kAdtHandle:
296       return Doc("AdtHandle");
297     case kTypeData:
298       return Doc("TypeData");
299     default:
300       LOG(ERROR) << "Unknown Kind";
301       throw;
302     }
303   }
304   /*!
305    * \brief Allocate name to a type variable.
306    * \param var The input type variable.
307    * \return The corresponding name.
308    */
AllocTypeVar(const TypeVar & var)309   Doc AllocTypeVar(const TypeVar& var) {
310     if (memo_type_.count(var)) {
311       Doc val = memo_type_[var];
312       val << "-malformed-ir";
313       return val;
314     }
315     std::string name = var->var->name_hint;
316     if (name.length() == 0 || !std::isalpha(name[0])) {
317       name = "t" + name;
318     }
319     Doc val = GetUniqueName(name);
320     memo_type_[var] = val;
321     if (var->kind != kType) {
322       val << ": " << Print(var->kind);
323     }
324     return val;
325   }
326 
327   /*!
328    * \brief Allocate name to a variable.
329    * \param var The input variable.
330    * \return The corresponding name.
331    */
AllocVar(const Var & var)332   Doc AllocVar(const Var& var) {
333     // still print if ir is malformed, but show the error.
334     if (memo_.count(var)) {
335       Doc val = memo_[var];
336       val << "-malformed-ir";
337       return val;
338     }
339     std::string name = var->name_hint();
340     // always make sure first name is alpha
341     if (name.length() == 0 || !std::isalpha(name[0])) {
342       name = "v" + name;
343     }
344     Doc val = GetUniqueName("%" + name);
345     memo_[var] = val;
346     if (var->type_annotation.defined()) {
347       val << ": " << Print(var->type_annotation);
348     }
349     return val;
350   }
351 
IsUnique(const Expr & expr)352   bool IsUnique(const Expr& expr) {
353     auto it = dg_.expr_node.find(expr);
354     if (it == dg_.expr_node.end()) {
355       return true;
356     } else {
357       return !(it->second->parents.head && it->second->parents.head->next);
358     }
359   }
360 
AlwaysInline(const Expr & expr)361   bool AlwaysInline(const Expr& expr) {
362     return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
363            expr.as<VarNode>() || expr.as<ConstructorNode>();
364   }
365 
366   //------------------------------------
367   // Overload of Expr printing functions
368   //------------------------------------
PrintExpr(const Expr & expr,bool meta,bool try_inline)369   Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
370     // Exploit memoization to print GNF.
371     // The first time we visit an expression, we need to allocate a temp var
372     // for it. Every subsequent time we can just use its assigned variable.
373     // This works since hashing uses pointer equality.
374 
375     // determine whether to inline
376     bool inline_expr = AlwaysInline(expr);
377     if (try_inline) {
378       inline_expr |= IsUnique(expr);
379     }
380 
381     auto it = memo_.find(expr);
382     if (it != memo_.end()) return it->second;
383 
384     Doc printed_expr;
385     if (meta) {
386       printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get()));
387     } else if (!inline_expr && expr.as<LetNode>()) {
388       // wrap GNFed let in brackets
389       Doc body;
390       printed_expr << "(";
391       printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
392       printed_expr << ")";
393     } else {
394       printed_expr = VisitExpr(expr);
395     }
396 
397     printed_expr << PrintOptionalInfo(expr);
398 
399     // add expr to doc
400     if (expr.as<VarNode>()) {
401       // This is our first time visiting the var and we hit the VarNode case
402       // in the visitor. Thus the variable is free.
403       doc_stack_.back() << "free_var " << printed_expr << PrintNewLine();
404       // Memoization is done in AllocVar.
405       return memo_[expr];
406     } else if (inline_expr) {
407       memo_[expr] = printed_expr;
408       return printed_expr;
409     } else {
410       Doc temp_var = AllocTemp();
411       memo_[expr] = temp_var;
412       doc_stack_.back() << temp_var << " = " << printed_expr << ";" << PrintNewLine();
413       return temp_var;
414     }
415   }
416 
417   // Should only be triggered when op is a free variable being visited for the
418   // first time.
VisitExpr_(const VarNode * op)419   Doc VisitExpr_(const VarNode* op) final {
420     return AllocVar(GetRef<Var>(op));
421   }
422 
VisitExpr_(const ConstantNode * op)423   Doc VisitExpr_(const ConstantNode* op) final {
424     // Print out simple scalars directly.
425     if (op->is_scalar()) {
426       std::ostringstream os;
427       DataType dtype = TVMType2Type(op->data->dtype);
428       CHECK_EQ(op->data->ctx.device_type, kDLCPU);
429       if (dtype == Int(32)) {
430         return PrintConstScalar(dtype, static_cast<const int32_t*>(op->data->data));
431       } else if (dtype == Int(64)) {
432         return PrintConstScalar(dtype, static_cast<const int64_t*>(op->data->data));
433       } else if (dtype == Float(32)) {
434         return PrintConstScalar(dtype, static_cast<const float*>(op->data->data));
435       } else if (dtype == Float(64)) {
436         return PrintConstScalar(dtype, static_cast<const double*>(op->data->data));
437       } else if (dtype == Bool()) {
438         return PrintConstScalar(dtype, static_cast<const uint8_t*>(op->data->data));
439       }
440     }
441     // default fall-back, record it as meta node.
442     Doc doc;
443     return doc << Print(GetRef<NodeRef>(op), true);
444   }
445 
VisitExpr_(const TupleNode * op)446   Doc VisitExpr_(const TupleNode* op) final {
447     std::vector<Doc> fields;
448     for (Expr field : op->fields) {
449       fields.push_back(Print(field));
450     }
451     Doc doc;
452     doc << "(" << PrintSep(fields);
453     // conform to python tuple format (1,)
454     if (op->fields.size() == 1) {
455       doc << ",";
456     }
457     return doc << ")";
458   }
459 
VisitExpr_(const TupleGetItemNode * op)460   Doc VisitExpr_(const TupleGetItemNode* op) final {
461     Doc doc;
462     return doc << Print(op->tuple) << "." << op->index;
463   }
464 
VisitExpr_(const IfNode * op)465   Doc VisitExpr_(const IfNode* op) final {
466     Doc doc;
467     doc << "if (" << Print(op->cond) << ") ";
468     doc << PrintBody(op->true_branch);
469     doc << " else ";
470     doc << PrintBody(op->false_branch);
471     return doc;
472   }
473 
VisitExpr_(const LetNode * op)474   Doc VisitExpr_(const LetNode* op) final {
475     Doc doc;
476     doc
477       << "let "
478       << AllocVar(op->var)
479       << " = "
480       << Print(op->value, false, true)
481       << ";"
482       << PrintNewLine();
483     // we use a scope here so GNF hoisting doesn't escape too far
484     // and nested, unique lets are not hoisted
485     doc << PrintScope(op->body);
486     return doc;
487   }
488 
PrintFunc(const Doc & prefix,const Function & fn)489   Doc PrintFunc(const Doc& prefix, const Function& fn) {
490     Doc doc;
491     doc << prefix;
492     if (fn->type_params.size() > 0) {
493       doc << "[";
494       std::vector<Doc> type_params;
495       for (const TypeVar& tv : fn->type_params) {
496         type_params.push_back(Doc(tv->var->name_hint));
497       }
498       doc << PrintSep(type_params);
499       doc << "]";
500     }
501     doc << "(";
502     std::vector<Doc> params;
503     for (Var param : fn->params) {
504       params.push_back(AllocVar(param));
505     }
506     for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
507       params.push_back(d);
508     }
509     doc << PrintSep(params) << ") ";
510     if (fn->ret_type.defined()) {
511       doc << "-> " << Print(fn->ret_type) << " ";
512     }
513     doc << PrintBody(fn->body);
514     return doc;
515   }
516 
PrintMod(const Module & mod)517   Doc PrintMod(const Module& mod) {
518     Doc doc;
519     int counter = 0;
520     // type definitions
521     for (const auto& kv : mod->type_definitions) {
522       if (counter++ != 0) {
523         doc << PrintNewLine();
524       }
525       doc << Print(kv.second);
526       doc << PrintNewLine();
527     }
528     // functions
529     for (const auto& kv : mod->functions) {
530       dg_ = DependencyGraph::Create(&arena_, kv.second);
531 
532       if (counter++ != 0) {
533         doc << PrintNewLine();
534       }
535       std::ostringstream os;
536       os << "def @" << kv.first->name_hint;
537       doc << PrintFunc(Doc(os.str()), kv.second);
538       doc << PrintNewLine();
539     }
540     return doc;
541   }
542 
VisitExpr_(const FunctionNode * op)543   Doc VisitExpr_(const FunctionNode* op) final {
544     return PrintFunc(Doc("fn "), GetRef<Function>(op));
545   }
546 
VisitExpr_(const GlobalVarNode * op)547   Doc VisitExpr_(const GlobalVarNode* op) final {
548     return Doc('@' + op->name_hint);
549   }
550 
VisitExpr_(const OpNode * op)551   Doc VisitExpr_(const OpNode* op) final {
552     return Doc(op->name);
553   }
554 
VisitExpr_(const CallNode * op)555   Doc VisitExpr_(const CallNode* op) final {
556     Doc doc;
557     // visit args first so they are lifted before the op
558     // this places op closer to its call site
559     std::vector<Doc> args;
560     for (const Expr& arg : op->args) {
561       args.push_back(Print(arg));
562     }
563     for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
564       args.push_back(d);
565     }
566     const auto* cons_node = op->op.as<ConstructorNode>();
567     if (cons_node) {
568       doc << cons_node->name_hint;
569     } else {
570       doc << Print(op->op);
571     }
572 
573     if (cons_node && cons_node->inputs.size() == 0) {
574       // don't print as a call if it's a 0-arity cons
575       return doc;
576     } else {
577       return doc << "(" << PrintSep(args) << ")";
578     }
579   }
580 
VisitExpr_(const RefCreateNode * op)581   Doc VisitExpr_(const RefCreateNode* op) final {
582     Doc doc;
583     return doc << "ref(" << Print(op->value) << ")";
584   }
585 
VisitExpr_(const RefReadNode * op)586   Doc VisitExpr_(const RefReadNode* op) final {
587     Doc doc;
588     return doc << Print(op->ref) << "^";
589   }
590 
VisitExpr_(const RefWriteNode * op)591   Doc VisitExpr_(const RefWriteNode* op) final {
592     Doc doc;
593     return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")";
594   }
595 
VisitExpr_(const MatchNode * op)596   Doc VisitExpr_(const MatchNode* op) final {
597     // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
598     Doc doc;
599     Doc body;
600     doc << "match";
601     if (!op->complete) {
602       doc << "?";
603     }
604     doc << " (" << Print(op->data) << ") {";
605     std::vector<Doc> clause_docs;
606     for (const auto& clause : op->clauses) {
607       Doc clause_doc;
608       clause_doc << PrintPattern(clause->lhs, false) << " => ";
609       Doc rhs_doc = PrintScope(clause->rhs);
610       if (clause->rhs.as<LetNode>()) {
611         // only add braces if there are multiple lines on the rhs
612         rhs_doc = Brace(rhs_doc);
613       }
614       clause_doc << rhs_doc << ",";
615       clause_docs.push_back(clause_doc);
616     }
617     doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine()))
618         << PrintNewLine() << "}";
619     return doc;
620   }
621 
PrintPattern(const Pattern & pattern,bool meta)622   Doc PrintPattern(const Pattern& pattern, bool meta) {
623     auto it = memo_pattern_.find(pattern);
624     if (it != memo_pattern_.end()) return it->second;
625     Doc printed_pattern;
626     if (meta) {
627       printed_pattern = meta_.GetMetaNode(GetRef<NodeRef>(pattern.get()));
628     } else {
629       printed_pattern = VisitPattern(pattern);
630     }
631     memo_pattern_[pattern] = printed_pattern;
632     return printed_pattern;
633   }
634 
VisitPattern_(const PatternConstructorNode * p)635   Doc VisitPattern_(const PatternConstructorNode* p) final {
636     Doc doc;
637     doc << p->constructor->name_hint;
638     if (!p->patterns.empty()) {
639       doc << "(";
640       std::vector<Doc> pats;
641       for (const auto& pat : p->patterns) {
642         pats.push_back(Print(pat));
643       }
644       doc << PrintSep(pats) << ")";
645     }
646     return doc;
647   }
648 
VisitPattern_(const PatternTupleNode * pt)649   Doc VisitPattern_(const PatternTupleNode* pt) final {
650     Doc doc;
651     doc << "(";
652     std::vector<Doc> pats;
653     for (const auto& pat : pt->patterns) {
654       pats.push_back(Print(pat));
655     }
656     doc << PrintSep(pats) << ")";
657     return doc;
658   }
659 
VisitPattern_(const PatternWildcardNode * pw)660   Doc VisitPattern_(const PatternWildcardNode* pw) final {
661     return Doc("_");
662   }
663 
VisitPattern_(const PatternVarNode * pv)664   Doc VisitPattern_(const PatternVarNode* pv) final {
665     return AllocVar(pv->var);
666   }
667 
VisitExpr_(const ConstructorNode * n)668   Doc VisitExpr_(const ConstructorNode* n) final {
669     Doc doc;
670     doc << n->name_hint;
671     if (in_adt_def_ && n->inputs.size() != 0) {
672       doc << "(";
673       std::vector<Doc> inputs;
674       for (Type input : n->inputs) {
675         inputs.push_back(Print(input));
676       }
677       doc << PrintSep(inputs) << ")";
678     }
679     return doc;
680   }
681 
682   //------------------------------------
683   // Overload of Type printing functions
684   //------------------------------------
PrintType(const Type & type,bool meta)685   Doc PrintType(const Type& type, bool meta) {
686     auto it = memo_type_.find(type);
687     if (it != memo_type_.end()) return it->second;
688     Doc printed_type;
689     if (meta) {
690       printed_type = meta_.GetMetaNode(GetRef<NodeRef>(type.get()));
691     } else {
692       printed_type = VisitType(type);
693     }
694     memo_type_[type] = printed_type;
695     return printed_type;
696   }
697 
VisitTypeDefault_(const Node * node)698   Doc VisitTypeDefault_(const Node* node) final {
699     // by default always print as meta data
700     return Print(GetRef<NodeRef>(node), true);
701   }
702 
VisitType_(const TypeVarNode * node)703   Doc VisitType_(const TypeVarNode* node) final {
704     return Doc(node->var->name_hint);
705   }
706 
VisitType_(const GlobalTypeVarNode * node)707   Doc VisitType_(const GlobalTypeVarNode* node) final {
708     return Doc(node->var->name_hint);
709   }
710 
VisitType_(const TypeCallNode * node)711   Doc VisitType_(const TypeCallNode* node) final {
712     Doc doc = PrintType(node->func, false);
713     std::vector<Doc> args;
714     for (const Type& t : node->args) {
715       args.push_back(PrintType(t, false));
716     }
717     doc << "[";
718     doc << PrintSep(args);
719     doc << "]";
720     return doc;
721   }
722 
VisitType_(const TensorTypeNode * node)723   Doc VisitType_(const TensorTypeNode* node) final {
724     // scalar type
725     if (node->shape.size() == 0) {
726       return PrintDType(node->dtype);
727     }
728     Doc doc;
729     doc << "Tensor[(";
730     std::vector<Doc> shapes;
731     for (NodeRef shape : node->shape) {
732       shapes.push_back(PrintAttr(shape));
733     }
734     doc << PrintSep(shapes);
735     return doc << "), " << PrintDType(node->dtype) << "]";
736   }
737 
VisitType_(const TupleTypeNode * node)738   Doc VisitType_(const TupleTypeNode* node) final {
739     std::vector<Doc> fields;
740     for (Type field : node->fields) {
741       fields.push_back(Print(field));
742     }
743     Doc doc;
744     doc << "(" << PrintSep(fields);
745     // conform to python tuple format (1,)
746     if (node->fields.size() == 1) {
747       doc << ",";
748     }
749     return doc << ")";
750   }
751 
VisitType_(const FuncTypeNode * node)752   Doc VisitType_(const FuncTypeNode* node) final {
753     Doc doc;
754     doc << "fn ";
755     if (node->type_params.size() != 0) {
756       doc << "[";
757       std::vector<Doc> type_params;
758       for (Type type_param : node->type_params) {
759         type_params.push_back(Print(type_param));
760       }
761       doc << PrintSep(type_params);
762       doc << "]";
763     }
764     std::vector<Doc> arg_types;
765     for (Type arg_type : node->arg_types) {
766       arg_types.push_back(Print(arg_type));
767     }
768     return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
769   }
770 
VisitType_(const RefTypeNode * node)771   Doc VisitType_(const RefTypeNode* node) final {
772     Doc doc;
773     return doc << "ref(" << Print(node->value) << ")";
774   }
775 
VisitType_(const TypeDataNode * node)776   Doc VisitType_(const TypeDataNode* node) final {
777     in_adt_def_ = true;
778     Doc doc;
779     doc << "type " << Print(node->header);
780 
781     // type vars
782     if (node->type_vars.size() != 0) {
783       doc << "[";
784       std::vector<Doc> type_vars;
785       for (Type type_var : node->type_vars) {
786         type_vars.push_back(Print(type_var));
787       }
788       doc << PrintSep(type_vars) << "]";
789     }
790     doc << " ";
791 
792     std::vector<Doc> constructor_docs;
793     for (Constructor constructor : node->constructors) {
794       constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
795     }
796     Doc separator;
797     separator << "," << PrintNewLine();
798     Doc adt_body;
799     adt_body << PrintSep(constructor_docs, separator);
800     // add trailing comma if there are any constructors
801     if (!constructor_docs.empty()) {
802       adt_body << ",";
803     }
804     doc << Brace(adt_body);
805     in_adt_def_ = false;
806     return doc;
807   }
808 
809   //------------------------------------
810   // Overload of Attr printing functions
811   //------------------------------------
812 
PrintAttr(const ObjectRef & value,bool meta=false)813   Doc PrintAttr(const ObjectRef& value, bool meta = false) {
814     if (value.defined()) {
815       Doc printed_attr;
816       if (value.as<tvm::ir::Any>()) {
817         printed_attr << "?";
818       } else if (meta) {
819         printed_attr = meta_.GetMetaNode(Downcast<NodeRef>(value));
820       } else {
821         printed_attr = VisitAttr(value);
822       }
823       return printed_attr;
824     } else {
825       return Doc("None");
826     }
827   }
828 
VisitAttrDefault_(const Object * op)829   Doc VisitAttrDefault_(const Object* op) final {
830     return PrintAttr(GetRef<ObjectRef>(op), true);
831   }
832 
VisitAttr_(const ArrayNode * op)833   Doc VisitAttr_(const ArrayNode* op) final {
834     Doc doc;
835     doc << "[";
836     std::vector<Doc> arr_vals;
837     for (auto val : op->data) {
838       arr_vals.push_back(PrintAttr(val));
839     }
840     doc << PrintSep(arr_vals);
841     doc << "]";
842     return doc;
843   }
844 
VisitAttr_(const ir::IntImm * op)845   Doc VisitAttr_(const ir::IntImm* op) final {
846     return PrintConstScalar(op->type, &(op->value));
847   }
848 
VisitAttr_(const ir::UIntImm * op)849   Doc VisitAttr_(const ir::UIntImm* op) final {
850     return PrintConstScalar(op->type, &(op->value));
851   }
852 
VisitAttr_(const ir::FloatImm * op)853   Doc VisitAttr_(const ir::FloatImm* op) final {
854     return PrintConstScalar(op->type, &(op->value));
855   }
856 
VisitAttr_(const ir::StringImm * op)857   Doc VisitAttr_(const ir::StringImm* op) final {
858     return PrintString(op->value);
859   }
860 
861  private:
862   /*! \brief Whether to print meta data. */
863   bool show_meta_data_;
864   /*! \brief additional comment function */
865   runtime::TypedPackedFunc<std::string(Expr)> annotate_;
866   /*! \brief Stack of docs to implement scoped GNFing. */
867   std::vector<Doc> doc_stack_{};
868   /*! \brief Map from Expr to Doc */
869   std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_;
870   /*! \brief Map from Type to Doc */
871   std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_;
872   /*! \brief Map from Type to Doc */
873   std::unordered_map<Pattern, Doc, NodeHash, NodeEqual> memo_pattern_;
874   /*! \brief name allocation map */
875   std::unordered_map<std::string, int> name_alloc_map_;
876   /*! \brief meta data context */
877   TextMetaDataContext meta_;
878   /*! \brief counter of temporary variable */
879   size_t temp_var_counter_{0};
880   /*! \brief whether the printer is currently in an ADT definition */
881   bool in_adt_def_;
882   /*! \brief arena for dependency graph */
883   common::Arena arena_;
884   /*! \brief dependency graph of the expr */
885   DependencyGraph dg_;
886   class AttrPrinter;
887   friend class AttrPrinter;
888 };
889 
890 /*!
891  * \brief Attribute printer which prints the attributes in the call.
892  */
893 class PrettyPrinter::AttrPrinter : public AttrVisitor {
894  public:
AttrPrinter(std::vector<Doc> * doc,PrettyPrinter * parent)895   AttrPrinter(std::vector<Doc>* doc, PrettyPrinter* parent) : docs(doc), parent_(parent) {}
896 
897   template<typename T>
PrintKV(const char * key,const T & value)898   void PrintKV(const char* key, const T& value) {
899     Doc doc;
900     doc << key << "=" << value;
901     docs->push_back(doc);
902   }
903 
Visit(const char * key,double * value)904   void Visit(const char* key, double* value) final {
905     Doc doc;
906     doc << key << "=" << *value << "f";
907     docs->push_back(doc);
908   }
Visit(const char * key,int64_t * value)909   void Visit(const char* key, int64_t* value) final {
910     PrintKV(key, *value);
911   }
Visit(const char * key,uint64_t * value)912   void Visit(const char* key, uint64_t* value) final {
913     PrintKV(key, *value);
914   }
Visit(const char * key,int * value)915   void Visit(const char* key, int* value) final {
916     PrintKV(key, *value);
917   }
Visit(const char * key,bool * value)918   void Visit(const char* key, bool* value) final {
919     PrintKV(key, PrintBool(*value));
920   }
Visit(const char * key,std::string * value)921   void Visit(const char* key, std::string* value) final {
922     PrintKV(key, PrintString(*value));
923   }
Visit(const char * key,void ** value)924   void Visit(const char* key, void** value) final {
925     LOG(FATAL) << "do not allow void as argument";
926   }
Visit(const char * key,DataType * value)927   void Visit(const char* key, DataType* value) final {
928     PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(*value))));
929   }
Visit(const char * key,runtime::NDArray * value)930   void Visit(const char* key, runtime::NDArray* value) final {
931     LOG(FATAL) << "do not allow NDarray as argument";
932   }
Visit(const char * key,runtime::ObjectRef * obj)933   void Visit(const char* key, runtime::ObjectRef* obj) final {
934     PrintKV(key, parent_->PrintAttr(*obj));
935   }
936 
937  private:
938   std::vector<Doc>* docs;
939   PrettyPrinter* parent_;
940 };
941 
PrintCallAttrs(const Attrs & attrs,const Expr & op)942 std::vector<Doc> PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
943   std::vector<Doc> docs;
944   if (!attrs.defined()) return docs;
945   const auto* op_node = op.as<OpNode>();
946   if (op_node && (attrs->type_index() != op_node->attrs_type_index)) {
947     // fallback
948     Doc doc;
949     doc << meta_.GetMetaNode(attrs);
950     docs.push_back(doc);
951     return docs;
952   } else {
953     AttrPrinter printer(&docs, this);
954     const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
955     return docs;
956   }
957 }
958 
PrintFuncAttrs(const Attrs & attrs)959 std::vector<Doc> PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) {
960   std::vector<Doc> docs;
961   if (!attrs.defined()) return docs;
962   const auto* dict_attrs = attrs.as<DictAttrsNode>();
963   CHECK(dict_attrs);
964   for (const auto& k : dict_attrs->dict) {
965     Doc doc;
966     doc << k.first << "=" << Print(k.second);
967     docs.push_back(doc);
968   }
969   return docs;
970 }
971 
PrettyPrint_(const NodeRef & node,bool show_meta_data,runtime::TypedPackedFunc<std::string (Expr)> annotate)972 std::string PrettyPrint_(const NodeRef& node,
973                          bool show_meta_data,
974                          runtime::TypedPackedFunc<std::string(Expr)> annotate) {
975   Doc doc;
976   doc << kSemVer << PrintNewLine()
977       << PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
978   return doc.str();
979 }
980 
PrettyPrint(const NodeRef & node)981 std::string PrettyPrint(const NodeRef& node) {
982   Doc doc;
983   doc << PrettyPrinter(false, runtime::TypedPackedFunc<std::string(Expr)>()).PrintFinal(node);
984   return doc.str();
985 }
986 
AsText(const NodeRef & node,bool show_meta_data,runtime::TypedPackedFunc<std::string (Expr)> annotate)987 std::string AsText(const NodeRef& node,
988                        bool show_meta_data,
989                        runtime::TypedPackedFunc<std::string(Expr)> annotate) {
990   return PrettyPrint_(node, show_meta_data, annotate);
991 }
992 
993 TVM_REGISTER_API("relay._expr.AsText")
994 .set_body_typed<std::string(const NodeRef&,
995                             bool,
996                             runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);
997 
998 }  // namespace relay
999 }  // namespace tvm
1000