1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file tvm/relay/adt.h 22 * \brief Algebraic data types for Relay 23 */ 24 #ifndef TVM_RELAY_ADT_H_ 25 #define TVM_RELAY_ADT_H_ 26 27 #include <tvm/ir/adt.h> 28 #include <tvm/ir/attrs.h> 29 #include <tvm/relay/base.h> 30 #include <tvm/relay/expr.h> 31 #include <tvm/relay/type.h> 32 33 #include <functional> 34 #include <string> 35 #include <utility> 36 37 namespace tvm { 38 namespace relay { 39 40 using Constructor = tvm::Constructor; 41 using ConstructorNode = tvm::ConstructorNode; 42 43 using TypeData = tvm::TypeData; 44 using TypeDataNode = tvm::TypeDataNode; 45 46 /*! \brief Base type for declaring relay pattern. */ 47 class PatternNode : public RelayNode { 48 public: 49 static constexpr const char* _type_key = "relay.Pattern"; 50 static constexpr const bool _type_has_method_sequal_reduce = true; 51 static constexpr const bool _type_has_method_shash_reduce = true; 52 TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object); 53 }; 54 55 /*! 56 * \brief Pattern is the base type for an ADT match pattern in Relay. 57 * 58 * Given an ADT value, a pattern might accept it and bind the pattern variable to some value 59 * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value. 60 * 61 * ADT pattern matching thus takes a list of values and binds to the first that accepts the value. 62 */ 63 class Pattern : public ObjectRef { 64 public: Pattern()65 Pattern() {} Pattern(ObjectPtr<tvm::Object> p)66 explicit Pattern(ObjectPtr<tvm::Object> p) : ObjectRef(p) {} 67 68 using ContainerType = PatternNode; 69 }; 70 71 /*! \brief A wildcard pattern: Accepts all input and binds nothing. */ 72 class PatternWildcard; 73 /*! \brief PatternWildcard container node */ 74 class PatternWildcardNode : public PatternNode { 75 public: VisitAttrs(tvm::AttrVisitor * v)76 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } 77 SEqualReduce(const PatternNode * other,SEqualReducer equal)78 bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; } 79 SHashReduce(SHashReducer hash_reduce)80 void SHashReduce(SHashReducer hash_reduce) const {} 81 82 static constexpr const char* _type_key = "relay.PatternWildcard"; 83 TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode); 84 }; 85 86 class PatternWildcard : public Pattern { 87 public: 88 /* \brief Overload the default constructors. */ 89 TVM_DLL PatternWildcard(); PatternWildcard(ObjectPtr<Object> n)90 explicit PatternWildcard(ObjectPtr<Object> n) : Pattern(n) {} 91 /* \brief Copy constructor. */ PatternWildcard(const PatternWildcard & pat)92 PatternWildcard(const PatternWildcard& pat) : PatternWildcard(pat.data_) {} 93 /* \brief Move constructor. */ PatternWildcard(PatternWildcard && pat)94 PatternWildcard(PatternWildcard&& pat) : PatternWildcard(std::move(pat.data_)) {} 95 /* \brief Copy assignment. */ 96 PatternWildcard& operator=(const PatternWildcard& other) { 97 (*this).data_ = other.data_; 98 return *this; 99 } 100 /* \brief Move assignment. */ 101 PatternWildcard& operator=(PatternWildcard&& other) { 102 (*this).data_ = std::move(other.data_); 103 return *this; 104 } 105 106 const PatternWildcardNode* operator->() const { 107 return static_cast<const PatternWildcardNode*>(get()); 108 } 109 110 using ContainerType = PatternWildcardNode; 111 }; 112 113 /*! \brief A var pattern. Accept all input and bind to a var. */ 114 class PatternVar; 115 /*! \brief PatternVar container node */ 116 class PatternVarNode : public PatternNode { 117 public: 118 /*! \brief Variable that stores the matched value. */ 119 tvm::relay::Var var; 120 VisitAttrs(tvm::AttrVisitor * v)121 void VisitAttrs(tvm::AttrVisitor* v) { 122 v->Visit("var", &var); 123 v->Visit("span", &span); 124 } 125 SEqualReduce(const PatternVarNode * other,SEqualReducer equal)126 bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const { 127 return equal.DefEqual(var, other->var); 128 } 129 SHashReduce(SHashReducer hash_reduce)130 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); } 131 132 static constexpr const char* _type_key = "relay.PatternVar"; 133 TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode); 134 }; 135 136 class PatternVar : public Pattern { 137 public: 138 /*! 139 * \brief Constructor 140 * \param var The var to construct a pattern 141 */ 142 TVM_DLL explicit PatternVar(tvm::relay::Var var); 143 144 TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode); 145 }; 146 147 /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ 148 class PatternConstructor; 149 /*! \brief PatternVar container node */ 150 class PatternConstructorNode : public PatternNode { 151 public: 152 /*! Constructor matched by the pattern. */ 153 Constructor constructor; 154 /*! Sub-patterns to match against each input to the constructor. */ 155 tvm::Array<Pattern> patterns; 156 VisitAttrs(tvm::AttrVisitor * v)157 void VisitAttrs(tvm::AttrVisitor* v) { 158 v->Visit("constructor", &constructor); 159 v->Visit("patterns", &patterns); 160 v->Visit("span", &span); 161 } 162 SEqualReduce(const PatternConstructorNode * other,SEqualReducer equal)163 bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const { 164 return equal(constructor, other->constructor) && equal(patterns, other->patterns); 165 } 166 SHashReduce(SHashReducer hash_reduce)167 void SHashReduce(SHashReducer hash_reduce) const { 168 hash_reduce(constructor); 169 hash_reduce(patterns); 170 } 171 172 static constexpr const char* _type_key = "relay.PatternConstructor"; 173 TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode); 174 }; 175 176 class PatternConstructor : public Pattern { 177 public: 178 /*! 179 * \brief Constructor 180 * \param constructor The constructor of a pattern 181 * \param patterns The sub-patterns for matching 182 */ 183 TVM_DLL PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns); 184 185 TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode); 186 }; 187 188 /*! \brief A tuple pattern. Matches a tuple, binds recursively. */ 189 class PatternTuple; 190 /*! \brief PatternVar container node */ 191 class PatternTupleNode : public PatternNode { 192 public: 193 /*! Sub-patterns to match against each value of the tuple. */ 194 tvm::Array<Pattern> patterns; 195 VisitAttrs(tvm::AttrVisitor * v)196 void VisitAttrs(tvm::AttrVisitor* v) { 197 v->Visit("patterns", &patterns); 198 v->Visit("span", &span); 199 } 200 SEqualReduce(const PatternTupleNode * other,SEqualReducer equal)201 bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const { 202 return equal(patterns, other->patterns); 203 } 204 SHashReduce(SHashReducer hash_reduce)205 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); } 206 207 static constexpr const char* _type_key = "relay.PatternTuple"; 208 TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode); 209 }; 210 211 class PatternTuple : public Pattern { 212 public: 213 /*! 214 * \brief Constructor 215 * \param patterns The sub-patterns to match against each value of the tuple 216 */ 217 TVM_DLL explicit PatternTuple(tvm::Array<Pattern> patterns); 218 219 TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode); 220 }; 221 222 /*! \brief A clause in a match expression. */ 223 class Clause; 224 /*! \brief Clause container node. */ 225 class ClauseNode : public Object { 226 public: 227 /*! \brief The pattern the clause matches. */ 228 Pattern lhs; 229 /*! \brief The resulting value. */ 230 Expr rhs; 231 VisitAttrs(tvm::AttrVisitor * v)232 void VisitAttrs(tvm::AttrVisitor* v) { 233 v->Visit("lhs", &lhs); 234 v->Visit("rhs", &rhs); 235 } 236 SEqualReduce(const ClauseNode * other,SEqualReducer equal)237 bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const { 238 return equal(lhs, other->lhs) && equal(rhs, other->rhs); 239 } 240 SHashReduce(SHashReducer hash_reduce)241 void SHashReduce(SHashReducer hash_reduce) const { 242 hash_reduce(lhs); 243 hash_reduce(rhs); 244 } 245 246 static constexpr const char* _type_key = "relay.Clause"; 247 static constexpr const bool _type_has_method_sequal_reduce = true; 248 static constexpr const bool _type_has_method_shash_reduce = true; 249 TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object); 250 }; 251 252 class Clause : public ObjectRef { 253 public: 254 /*! 255 * \brief Constructor 256 * \param lhs The pattern matched by the clause. 257 * \param rhs The resulting value 258 */ 259 TVM_DLL explicit Clause(Pattern lhs, Expr rhs); 260 261 TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode); 262 }; 263 264 /*! \brief ADT pattern matching exression. */ 265 class Match; 266 /*! \brief Match container node. */ 267 class MatchNode : public ExprNode { 268 public: 269 /*! \brief The input being deconstructed. */ 270 Expr data; 271 272 /*! \brief The match node clauses. */ 273 tvm::Array<Clause> clauses; 274 275 /*! \brief Should this match be complete (cover all cases)? 276 * If yes, the type checker will generate an error if there are any missing cases. 277 */ 278 bool complete; 279 VisitAttrs(tvm::AttrVisitor * v)280 void VisitAttrs(tvm::AttrVisitor* v) { 281 v->Visit("data", &data); 282 v->Visit("clauses", &clauses); 283 v->Visit("complete", &complete); 284 v->Visit("span", &span); 285 v->Visit("_checked_type_", &checked_type_); 286 } 287 SEqualReduce(const MatchNode * other,SEqualReducer equal)288 bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const { 289 equal->MarkGraphNode(); 290 return equal(data, other->data) && equal(clauses, other->clauses) && 291 equal(complete, other->complete); 292 } 293 SHashReduce(SHashReducer hash_reduce)294 void SHashReduce(SHashReducer hash_reduce) const { 295 hash_reduce->MarkGraphNode(); 296 hash_reduce(data); 297 hash_reduce(clauses); 298 hash_reduce(complete); 299 } 300 301 static constexpr const char* _type_key = "relay.Match"; 302 TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode); 303 }; 304 305 class Match : public Expr { 306 public: 307 /*! 308 * \brief Constructor 309 * \param data the input being deconstructed. 310 * \param clauses The clauses for matching. 311 * \param complete Indicate if this match is complete. 312 * \param span The span of the expression. 313 */ 314 TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true, Span span = Span()); 315 316 TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode); 317 }; 318 319 } // namespace relay 320 } // namespace tvm 321 322 #endif // TVM_RELAY_ADT_H_ 323