1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file pretty_printer.cc
22 * \brief Pretty printer for Relay programs
23 * Supports ANF, GNF, and metadata.
24 *
25 * Inlining heuristics:
26 * - Always inline:
27 * - GlobalVar
28 * - Constant
29 * - Op
30 * - Var
31 * - Otherwise, inline if the node is at the end of a scope and is used at most once.
32 */
33
34 #include <tvm/node/serialization.h>
35 #include <tvm/relay/expr_functor.h>
36 #include <tvm/relay/module.h>
37 #include <tvm/relay/pattern_functor.h>
38 #include "doc.h"
39 #include "type_functor.h"
40 #include "../pass/dependency_graph.h"
41 #include "../../lang/attr_functor.h"
42
43 namespace tvm {
44 namespace relay {
45
46 static const char* kSemVer = "v0.0.4";
47
Brace(const Doc & d,const std::string & open="{",const std::string & close="}",int indent=2)48 Doc Brace(const Doc& d,
49 const std::string& open = "{",
50 const std::string& close = "}",
51 int indent = 2) {
52 Doc doc;
53 doc << open;
54 doc << Indent(indent, PrintNewLine() << d) << PrintNewLine();
55 doc << close;
56 return doc;
57 }
58
59 /*!
60 * \brief Meta data context for PrettyPrinter.
61 *
62 * This is an important part to enable bi-directional serializability.
63 * We use tvm's Node system to build the current IR.
64 * It can be hard to design a text format for all the possible nodes
65 * as the set of nodes can grow when we do more extensions.
66 *
67 * Instead of trying to design readable text format for every node,
68 * we support a meta data section in the text format.
69 * We allow the text format to refer to a node in the meta data section.
70 *
71 * The meta data section is a json serialized string of an Map<string, Array<NodeRef>>.
72 * Each element in the meta data section can be referenced by the text format.
73 * Each meta data node is printed in the following format.
74 *
75 * meta[type-key-of-node>][<index-in-meta-section>]
76 *
77 * Specifically, consider the following IR(constructed by python).
78 *
79 * \code
80 *
81 * n = tvm.var("n")
82 * x = tvm.relay.var("x", shape=(n, 1))
83 * f = tvm.relay.Function([x], x)
84 * print(f.astext())
85 *
86 * \endcode
87 *
88 * The corresponding text format is shown in the following code block.
89 *
90 * \code
91 *
92 * fn (%x: Tensor[(meta[Variable][0],), float32]) {
93 * %x
94 * }
95 * # Meta data section is a json-serialized string
96 * # of the following array.
97 * # [tvm.var("n")]
98 *
99 * \endcode
100 *
101 * Note that we store tvm.var("n") in the meta data section.
102 * Since it is stored in the index-0 in the meta data section,
103 * we print it as meta[Variable][0].
104 *
105 * The text parser can recover this object by loading from the corresponding
106 * location in the meta data section.
107 *
108 * This is is a design trade-off.
109 * It allows us to embedded any meta data in the text format,
110 * while still being able to tweak the text part of the printed IR easily.
111 */
112 class TextMetaDataContext {
113 public:
114 /*!
115 * \brief Get text representation of meta node.
116 * \param node The node to be converted to meta node.
117 * \return A string representation of the meta node.
118 */
GetMetaNode(const NodeRef & node)119 Doc GetMetaNode(const NodeRef& node) {
120 auto it = meta_repr_.find(node);
121 if (it != meta_repr_.end()) {
122 return it->second;
123 }
124 std::string type_key = node->GetTypeKey();
125 CHECK(!type_key.empty());
126 Array<NodeRef>& mvector =
127 meta_data_[type_key];
128 int64_t index = static_cast<int64_t>(mvector.size());
129 mvector.push_back(node);
130 Doc doc;
131 doc << "meta[" << type_key << "][" << index << "]";
132 meta_repr_[node] = doc;
133 return meta_repr_[node];
134 }
135
PrintKeyValue(const std::string & str,const Doc & v) const136 Doc PrintKeyValue(const std::string& str, const Doc& v) const {
137 return Doc("\"") << str << "\": " << v;
138 }
139
140 /*!
141 * \brief Get the metadata section in json format.
142 * \return the meta data string.
143 */
GetMetaSection() const144 Doc GetMetaSection() const {
145 if (meta_data_.size() == 0) return Doc();
146 return Doc(SaveJSON(Map<std::string, NodeRef>(meta_data_.begin(), meta_data_.end())));
147 }
148
149 /*! \return whether the meta data context is empty. */
empty() const150 bool empty() const {
151 return meta_data_.empty();
152 }
153
154 private:
155 /*! \brief additional metadata stored in TVM json format */
156 std::unordered_map<std::string, Array<NodeRef> > meta_data_;
157 /*! \brief map from meta data into its string representation */
158 std::unordered_map<NodeRef, Doc, NodeHash, NodeEqual> meta_repr_;
159 };
160
161 class PrettyPrinter :
162 public ExprFunctor<Doc(const Expr&)>,
163 public PatternFunctor<Doc(const Pattern&)>,
164 public TypeFunctor<Doc(const Type&)>,
165 public AttrFunctor<Doc(const ObjectRef&)> {
166 public:
PrettyPrinter(bool show_meta_data,runtime::TypedPackedFunc<std::string (Expr)> annotate)167 explicit PrettyPrinter(bool show_meta_data,
168 runtime::TypedPackedFunc<std::string(Expr)> annotate) :
169 show_meta_data_(show_meta_data),
170 annotate_(annotate) {}
171
172 /*!
173 * \brief Print additional info about expr in comment.
174 * \param expr The expression.
175 */
PrintOptionalInfo(const Expr & expr)176 Doc PrintOptionalInfo(const Expr& expr) {
177 Doc doc;
178 // default annotations
179 if (annotate_ == nullptr) {
180 if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
181 doc << " /* ty=" << Print(expr->checked_type()) << " */";
182 }
183 } else {
184 std::string annotated_expr = annotate_(expr);
185 if (annotated_expr != "") {
186 doc << annotated_expr;
187 }
188 }
189
190 return doc;
191 }
192
193 // indent a new body
PrintBody(const NodeRef & node,int indent=2)194 Doc PrintBody(const NodeRef& node, int indent = 2) {
195 Doc doc;
196 Doc body;
197 doc << "{";
198 doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine();
199 doc << "}";
200 return doc;
201 }
202
203 // create a new scope by creating a new printer object. This allows temp var
204 // numbers to be reused and prevents hoisted vars from escaping too far
PrintScope(const NodeRef & node)205 Doc PrintScope(const NodeRef& node) {
206 // print in a new scope
207 doc_stack_.push_back(Doc());
208 // must print first so doc_stack_.back() reference doesn't become stale
209 Doc doc = Print(node, false, true);
210 doc = doc_stack_.back() << doc;
211 doc_stack_.pop_back();
212 return doc;
213 }
214
PrintFinal(const NodeRef & node)215 Doc PrintFinal(const NodeRef& node) {
216 if (node.as<ExprNode>()) {
217 Expr expr = Downcast<Expr>(node);
218 dg_ = DependencyGraph::Create(&arena_, expr);
219 }
220
221 Doc doc;
222 doc << PrintScope(node);
223 if (!meta_.empty()) {
224 doc << PrintNewLine();
225 if (show_meta_data_) {
226 // append meta data in the end.
227 doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection();
228 } else {
229 doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
230 }
231 }
232 return doc;
233 }
234
235 std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
236 std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
237
Print(const NodeRef & node,bool meta=false,bool try_inline=false)238 Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
239 if (node.as<ExprNode>()) {
240 return PrintExpr(Downcast<Expr>(node), meta, try_inline);
241 } else if (node.as<TypeNode>()) {
242 return PrintType(Downcast<Type>(node), meta);
243 } else if (node.as<PatternNode>()) {
244 return PrintPattern(Downcast<Pattern>(node), meta);
245 } else if (node.as<ModuleNode>()) {
246 return PrintMod(Downcast<Module>(node));
247 } else {
248 Doc doc;
249 return doc << node;
250 }
251 }
252
TempVar(int n)253 Doc TempVar(int n) {
254 Doc doc;
255 return doc << "%" << n;
256 }
257
AllocTemp()258 Doc AllocTemp() {
259 return TempVar(temp_var_counter_++);
260 }
261
262 /*!
263 * \brief get a unique name with the corresponding prefix
264 * \param prefix The prefix of the name
265 * \return The returned name.
266 */
GetUniqueName(const std::string & prefix)267 Doc GetUniqueName(const std::string& prefix) {
268 std::string unique_prefix = prefix;
269 auto it = name_alloc_map_.find(prefix);
270 if (it != name_alloc_map_.end()) {
271 while (true) {
272 std::ostringstream os;
273 os << prefix << (++it->second);
274 std::string name = os.str();
275 if (name_alloc_map_.count(name) == 0) {
276 unique_prefix = name;
277 break;
278 }
279 }
280 }
281 name_alloc_map_[unique_prefix] = 0;
282 return Doc(unique_prefix);
283 }
284
Print(Kind k)285 Doc Print(Kind k) {
286 switch (k) {
287 case kType:
288 return Doc("Type");
289 case kShapeVar:
290 return Doc("Shape");
291 case kBaseType:
292 return Doc("BaseType");
293 case kConstraint:
294 return Doc("Constraint");
295 case kAdtHandle:
296 return Doc("AdtHandle");
297 case kTypeData:
298 return Doc("TypeData");
299 default:
300 LOG(ERROR) << "Unknown Kind";
301 throw;
302 }
303 }
304 /*!
305 * \brief Allocate name to a type variable.
306 * \param var The input type variable.
307 * \return The corresponding name.
308 */
AllocTypeVar(const TypeVar & var)309 Doc AllocTypeVar(const TypeVar& var) {
310 if (memo_type_.count(var)) {
311 Doc val = memo_type_[var];
312 val << "-malformed-ir";
313 return val;
314 }
315 std::string name = var->var->name_hint;
316 if (name.length() == 0 || !std::isalpha(name[0])) {
317 name = "t" + name;
318 }
319 Doc val = GetUniqueName(name);
320 memo_type_[var] = val;
321 if (var->kind != kType) {
322 val << ": " << Print(var->kind);
323 }
324 return val;
325 }
326
327 /*!
328 * \brief Allocate name to a variable.
329 * \param var The input variable.
330 * \return The corresponding name.
331 */
AllocVar(const Var & var)332 Doc AllocVar(const Var& var) {
333 // still print if ir is malformed, but show the error.
334 if (memo_.count(var)) {
335 Doc val = memo_[var];
336 val << "-malformed-ir";
337 return val;
338 }
339 std::string name = var->name_hint();
340 // always make sure first name is alpha
341 if (name.length() == 0 || !std::isalpha(name[0])) {
342 name = "v" + name;
343 }
344 Doc val = GetUniqueName("%" + name);
345 memo_[var] = val;
346 if (var->type_annotation.defined()) {
347 val << ": " << Print(var->type_annotation);
348 }
349 return val;
350 }
351
IsUnique(const Expr & expr)352 bool IsUnique(const Expr& expr) {
353 auto it = dg_.expr_node.find(expr);
354 if (it == dg_.expr_node.end()) {
355 return true;
356 } else {
357 return !(it->second->parents.head && it->second->parents.head->next);
358 }
359 }
360
AlwaysInline(const Expr & expr)361 bool AlwaysInline(const Expr& expr) {
362 return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
363 expr.as<VarNode>() || expr.as<ConstructorNode>();
364 }
365
366 //------------------------------------
367 // Overload of Expr printing functions
368 //------------------------------------
PrintExpr(const Expr & expr,bool meta,bool try_inline)369 Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
370 // Exploit memoization to print GNF.
371 // The first time we visit an expression, we need to allocate a temp var
372 // for it. Every subsequent time we can just use its assigned variable.
373 // This works since hashing uses pointer equality.
374
375 // determine whether to inline
376 bool inline_expr = AlwaysInline(expr);
377 if (try_inline) {
378 inline_expr |= IsUnique(expr);
379 }
380
381 auto it = memo_.find(expr);
382 if (it != memo_.end()) return it->second;
383
384 Doc printed_expr;
385 if (meta) {
386 printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get()));
387 } else if (!inline_expr && expr.as<LetNode>()) {
388 // wrap GNFed let in brackets
389 Doc body;
390 printed_expr << "(";
391 printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
392 printed_expr << ")";
393 } else {
394 printed_expr = VisitExpr(expr);
395 }
396
397 printed_expr << PrintOptionalInfo(expr);
398
399 // add expr to doc
400 if (expr.as<VarNode>()) {
401 // This is our first time visiting the var and we hit the VarNode case
402 // in the visitor. Thus the variable is free.
403 doc_stack_.back() << "free_var " << printed_expr << PrintNewLine();
404 // Memoization is done in AllocVar.
405 return memo_[expr];
406 } else if (inline_expr) {
407 memo_[expr] = printed_expr;
408 return printed_expr;
409 } else {
410 Doc temp_var = AllocTemp();
411 memo_[expr] = temp_var;
412 doc_stack_.back() << temp_var << " = " << printed_expr << ";" << PrintNewLine();
413 return temp_var;
414 }
415 }
416
417 // Should only be triggered when op is a free variable being visited for the
418 // first time.
VisitExpr_(const VarNode * op)419 Doc VisitExpr_(const VarNode* op) final {
420 return AllocVar(GetRef<Var>(op));
421 }
422
VisitExpr_(const ConstantNode * op)423 Doc VisitExpr_(const ConstantNode* op) final {
424 // Print out simple scalars directly.
425 if (op->is_scalar()) {
426 std::ostringstream os;
427 DataType dtype = TVMType2Type(op->data->dtype);
428 CHECK_EQ(op->data->ctx.device_type, kDLCPU);
429 if (dtype == Int(32)) {
430 return PrintConstScalar(dtype, static_cast<const int32_t*>(op->data->data));
431 } else if (dtype == Int(64)) {
432 return PrintConstScalar(dtype, static_cast<const int64_t*>(op->data->data));
433 } else if (dtype == Float(32)) {
434 return PrintConstScalar(dtype, static_cast<const float*>(op->data->data));
435 } else if (dtype == Float(64)) {
436 return PrintConstScalar(dtype, static_cast<const double*>(op->data->data));
437 } else if (dtype == Bool()) {
438 return PrintConstScalar(dtype, static_cast<const uint8_t*>(op->data->data));
439 }
440 }
441 // default fall-back, record it as meta node.
442 Doc doc;
443 return doc << Print(GetRef<NodeRef>(op), true);
444 }
445
VisitExpr_(const TupleNode * op)446 Doc VisitExpr_(const TupleNode* op) final {
447 std::vector<Doc> fields;
448 for (Expr field : op->fields) {
449 fields.push_back(Print(field));
450 }
451 Doc doc;
452 doc << "(" << PrintSep(fields);
453 // conform to python tuple format (1,)
454 if (op->fields.size() == 1) {
455 doc << ",";
456 }
457 return doc << ")";
458 }
459
VisitExpr_(const TupleGetItemNode * op)460 Doc VisitExpr_(const TupleGetItemNode* op) final {
461 Doc doc;
462 return doc << Print(op->tuple) << "." << op->index;
463 }
464
VisitExpr_(const IfNode * op)465 Doc VisitExpr_(const IfNode* op) final {
466 Doc doc;
467 doc << "if (" << Print(op->cond) << ") ";
468 doc << PrintBody(op->true_branch);
469 doc << " else ";
470 doc << PrintBody(op->false_branch);
471 return doc;
472 }
473
VisitExpr_(const LetNode * op)474 Doc VisitExpr_(const LetNode* op) final {
475 Doc doc;
476 doc
477 << "let "
478 << AllocVar(op->var)
479 << " = "
480 << Print(op->value, false, true)
481 << ";"
482 << PrintNewLine();
483 // we use a scope here so GNF hoisting doesn't escape too far
484 // and nested, unique lets are not hoisted
485 doc << PrintScope(op->body);
486 return doc;
487 }
488
PrintFunc(const Doc & prefix,const Function & fn)489 Doc PrintFunc(const Doc& prefix, const Function& fn) {
490 Doc doc;
491 doc << prefix;
492 if (fn->type_params.size() > 0) {
493 doc << "[";
494 std::vector<Doc> type_params;
495 for (const TypeVar& tv : fn->type_params) {
496 type_params.push_back(Doc(tv->var->name_hint));
497 }
498 doc << PrintSep(type_params);
499 doc << "]";
500 }
501 doc << "(";
502 std::vector<Doc> params;
503 for (Var param : fn->params) {
504 params.push_back(AllocVar(param));
505 }
506 for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
507 params.push_back(d);
508 }
509 doc << PrintSep(params) << ") ";
510 if (fn->ret_type.defined()) {
511 doc << "-> " << Print(fn->ret_type) << " ";
512 }
513 doc << PrintBody(fn->body);
514 return doc;
515 }
516
PrintMod(const Module & mod)517 Doc PrintMod(const Module& mod) {
518 Doc doc;
519 int counter = 0;
520 // type definitions
521 for (const auto& kv : mod->type_definitions) {
522 if (counter++ != 0) {
523 doc << PrintNewLine();
524 }
525 doc << Print(kv.second);
526 doc << PrintNewLine();
527 }
528 // functions
529 for (const auto& kv : mod->functions) {
530 dg_ = DependencyGraph::Create(&arena_, kv.second);
531
532 if (counter++ != 0) {
533 doc << PrintNewLine();
534 }
535 std::ostringstream os;
536 os << "def @" << kv.first->name_hint;
537 doc << PrintFunc(Doc(os.str()), kv.second);
538 doc << PrintNewLine();
539 }
540 return doc;
541 }
542
VisitExpr_(const FunctionNode * op)543 Doc VisitExpr_(const FunctionNode* op) final {
544 return PrintFunc(Doc("fn "), GetRef<Function>(op));
545 }
546
VisitExpr_(const GlobalVarNode * op)547 Doc VisitExpr_(const GlobalVarNode* op) final {
548 return Doc('@' + op->name_hint);
549 }
550
VisitExpr_(const OpNode * op)551 Doc VisitExpr_(const OpNode* op) final {
552 return Doc(op->name);
553 }
554
VisitExpr_(const CallNode * op)555 Doc VisitExpr_(const CallNode* op) final {
556 Doc doc;
557 // visit args first so they are lifted before the op
558 // this places op closer to its call site
559 std::vector<Doc> args;
560 for (const Expr& arg : op->args) {
561 args.push_back(Print(arg));
562 }
563 for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
564 args.push_back(d);
565 }
566 const auto* cons_node = op->op.as<ConstructorNode>();
567 if (cons_node) {
568 doc << cons_node->name_hint;
569 } else {
570 doc << Print(op->op);
571 }
572
573 if (cons_node && cons_node->inputs.size() == 0) {
574 // don't print as a call if it's a 0-arity cons
575 return doc;
576 } else {
577 return doc << "(" << PrintSep(args) << ")";
578 }
579 }
580
VisitExpr_(const RefCreateNode * op)581 Doc VisitExpr_(const RefCreateNode* op) final {
582 Doc doc;
583 return doc << "ref(" << Print(op->value) << ")";
584 }
585
VisitExpr_(const RefReadNode * op)586 Doc VisitExpr_(const RefReadNode* op) final {
587 Doc doc;
588 return doc << Print(op->ref) << "^";
589 }
590
VisitExpr_(const RefWriteNode * op)591 Doc VisitExpr_(const RefWriteNode* op) final {
592 Doc doc;
593 return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")";
594 }
595
VisitExpr_(const MatchNode * op)596 Doc VisitExpr_(const MatchNode* op) final {
597 // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
598 Doc doc;
599 Doc body;
600 doc << "match";
601 if (!op->complete) {
602 doc << "?";
603 }
604 doc << " (" << Print(op->data) << ") {";
605 std::vector<Doc> clause_docs;
606 for (const auto& clause : op->clauses) {
607 Doc clause_doc;
608 clause_doc << PrintPattern(clause->lhs, false) << " => ";
609 Doc rhs_doc = PrintScope(clause->rhs);
610 if (clause->rhs.as<LetNode>()) {
611 // only add braces if there are multiple lines on the rhs
612 rhs_doc = Brace(rhs_doc);
613 }
614 clause_doc << rhs_doc << ",";
615 clause_docs.push_back(clause_doc);
616 }
617 doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine()))
618 << PrintNewLine() << "}";
619 return doc;
620 }
621
PrintPattern(const Pattern & pattern,bool meta)622 Doc PrintPattern(const Pattern& pattern, bool meta) {
623 auto it = memo_pattern_.find(pattern);
624 if (it != memo_pattern_.end()) return it->second;
625 Doc printed_pattern;
626 if (meta) {
627 printed_pattern = meta_.GetMetaNode(GetRef<NodeRef>(pattern.get()));
628 } else {
629 printed_pattern = VisitPattern(pattern);
630 }
631 memo_pattern_[pattern] = printed_pattern;
632 return printed_pattern;
633 }
634
VisitPattern_(const PatternConstructorNode * p)635 Doc VisitPattern_(const PatternConstructorNode* p) final {
636 Doc doc;
637 doc << p->constructor->name_hint;
638 if (!p->patterns.empty()) {
639 doc << "(";
640 std::vector<Doc> pats;
641 for (const auto& pat : p->patterns) {
642 pats.push_back(Print(pat));
643 }
644 doc << PrintSep(pats) << ")";
645 }
646 return doc;
647 }
648
VisitPattern_(const PatternTupleNode * pt)649 Doc VisitPattern_(const PatternTupleNode* pt) final {
650 Doc doc;
651 doc << "(";
652 std::vector<Doc> pats;
653 for (const auto& pat : pt->patterns) {
654 pats.push_back(Print(pat));
655 }
656 doc << PrintSep(pats) << ")";
657 return doc;
658 }
659
VisitPattern_(const PatternWildcardNode * pw)660 Doc VisitPattern_(const PatternWildcardNode* pw) final {
661 return Doc("_");
662 }
663
VisitPattern_(const PatternVarNode * pv)664 Doc VisitPattern_(const PatternVarNode* pv) final {
665 return AllocVar(pv->var);
666 }
667
VisitExpr_(const ConstructorNode * n)668 Doc VisitExpr_(const ConstructorNode* n) final {
669 Doc doc;
670 doc << n->name_hint;
671 if (in_adt_def_ && n->inputs.size() != 0) {
672 doc << "(";
673 std::vector<Doc> inputs;
674 for (Type input : n->inputs) {
675 inputs.push_back(Print(input));
676 }
677 doc << PrintSep(inputs) << ")";
678 }
679 return doc;
680 }
681
682 //------------------------------------
683 // Overload of Type printing functions
684 //------------------------------------
PrintType(const Type & type,bool meta)685 Doc PrintType(const Type& type, bool meta) {
686 auto it = memo_type_.find(type);
687 if (it != memo_type_.end()) return it->second;
688 Doc printed_type;
689 if (meta) {
690 printed_type = meta_.GetMetaNode(GetRef<NodeRef>(type.get()));
691 } else {
692 printed_type = VisitType(type);
693 }
694 memo_type_[type] = printed_type;
695 return printed_type;
696 }
697
VisitTypeDefault_(const Node * node)698 Doc VisitTypeDefault_(const Node* node) final {
699 // by default always print as meta data
700 return Print(GetRef<NodeRef>(node), true);
701 }
702
VisitType_(const TypeVarNode * node)703 Doc VisitType_(const TypeVarNode* node) final {
704 return Doc(node->var->name_hint);
705 }
706
VisitType_(const GlobalTypeVarNode * node)707 Doc VisitType_(const GlobalTypeVarNode* node) final {
708 return Doc(node->var->name_hint);
709 }
710
VisitType_(const TypeCallNode * node)711 Doc VisitType_(const TypeCallNode* node) final {
712 Doc doc = PrintType(node->func, false);
713 std::vector<Doc> args;
714 for (const Type& t : node->args) {
715 args.push_back(PrintType(t, false));
716 }
717 doc << "[";
718 doc << PrintSep(args);
719 doc << "]";
720 return doc;
721 }
722
VisitType_(const TensorTypeNode * node)723 Doc VisitType_(const TensorTypeNode* node) final {
724 // scalar type
725 if (node->shape.size() == 0) {
726 return PrintDType(node->dtype);
727 }
728 Doc doc;
729 doc << "Tensor[(";
730 std::vector<Doc> shapes;
731 for (NodeRef shape : node->shape) {
732 shapes.push_back(PrintAttr(shape));
733 }
734 doc << PrintSep(shapes);
735 return doc << "), " << PrintDType(node->dtype) << "]";
736 }
737
VisitType_(const TupleTypeNode * node)738 Doc VisitType_(const TupleTypeNode* node) final {
739 std::vector<Doc> fields;
740 for (Type field : node->fields) {
741 fields.push_back(Print(field));
742 }
743 Doc doc;
744 doc << "(" << PrintSep(fields);
745 // conform to python tuple format (1,)
746 if (node->fields.size() == 1) {
747 doc << ",";
748 }
749 return doc << ")";
750 }
751
VisitType_(const FuncTypeNode * node)752 Doc VisitType_(const FuncTypeNode* node) final {
753 Doc doc;
754 doc << "fn ";
755 if (node->type_params.size() != 0) {
756 doc << "[";
757 std::vector<Doc> type_params;
758 for (Type type_param : node->type_params) {
759 type_params.push_back(Print(type_param));
760 }
761 doc << PrintSep(type_params);
762 doc << "]";
763 }
764 std::vector<Doc> arg_types;
765 for (Type arg_type : node->arg_types) {
766 arg_types.push_back(Print(arg_type));
767 }
768 return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
769 }
770
VisitType_(const RefTypeNode * node)771 Doc VisitType_(const RefTypeNode* node) final {
772 Doc doc;
773 return doc << "ref(" << Print(node->value) << ")";
774 }
775
VisitType_(const TypeDataNode * node)776 Doc VisitType_(const TypeDataNode* node) final {
777 in_adt_def_ = true;
778 Doc doc;
779 doc << "type " << Print(node->header);
780
781 // type vars
782 if (node->type_vars.size() != 0) {
783 doc << "[";
784 std::vector<Doc> type_vars;
785 for (Type type_var : node->type_vars) {
786 type_vars.push_back(Print(type_var));
787 }
788 doc << PrintSep(type_vars) << "]";
789 }
790 doc << " ";
791
792 std::vector<Doc> constructor_docs;
793 for (Constructor constructor : node->constructors) {
794 constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
795 }
796 Doc separator;
797 separator << "," << PrintNewLine();
798 Doc adt_body;
799 adt_body << PrintSep(constructor_docs, separator);
800 // add trailing comma if there are any constructors
801 if (!constructor_docs.empty()) {
802 adt_body << ",";
803 }
804 doc << Brace(adt_body);
805 in_adt_def_ = false;
806 return doc;
807 }
808
809 //------------------------------------
810 // Overload of Attr printing functions
811 //------------------------------------
812
PrintAttr(const ObjectRef & value,bool meta=false)813 Doc PrintAttr(const ObjectRef& value, bool meta = false) {
814 if (value.defined()) {
815 Doc printed_attr;
816 if (value.as<tvm::ir::Any>()) {
817 printed_attr << "?";
818 } else if (meta) {
819 printed_attr = meta_.GetMetaNode(Downcast<NodeRef>(value));
820 } else {
821 printed_attr = VisitAttr(value);
822 }
823 return printed_attr;
824 } else {
825 return Doc("None");
826 }
827 }
828
VisitAttrDefault_(const Object * op)829 Doc VisitAttrDefault_(const Object* op) final {
830 return PrintAttr(GetRef<ObjectRef>(op), true);
831 }
832
VisitAttr_(const ArrayNode * op)833 Doc VisitAttr_(const ArrayNode* op) final {
834 Doc doc;
835 doc << "[";
836 std::vector<Doc> arr_vals;
837 for (auto val : op->data) {
838 arr_vals.push_back(PrintAttr(val));
839 }
840 doc << PrintSep(arr_vals);
841 doc << "]";
842 return doc;
843 }
844
VisitAttr_(const ir::IntImm * op)845 Doc VisitAttr_(const ir::IntImm* op) final {
846 return PrintConstScalar(op->type, &(op->value));
847 }
848
VisitAttr_(const ir::UIntImm * op)849 Doc VisitAttr_(const ir::UIntImm* op) final {
850 return PrintConstScalar(op->type, &(op->value));
851 }
852
VisitAttr_(const ir::FloatImm * op)853 Doc VisitAttr_(const ir::FloatImm* op) final {
854 return PrintConstScalar(op->type, &(op->value));
855 }
856
VisitAttr_(const ir::StringImm * op)857 Doc VisitAttr_(const ir::StringImm* op) final {
858 return PrintString(op->value);
859 }
860
861 private:
862 /*! \brief Whether to print meta data. */
863 bool show_meta_data_;
864 /*! \brief additional comment function */
865 runtime::TypedPackedFunc<std::string(Expr)> annotate_;
866 /*! \brief Stack of docs to implement scoped GNFing. */
867 std::vector<Doc> doc_stack_{};
868 /*! \brief Map from Expr to Doc */
869 std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_;
870 /*! \brief Map from Type to Doc */
871 std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_;
872 /*! \brief Map from Type to Doc */
873 std::unordered_map<Pattern, Doc, NodeHash, NodeEqual> memo_pattern_;
874 /*! \brief name allocation map */
875 std::unordered_map<std::string, int> name_alloc_map_;
876 /*! \brief meta data context */
877 TextMetaDataContext meta_;
878 /*! \brief counter of temporary variable */
879 size_t temp_var_counter_{0};
880 /*! \brief whether the printer is currently in an ADT definition */
881 bool in_adt_def_;
882 /*! \brief arena for dependency graph */
883 common::Arena arena_;
884 /*! \brief dependency graph of the expr */
885 DependencyGraph dg_;
886 class AttrPrinter;
887 friend class AttrPrinter;
888 };
889
890 /*!
891 * \brief Attribute printer which prints the attributes in the call.
892 */
893 class PrettyPrinter::AttrPrinter : public AttrVisitor {
894 public:
AttrPrinter(std::vector<Doc> * doc,PrettyPrinter * parent)895 AttrPrinter(std::vector<Doc>* doc, PrettyPrinter* parent) : docs(doc), parent_(parent) {}
896
897 template<typename T>
PrintKV(const char * key,const T & value)898 void PrintKV(const char* key, const T& value) {
899 Doc doc;
900 doc << key << "=" << value;
901 docs->push_back(doc);
902 }
903
Visit(const char * key,double * value)904 void Visit(const char* key, double* value) final {
905 Doc doc;
906 doc << key << "=" << *value << "f";
907 docs->push_back(doc);
908 }
Visit(const char * key,int64_t * value)909 void Visit(const char* key, int64_t* value) final {
910 PrintKV(key, *value);
911 }
Visit(const char * key,uint64_t * value)912 void Visit(const char* key, uint64_t* value) final {
913 PrintKV(key, *value);
914 }
Visit(const char * key,int * value)915 void Visit(const char* key, int* value) final {
916 PrintKV(key, *value);
917 }
Visit(const char * key,bool * value)918 void Visit(const char* key, bool* value) final {
919 PrintKV(key, PrintBool(*value));
920 }
Visit(const char * key,std::string * value)921 void Visit(const char* key, std::string* value) final {
922 PrintKV(key, PrintString(*value));
923 }
Visit(const char * key,void ** value)924 void Visit(const char* key, void** value) final {
925 LOG(FATAL) << "do not allow void as argument";
926 }
Visit(const char * key,DataType * value)927 void Visit(const char* key, DataType* value) final {
928 PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(*value))));
929 }
Visit(const char * key,runtime::NDArray * value)930 void Visit(const char* key, runtime::NDArray* value) final {
931 LOG(FATAL) << "do not allow NDarray as argument";
932 }
Visit(const char * key,runtime::ObjectRef * obj)933 void Visit(const char* key, runtime::ObjectRef* obj) final {
934 PrintKV(key, parent_->PrintAttr(*obj));
935 }
936
937 private:
938 std::vector<Doc>* docs;
939 PrettyPrinter* parent_;
940 };
941
PrintCallAttrs(const Attrs & attrs,const Expr & op)942 std::vector<Doc> PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
943 std::vector<Doc> docs;
944 if (!attrs.defined()) return docs;
945 const auto* op_node = op.as<OpNode>();
946 if (op_node && (attrs->type_index() != op_node->attrs_type_index)) {
947 // fallback
948 Doc doc;
949 doc << meta_.GetMetaNode(attrs);
950 docs.push_back(doc);
951 return docs;
952 } else {
953 AttrPrinter printer(&docs, this);
954 const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
955 return docs;
956 }
957 }
958
PrintFuncAttrs(const Attrs & attrs)959 std::vector<Doc> PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) {
960 std::vector<Doc> docs;
961 if (!attrs.defined()) return docs;
962 const auto* dict_attrs = attrs.as<DictAttrsNode>();
963 CHECK(dict_attrs);
964 for (const auto& k : dict_attrs->dict) {
965 Doc doc;
966 doc << k.first << "=" << Print(k.second);
967 docs.push_back(doc);
968 }
969 return docs;
970 }
971
PrettyPrint_(const NodeRef & node,bool show_meta_data,runtime::TypedPackedFunc<std::string (Expr)> annotate)972 std::string PrettyPrint_(const NodeRef& node,
973 bool show_meta_data,
974 runtime::TypedPackedFunc<std::string(Expr)> annotate) {
975 Doc doc;
976 doc << kSemVer << PrintNewLine()
977 << PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
978 return doc.str();
979 }
980
PrettyPrint(const NodeRef & node)981 std::string PrettyPrint(const NodeRef& node) {
982 Doc doc;
983 doc << PrettyPrinter(false, runtime::TypedPackedFunc<std::string(Expr)>()).PrintFinal(node);
984 return doc.str();
985 }
986
AsText(const NodeRef & node,bool show_meta_data,runtime::TypedPackedFunc<std::string (Expr)> annotate)987 std::string AsText(const NodeRef& node,
988 bool show_meta_data,
989 runtime::TypedPackedFunc<std::string(Expr)> annotate) {
990 return PrettyPrint_(node, show_meta_data, annotate);
991 }
992
993 TVM_REGISTER_API("relay._expr.AsText")
994 .set_body_typed<std::string(const NodeRef&,
995 bool,
996 runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);
997
998 } // namespace relay
999 } // namespace tvm
1000