1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file error.h 22 * \brief The set of errors raised by Relay. 23 */ 24 #ifndef TVM_RELAY_ERROR_H_ 25 #define TVM_RELAY_ERROR_H_ 26 27 #include <string> 28 #include <vector> 29 #include <sstream> 30 #include <unordered_map> 31 #include "./base.h" 32 #include "./expr.h" 33 #include "./module.h" 34 35 namespace tvm { 36 namespace relay { 37 38 #define RELAY_ERROR(msg) (RelayErrorStream() << msg) 39 40 // Forward declaratio for RelayErrorStream. 41 struct Error; 42 43 /*! \brief A wrapper around std::stringstream. 44 * 45 * This is designed to avoid platform specific 46 * issues compiling and using std::stringstream 47 * for error reporting. 48 */ 49 struct RelayErrorStream { 50 std::stringstream ss; 51 52 template<typename T> 53 RelayErrorStream& operator<<(const T& t) { 54 ss << t; 55 return *this; 56 } 57 strRelayErrorStream58 std::string str() const { 59 return ss.str(); 60 } 61 62 void Raise() const; 63 }; 64 65 struct Error : public dmlc::Error { 66 Span sp; ErrorError67 explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {} ErrorError68 Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*) ErrorError69 Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {} ErrorError70 Error() : dmlc::Error(""), sp(nullptr) {} 71 }; 72 73 /*! \brief An abstraction around how errors are stored and reported. 74 * Designed to be opaque to users, so we can support a robust and simpler 75 * error reporting mode, as well as a more complex mode. 76 * 77 * The first mode is the most accurate: we report a Relay error at a specific 78 * Span, and then render the error message directly against a textual representation 79 * of the program, highlighting the exact lines in which it occurs. This mode is not 80 * implemented in this PR and will not work. 81 * 82 * The second mode is a general-purpose mode, which attempts to annotate the program's 83 * textual format with errors. 84 * 85 * The final mode represents the old mode, if we report an error that has no span or 86 * expression, we will default to throwing an exception with a textual representation 87 * of the error and no indication of where it occurred in the original program. 88 * 89 * The latter mode is not ideal, and the goal of the new error reporting machinery is 90 * to avoid ever reporting errors in this style. 91 */ 92 class ErrorReporter { 93 public: ErrorReporter()94 ErrorReporter() : errors_(), node_to_error_() {} 95 96 /*! \brief Report a tvm::relay::Error. 97 * 98 * This API is useful for reporting spanned errors. 99 * 100 * \param err The error to report. 101 */ Report(const Error & err)102 void Report(const Error& err) { 103 if (!err.sp.defined()) { 104 throw err; 105 } 106 107 this->errors_.push_back(err); 108 } 109 110 /*! \brief Report an error against a program, using the full program 111 * error reporting strategy. 112 * 113 * This error reporting method requires the global function in which 114 * to report an error, the expression to report the error on, 115 * and the error object. 116 * 117 * \param global The global function in which the expression is contained. 118 * \param node The expression or type to report the error at. 119 * \param err The error message to report. 120 */ ReportAt(const GlobalVar & global,const NodeRef & node,std::stringstream & err)121 inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) { 122 std::string err_msg = err.str(); 123 this->ReportAt(global, node, Error(err_msg)); 124 } 125 126 /*! \brief Report an error against a program, using the full program 127 * error reporting strategy. 128 * 129 * This error reporting method requires the global function in which 130 * to report an error, the expression to report the error on, 131 * and the error object. 132 * 133 * \param global The global function in which the expression is contained. 134 * \param node The expression or type to report the error at. 135 * \param err The error to report. 136 */ 137 void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err); 138 139 /*! \brief Render all reported errors and exit the program. 140 * 141 * This function should be used after executing a pass to render reported errors. 142 * 143 * It will build an error message from the set of errors, depending on the error 144 * reporting strategy. 145 * 146 * \param module The module to report errors on. 147 * \param use_color Controls whether to colorize the output. 148 */ 149 void RenderErrors(const Module& module, bool use_color = true); 150 AnyErrors()151 inline bool AnyErrors() { 152 return errors_.size() != 0; 153 } 154 155 private: 156 std::vector<Error> errors_; 157 std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_; 158 std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_; 159 }; 160 161 } // namespace relay 162 } // namespace tvm 163 164 #endif // TVM_RELAY_ERROR_H_ 165