1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file error.h
22  * \brief The set of errors raised by Relay.
23  */
24 #ifndef TVM_RELAY_ERROR_H_
25 #define TVM_RELAY_ERROR_H_
26 
27 #include <string>
28 #include <vector>
29 #include <sstream>
30 #include <unordered_map>
31 #include "./base.h"
32 #include "./expr.h"
33 #include "./module.h"
34 
35 namespace tvm {
36 namespace relay {
37 
38 #define RELAY_ERROR(msg) (RelayErrorStream() << msg)
39 
40 // Forward declaratio for RelayErrorStream.
41 struct Error;
42 
43 /*! \brief A wrapper around std::stringstream.
44  *
45  * This is designed to avoid platform specific
46  * issues compiling and using std::stringstream
47  * for error reporting.
48  */
49 struct RelayErrorStream {
50   std::stringstream ss;
51 
52   template<typename T>
53   RelayErrorStream& operator<<(const T& t) {
54     ss << t;
55     return *this;
56   }
57 
strRelayErrorStream58   std::string str() const {
59     return ss.str();
60   }
61 
62   void Raise() const;
63 };
64 
65 struct Error : public dmlc::Error {
66   Span sp;
ErrorError67   explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {}
ErrorError68   Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*)
ErrorError69   Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {}
ErrorError70   Error() : dmlc::Error(""), sp(nullptr) {}
71 };
72 
73 /*! \brief An abstraction around how errors are stored and reported.
74  * Designed to be opaque to users, so we can support a robust and simpler
75  * error reporting mode, as well as a more complex mode.
76  *
77  * The first mode is the most accurate: we report a Relay error at a specific
78  * Span, and then render the error message directly against a textual representation
79  * of the program, highlighting the exact lines in which it occurs. This mode is not
80  * implemented in this PR and will not work.
81  *
82  * The second mode is a general-purpose mode, which attempts to annotate the program's
83  * textual format with errors.
84  *
85  * The final mode represents the old mode, if we report an error that has no span or
86  * expression, we will default to throwing an exception with a textual representation
87  * of the error and no indication of where it occurred in the original program.
88  *
89  * The latter mode is not ideal, and the goal of the new error reporting machinery is
90  * to avoid ever reporting errors in this style.
91  */
92 class ErrorReporter {
93  public:
ErrorReporter()94   ErrorReporter() : errors_(), node_to_error_() {}
95 
96   /*! \brief Report a tvm::relay::Error.
97    *
98    * This API is useful for reporting spanned errors.
99    *
100    * \param err The error to report.
101    */
Report(const Error & err)102   void Report(const Error& err) {
103     if (!err.sp.defined()) {
104       throw err;
105     }
106 
107     this->errors_.push_back(err);
108   }
109 
110   /*! \brief Report an error against a program, using the full program
111    * error reporting strategy.
112    *
113    * This error reporting method requires the global function in which
114    * to report an error, the expression to report the error on,
115    * and the error object.
116    *
117    * \param global The global function in which the expression is contained.
118    * \param node The expression or type to report the error at.
119    * \param err The error message to report.
120    */
ReportAt(const GlobalVar & global,const NodeRef & node,std::stringstream & err)121   inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
122     std::string err_msg = err.str();
123     this->ReportAt(global, node, Error(err_msg));
124   }
125 
126   /*! \brief Report an error against a program, using the full program
127    * error reporting strategy.
128    *
129    * This error reporting method requires the global function in which
130    * to report an error, the expression to report the error on,
131    * and the error object.
132    *
133    * \param global The global function in which the expression is contained.
134    * \param node The expression or type to report the error at.
135    * \param err The error to report.
136    */
137   void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err);
138 
139   /*! \brief Render all reported errors and exit the program.
140    *
141    * This function should be used after executing a pass to render reported errors.
142    *
143    * It will build an error message from the set of errors, depending on the error
144    * reporting strategy.
145    *
146    * \param module The module to report errors on.
147    * \param use_color Controls whether to colorize the output.
148    */
149   void RenderErrors(const Module& module, bool use_color = true);
150 
AnyErrors()151   inline bool AnyErrors() {
152     return errors_.size() != 0;
153   }
154 
155  private:
156   std::vector<Error> errors_;
157   std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
158   std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_;
159 };
160 
161 }  // namespace relay
162 }  // namespace tvm
163 
164 #endif  // TVM_RELAY_ERROR_H_
165