1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file error_reporter.h
22 * \brief The set of errors raised by Relay.
23 */
24
25 #include <tvm/relay/expr.h>
26 #include <tvm/relay/module.h>
27 #include <tvm/relay/error.h>
28 #include <string>
29 #include <vector>
30 #include <rang.hpp>
31
32 namespace tvm {
33 namespace relay {
34
Raise() const35 void RelayErrorStream::Raise() const {
36 throw Error(*this);
37 }
38
39 template<typename T, typename U>
40 using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
41
RenderErrors(const Module & module,bool use_color)42 void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
43 // First we pick an error reporting strategy for each error.
44 // TODO(@jroesch): Spanned errors are currently not supported.
45 for (auto err : this->errors_) {
46 CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported";
47 }
48
49 NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps;
50
51 // Set control mode in order to produce colors;
52 if (use_color) {
53 rang::setControlMode(rang::control::Force);
54 }
55
56 for (auto pair : this->node_to_gv_) {
57 auto node = pair.first;
58 auto global = Downcast<GlobalVar>(pair.second);
59
60 auto has_errs = this->node_to_error_.find(node);
61
62 CHECK(has_errs != this->node_to_error_.end());
63
64 const auto& error_indicies = has_errs->second;
65
66 std::stringstream err_msg;
67
68 err_msg << rang::fg::red;
69 err_msg << " ";
70 for (auto index : error_indicies) {
71 err_msg << this->errors_[index].what() << "; ";
72 }
73 err_msg << rang::fg::reset;
74
75 // Setup error map.
76 auto it = error_maps.find(global);
77 if (it != error_maps.end()) {
78 it->second.insert({ node, err_msg.str() });
79 } else {
80 error_maps.insert({ global, { { node, err_msg.str() }}});
81 }
82 }
83
84 // Now we will construct the fully-annotated program to display to
85 // the user.
86 std::stringstream annotated_prog;
87
88 // First we output a header for the errors.
89 annotated_prog <<
90 rang::style::bold << std::endl <<
91 "Error(s) have occurred. The program has been annotated with them:"
92 << std::endl << std::endl << rang::style::reset;
93
94 // For each global function which contains errors, we will
95 // construct an annotated function.
96 for (auto pair : error_maps) {
97 auto global = pair.first;
98 auto err_map = pair.second;
99 auto func = module->Lookup(global);
100
101 // We output the name of the function before displaying
102 // the annotated program.
103 annotated_prog <<
104 rang::style::bold <<
105 "In `" << global->name_hint << "`: " <<
106 std::endl <<
107 rang::style::reset;
108
109 // We then call into the Relay printer to generate the program.
110 //
111 // The annotation callback will annotate the error messages
112 // contained in the map.
113 annotated_prog << AsText(func, false, [&err_map](tvm::relay::Expr expr) {
114 auto it = err_map.find(expr);
115 if (it != err_map.end()) {
116 CHECK_NE(it->second.size(), 0);
117 return it->second;
118 } else {
119 return std::string("");
120 }
121 });
122 }
123
124 auto msg = annotated_prog.str();
125
126 if (use_color) {
127 rang::setControlMode(rang::control::Auto);
128 }
129
130 // Finally we report the error, currently we do so to LOG(FATAL),
131 // it may be good to instead report it to std::cout.
132 LOG(FATAL) << annotated_prog.str() << std::endl;
133 }
134
ReportAt(const GlobalVar & global,const NodeRef & node,const Error & err)135 void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) {
136 size_t index_to_insert = this->errors_.size();
137 this->errors_.push_back(err);
138 auto it = this->node_to_error_.find(node);
139 if (it != this->node_to_error_.end()) {
140 it->second.push_back(index_to_insert);
141 } else {
142 this->node_to_error_.insert({ node, { index_to_insert }});
143 }
144 this->node_to_gv_.insert({ node, global });
145 }
146
147 } // namespace relay
148 } // namespace tvm
149