1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file error_reporter.h
22  * \brief The set of errors raised by Relay.
23  */
24 
25 #include <tvm/relay/expr.h>
26 #include <tvm/relay/module.h>
27 #include <tvm/relay/error.h>
28 #include <string>
29 #include <vector>
30 #include <rang.hpp>
31 
32 namespace tvm {
33 namespace relay {
34 
Raise() const35 void RelayErrorStream::Raise() const {
36   throw Error(*this);
37 }
38 
39 template<typename T, typename U>
40 using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
41 
RenderErrors(const Module & module,bool use_color)42 void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
43   // First we pick an error reporting strategy for each error.
44   // TODO(@jroesch): Spanned errors are currently not supported.
45   for (auto err : this->errors_) {
46     CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported";
47   }
48 
49   NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps;
50 
51   // Set control mode in order to produce colors;
52   if (use_color) {
53     rang::setControlMode(rang::control::Force);
54   }
55 
56   for (auto pair : this->node_to_gv_) {
57     auto node = pair.first;
58     auto global = Downcast<GlobalVar>(pair.second);
59 
60     auto has_errs = this->node_to_error_.find(node);
61 
62     CHECK(has_errs != this->node_to_error_.end());
63 
64     const auto& error_indicies = has_errs->second;
65 
66     std::stringstream err_msg;
67 
68     err_msg << rang::fg::red;
69     err_msg << " ";
70     for (auto index : error_indicies) {
71       err_msg << this->errors_[index].what() << "; ";
72     }
73     err_msg << rang::fg::reset;
74 
75     // Setup error map.
76     auto it = error_maps.find(global);
77     if (it != error_maps.end()) {
78       it->second.insert({ node, err_msg.str() });
79     } else {
80       error_maps.insert({ global, { { node, err_msg.str() }}});
81     }
82   }
83 
84   // Now we will construct the fully-annotated program to display to
85   // the user.
86   std::stringstream annotated_prog;
87 
88   // First we output a header for the errors.
89   annotated_prog <<
90   rang::style::bold << std::endl <<
91   "Error(s) have occurred. The program has been annotated with them:"
92   << std::endl << std::endl << rang::style::reset;
93 
94   // For each global function which contains errors, we will
95   // construct an annotated function.
96   for (auto pair : error_maps) {
97     auto global = pair.first;
98     auto err_map = pair.second;
99     auto func = module->Lookup(global);
100 
101     // We output the name of the function before displaying
102     // the annotated program.
103     annotated_prog <<
104       rang::style::bold <<
105       "In `" << global->name_hint << "`: " <<
106       std::endl <<
107       rang::style::reset;
108 
109     // We then call into the Relay printer to generate the program.
110     //
111     // The annotation callback will annotate the error messages
112     // contained in the map.
113     annotated_prog << AsText(func, false, [&err_map](tvm::relay::Expr expr) {
114       auto it = err_map.find(expr);
115       if (it != err_map.end()) {
116         CHECK_NE(it->second.size(), 0);
117         return it->second;
118       } else {
119         return std::string("");
120       }
121     });
122   }
123 
124   auto msg = annotated_prog.str();
125 
126   if (use_color) {
127     rang::setControlMode(rang::control::Auto);
128   }
129 
130   // Finally we report the error, currently we do so to LOG(FATAL),
131   // it may be good to instead report it to std::cout.
132   LOG(FATAL) << annotated_prog.str() << std::endl;
133 }
134 
ReportAt(const GlobalVar & global,const NodeRef & node,const Error & err)135 void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) {
136   size_t index_to_insert = this->errors_.size();
137   this->errors_.push_back(err);
138   auto it = this->node_to_error_.find(node);
139   if (it != this->node_to_error_.end()) {
140     it->second.push_back(index_to_insert);
141   } else {
142     this->node_to_error_.insert({ node, { index_to_insert }});
143   }
144   this->node_to_gv_.insert({ node, global });
145 }
146 
147 }  // namespace relay
148 }  // namespace tvm
149