1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/relay/adt.h
22  * \brief Algebraic data types for Relay
23  */
24 #ifndef TVM_RELAY_ADT_H_
25 #define TVM_RELAY_ADT_H_
26 
27 #include <tvm/ir/adt.h>
28 #include <tvm/ir/attrs.h>
29 #include <tvm/relay/base.h>
30 #include <tvm/relay/expr.h>
31 #include <tvm/relay/type.h>
32 
33 #include <functional>
34 #include <string>
35 #include <utility>
36 
37 namespace tvm {
38 namespace relay {
39 
40 using Constructor = tvm::Constructor;
41 using ConstructorNode = tvm::ConstructorNode;
42 
43 using TypeData = tvm::TypeData;
44 using TypeDataNode = tvm::TypeDataNode;
45 
46 /*! \brief Base type for declaring relay pattern. */
47 class PatternNode : public RelayNode {
48  public:
49   static constexpr const char* _type_key = "relay.Pattern";
50   static constexpr const bool _type_has_method_sequal_reduce = true;
51   static constexpr const bool _type_has_method_shash_reduce = true;
52   TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
53 };
54 
55 /*!
56  * \brief Pattern is the base type for an ADT match pattern in Relay.
57  *
58  * Given an ADT value, a pattern might accept it and bind the pattern variable to some value
59  * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
60  *
61  * ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
62  */
63 class Pattern : public ObjectRef {
64  public:
Pattern()65   Pattern() {}
Pattern(ObjectPtr<tvm::Object> p)66   explicit Pattern(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}
67 
68   using ContainerType = PatternNode;
69 };
70 
71 /*! \brief A wildcard pattern: Accepts all input and binds nothing. */
72 class PatternWildcard;
73 /*! \brief PatternWildcard container node */
74 class PatternWildcardNode : public PatternNode {
75  public:
VisitAttrs(tvm::AttrVisitor * v)76   void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }
77 
SEqualReduce(const PatternNode * other,SEqualReducer equal)78   bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; }
79 
SHashReduce(SHashReducer hash_reduce)80   void SHashReduce(SHashReducer hash_reduce) const {}
81 
82   static constexpr const char* _type_key = "relay.PatternWildcard";
83   TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
84 };
85 
86 class PatternWildcard : public Pattern {
87  public:
88   /* \brief Overload the default constructors. */
89   TVM_DLL PatternWildcard();
PatternWildcard(ObjectPtr<Object> n)90   explicit PatternWildcard(ObjectPtr<Object> n) : Pattern(n) {}
91   /* \brief Copy constructor. */
PatternWildcard(const PatternWildcard & pat)92   PatternWildcard(const PatternWildcard& pat) : PatternWildcard(pat.data_) {}
93   /* \brief Move constructor. */
PatternWildcard(PatternWildcard && pat)94   PatternWildcard(PatternWildcard&& pat) : PatternWildcard(std::move(pat.data_)) {}
95   /* \brief Copy assignment. */
96   PatternWildcard& operator=(const PatternWildcard& other) {
97     (*this).data_ = other.data_;
98     return *this;
99   }
100   /* \brief Move assignment. */
101   PatternWildcard& operator=(PatternWildcard&& other) {
102     (*this).data_ = std::move(other.data_);
103     return *this;
104   }
105 
106   const PatternWildcardNode* operator->() const {
107     return static_cast<const PatternWildcardNode*>(get());
108   }
109 
110   using ContainerType = PatternWildcardNode;
111 };
112 
113 /*! \brief A var pattern. Accept all input and bind to a var. */
114 class PatternVar;
115 /*! \brief PatternVar container node */
116 class PatternVarNode : public PatternNode {
117  public:
118   /*! \brief Variable that stores the matched value. */
119   tvm::relay::Var var;
120 
VisitAttrs(tvm::AttrVisitor * v)121   void VisitAttrs(tvm::AttrVisitor* v) {
122     v->Visit("var", &var);
123     v->Visit("span", &span);
124   }
125 
SEqualReduce(const PatternVarNode * other,SEqualReducer equal)126   bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
127     return equal.DefEqual(var, other->var);
128   }
129 
SHashReduce(SHashReducer hash_reduce)130   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); }
131 
132   static constexpr const char* _type_key = "relay.PatternVar";
133   TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
134 };
135 
136 class PatternVar : public Pattern {
137  public:
138   /*!
139    * \brief Constructor
140    * \param var The var to construct a pattern
141    */
142   TVM_DLL explicit PatternVar(tvm::relay::Var var);
143 
144   TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
145 };
146 
147 /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
148 class PatternConstructor;
149 /*! \brief PatternVar container node */
150 class PatternConstructorNode : public PatternNode {
151  public:
152   /*! Constructor matched by the pattern. */
153   Constructor constructor;
154   /*! Sub-patterns to match against each input to the constructor. */
155   tvm::Array<Pattern> patterns;
156 
VisitAttrs(tvm::AttrVisitor * v)157   void VisitAttrs(tvm::AttrVisitor* v) {
158     v->Visit("constructor", &constructor);
159     v->Visit("patterns", &patterns);
160     v->Visit("span", &span);
161   }
162 
SEqualReduce(const PatternConstructorNode * other,SEqualReducer equal)163   bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const {
164     return equal(constructor, other->constructor) && equal(patterns, other->patterns);
165   }
166 
SHashReduce(SHashReducer hash_reduce)167   void SHashReduce(SHashReducer hash_reduce) const {
168     hash_reduce(constructor);
169     hash_reduce(patterns);
170   }
171 
172   static constexpr const char* _type_key = "relay.PatternConstructor";
173   TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
174 };
175 
176 class PatternConstructor : public Pattern {
177  public:
178   /*!
179    * \brief Constructor
180    * \param constructor The constructor of a pattern
181    * \param patterns The sub-patterns for matching
182    */
183   TVM_DLL PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns);
184 
185   TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode);
186 };
187 
188 /*! \brief A tuple pattern. Matches a tuple, binds recursively. */
189 class PatternTuple;
190 /*! \brief PatternVar container node */
191 class PatternTupleNode : public PatternNode {
192  public:
193   /*! Sub-patterns to match against each value of the tuple. */
194   tvm::Array<Pattern> patterns;
195 
VisitAttrs(tvm::AttrVisitor * v)196   void VisitAttrs(tvm::AttrVisitor* v) {
197     v->Visit("patterns", &patterns);
198     v->Visit("span", &span);
199   }
200 
SEqualReduce(const PatternTupleNode * other,SEqualReducer equal)201   bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
202     return equal(patterns, other->patterns);
203   }
204 
SHashReduce(SHashReducer hash_reduce)205   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); }
206 
207   static constexpr const char* _type_key = "relay.PatternTuple";
208   TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
209 };
210 
211 class PatternTuple : public Pattern {
212  public:
213   /*!
214    * \brief Constructor
215    * \param patterns The sub-patterns to match against each value of the tuple
216    */
217   TVM_DLL explicit PatternTuple(tvm::Array<Pattern> patterns);
218 
219   TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
220 };
221 
222 /*! \brief A clause in a match expression. */
223 class Clause;
224 /*! \brief Clause container node. */
225 class ClauseNode : public Object {
226  public:
227   /*! \brief The pattern the clause matches. */
228   Pattern lhs;
229   /*! \brief The resulting value. */
230   Expr rhs;
231 
VisitAttrs(tvm::AttrVisitor * v)232   void VisitAttrs(tvm::AttrVisitor* v) {
233     v->Visit("lhs", &lhs);
234     v->Visit("rhs", &rhs);
235   }
236 
SEqualReduce(const ClauseNode * other,SEqualReducer equal)237   bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
238     return equal(lhs, other->lhs) && equal(rhs, other->rhs);
239   }
240 
SHashReduce(SHashReducer hash_reduce)241   void SHashReduce(SHashReducer hash_reduce) const {
242     hash_reduce(lhs);
243     hash_reduce(rhs);
244   }
245 
246   static constexpr const char* _type_key = "relay.Clause";
247   static constexpr const bool _type_has_method_sequal_reduce = true;
248   static constexpr const bool _type_has_method_shash_reduce = true;
249   TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
250 };
251 
252 class Clause : public ObjectRef {
253  public:
254   /*!
255    * \brief Constructor
256    * \param lhs The pattern matched by the clause.
257    * \param rhs The resulting value
258    */
259   TVM_DLL explicit Clause(Pattern lhs, Expr rhs);
260 
261   TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode);
262 };
263 
264 /*! \brief ADT pattern matching exression. */
265 class Match;
266 /*! \brief Match container node. */
267 class MatchNode : public ExprNode {
268  public:
269   /*! \brief The input being deconstructed. */
270   Expr data;
271 
272   /*! \brief The match node clauses. */
273   tvm::Array<Clause> clauses;
274 
275   /*! \brief Should this match be complete (cover all cases)?
276    *  If yes, the type checker will generate an error if there are any missing cases.
277    */
278   bool complete;
279 
VisitAttrs(tvm::AttrVisitor * v)280   void VisitAttrs(tvm::AttrVisitor* v) {
281     v->Visit("data", &data);
282     v->Visit("clauses", &clauses);
283     v->Visit("complete", &complete);
284     v->Visit("span", &span);
285     v->Visit("_checked_type_", &checked_type_);
286   }
287 
SEqualReduce(const MatchNode * other,SEqualReducer equal)288   bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
289     equal->MarkGraphNode();
290     return equal(data, other->data) && equal(clauses, other->clauses) &&
291            equal(complete, other->complete);
292   }
293 
SHashReduce(SHashReducer hash_reduce)294   void SHashReduce(SHashReducer hash_reduce) const {
295     hash_reduce->MarkGraphNode();
296     hash_reduce(data);
297     hash_reduce(clauses);
298     hash_reduce(complete);
299   }
300 
301   static constexpr const char* _type_key = "relay.Match";
302   TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
303 };
304 
305 class Match : public Expr {
306  public:
307   /*!
308    * \brief Constructor
309    * \param data the input being deconstructed.
310    * \param clauses The clauses for matching.
311    * \param complete Indicate if this match is complete.
312    * \param span The span of the expression.
313    */
314   TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true, Span span = Span());
315 
316   TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
317 };
318 
319 }  // namespace relay
320 }  // namespace tvm
321 
322 #endif  // TVM_RELAY_ADT_H_
323