1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/arithmetic/pattern_match.h
22  *
23  * \brief Internal tool for expression-template based pattern matching.
24  *
25  * It helps to simplify pattern matching and rewrites.
26  * All the patterns are generated via expression template during compile time,
27  * so the result code should be as efficient as manually written pattern match code.
28  *
29  * The code below shows how to use the pattern matcher.
30  *
31  * \code
32  *
33  *  // max(x + z, y + z) => max(x, y) + z
34  *  arith::PVar<Expr> x, y, z;
35  *
36  *  // The following code tries to match the declared pattern.
37  *  // Match will fill the result of match into PVar if successful.
38  *  // Note that z occurs twice in the pattern,
39  *  // an equality check is performed to ensure each occurance of z
40  *  // is equivalent to each other.
41  *  if (max(x + z, y + z).Match(expr)) {
42  *    // Eval evaluates a pattern with the current matched value.
43  *    // The filled value is valid until the next call to Match.
44  *    return (max(x, y) + z).Eval();
45  *  }
46  *
47  *  tvm::Var tx, ty;
48  *  arith::PVar<Integer> c;
49  *  arith::PVar<Var> v;
50  *  // We can match integer and Var, both of which are
51  *  // special case container of Expr
52  *  CHECK((v * c).Match(tx * 3));
53  *  CHECK_EQ(c.Eval()->value, 3);
54  *  // cannot match c to ty
55  *  CHECK(!(v * c).Match(tx * ty));
56  *
57  * \endcode
58  *
59  * \note The pattern matcher is not threadsafe,
60  *       do not use the same PVar in multiple threads.
61  *
62  *       Please be aware that the filled value in a PVar
63  *       can be overriden in the next call to Match.
64  */
65 #ifndef TVM_ARITHMETIC_PATTERN_MATCH_H_
66 #define TVM_ARITHMETIC_PATTERN_MATCH_H_
67 
68 #include <tvm/ir_pass.h>
69 #include <tuple>
70 #include "const_fold.h"
71 
72 namespace tvm {
73 namespace arith {
74 /*!
75  * \brief Base class of all the patterns.
76  *
77  * There are two major member functions supported by each pattern.
78  * - Match: checks if value matches the pattern.
79  * - Eval: construct a new value based on matched values in PVar.
80  *
81  * We use curiously recurring template pattern to construct
82  * expression templates.
83  *
84  * \tparam Derived The type of the derived class.
85  */
86 template<typename Derived>
87 class Pattern {
88  public:
89   /*!
90    * \brief Nested storage type in the expression.
91    *
92    *  Depending on the Derived class,
93    *  Nested can be Derived (nest by value) or
94    *  const Derived& (nest by reference).
95    *
96    *  The trick of Nested typedef originates from Eigen.
97    *
98    * \note We use nest by value for intermediate expressions,
99    *       and nest by reference for PVars.
100    */
101   using Nested = Derived;
102   /*!
103    * \brief Check if value matches the current pattern.
104    *
105    * This call also populates the PVars with matched value.
106    * The values in PVars are valid until the next call to Match.
107    *
108    * \return whether value matches the pattern.
109    */
110   template<typename NodeType>
Match(const NodeType & value)111   bool Match(const NodeType& value) const {
112     derived().InitMatch_();
113     return derived().Match_(value);
114   }
115   /*! \return Derived instance of current class. */
derived()116   const Derived& derived() const {
117     return *static_cast<const Derived*>(this);
118   }
119 };
120 
121 /*!
122  * \brief Default deep equality checker
123  * \tparam T the comparison point.
124  */
125 template<typename T>
126 class PEqualChecker {
127  public:
operator()128   bool operator()(const T& lhs, const T& rhs) const {
129     return lhs == rhs;
130   }
131 };
132 
133 template<>
134 class PEqualChecker<Expr> {
135  public:
operator()136   bool operator()(const Expr& lhs, const Expr& rhs) const {
137     if (lhs.same_as(rhs)) return true;
138     return ir::Equal(lhs, rhs);
139   }
140 };
141 
142 template<>
143 class PEqualChecker<Integer> {
144  public:
operator()145   bool operator()(const Integer& lhs, const Integer& rhs) const {
146     return lhs->value == rhs->value;
147   }
148 };
149 
150 template<>
151 class PEqualChecker<Var> {
152  public:
operator()153   bool operator()(const Var& lhs, const Var& rhs) const {
154     return lhs.same_as(rhs);
155   }
156 };
157 
158 /*!
159  * \brief Pattern variable container.
160  *
161  * PVar is used as a "hole" in the pattern that can be matched.
162  *
163  * \tparam T the type of the hole.
164  *
165  * \note PVar is not thread safe.
166  *       Do not use the same PVar in multiple threads.
167  */
168 template<typename T>
169 class PVar : public Pattern<PVar<T> > {
170  public:
171   // Store PVars by reference in the expression.
172   using Nested = const PVar<T>&;
173 
InitMatch_()174   void InitMatch_() const {
175     filled_ = false;
176   }
177 
Match_(const T & value)178   bool Match_(const T& value) const {
179     if (!filled_) {
180       value_ = value;
181       filled_ = true;
182       return true;
183     } else {
184       return PEqualChecker<T>()(value_, value);
185     }
186   }
187 
188   template<typename NodeRefType,
189            typename = typename std::enable_if<
190              std::is_base_of<NodeRefType, T>::value>::type>
Match_(const NodeRefType & value)191   bool Match_(const NodeRefType& value) const {
192     if (const auto* ptr = value.template as<typename T::ContainerType>()) {
193       return Match_(GetRef<T>(ptr));
194     } else {
195       return false;
196     }
197   }
198 
Eval()199   T Eval() const {
200     CHECK(filled_);
201     return value_;
202   }
203 
204  protected:
205   /*! \brief The matched value */
206   mutable T value_;
207   /*! \brief whether the variable has been filled */
208   mutable bool filled_{false};
209 };
210 
211 /*!
212  * \brief Constant Pattern variable container.
213  *
214  * \tparam T the type of the hole.
215  */
216 template<typename T>
217 class PConst : public Pattern<PConst<T> > {
218  public:
PConst(T value)219   PConst(T value)  // NOLINT(*)
220       : value_(value) {}
221 
InitMatch_()222   void InitMatch_() const {}
223 
Match_(const T & value)224   bool Match_(const T& value) const {
225     return PEqualChecker<T>()(value_, value);
226   }
227 
Eval()228   T Eval() const {
229     return value_;
230   }
231 
232  private:
233   const T value_;
234 };
235 
236 /*!
237  * \brief Pattern binary expression.
238  * \tparam NodeType The AST node type.
239  * \tparam TA The pattern type of the first operand.
240  * \tparam TB The pattern type of the second operand.
241  */
242 template<typename NodeType, typename TA, typename TB>
243 class PBinaryExpr :
244       public Pattern<PBinaryExpr<NodeType, TA, TB> > {
245  public:
PBinaryExpr(const TA & a,const TB & b)246   PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {}
247 
InitMatch_()248   void InitMatch_() const {
249     a_.InitMatch_();
250     b_.InitMatch_();
251   }
252 
Match_(const NodeRef & node)253   bool Match_(const NodeRef& node) const {
254     if (const NodeType* ptr = node.as<NodeType>()) {
255       if (!a_.Match_(ptr->a)) return false;
256       if (!b_.Match_(ptr->b)) return false;
257       return true;
258     } else {
259       return false;
260     }
261   }
262 
Eval()263   Expr Eval() const {
264     Expr lhs = a_.Eval();
265     Expr rhs = b_.Eval();
266     Expr ret = TryConstFold<NodeType>(lhs, rhs);
267     if (ret.defined()) return ret;
268     return NodeType::make(lhs, rhs);
269   }
270 
271  private:
272   typename TA::Nested a_;
273   typename TB::Nested b_;
274 };
275 
276 template<typename TA>
277 class PConstWithTypeLike :
278       public Pattern<PConstWithTypeLike<TA> > {
279  public:
PConstWithTypeLike(const TA & ref,int64_t value)280   PConstWithTypeLike(const TA& ref, int64_t value)
281       : ref_(ref), value_(value) {}
282 
InitMatch_()283   void InitMatch_() const {}
284 
Match_(const NodeRef & node)285   bool Match_(const NodeRef& node) const {
286     if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
287       return ptr->value == value_;
288     } else {
289       return false;
290     }
291   }
292 
Eval()293   Expr Eval() const {
294     return make_const(ref_.Eval().type(), value_);
295   }
296 
297  private:
298   typename TA::Nested ref_;
299   int64_t value_;
300 };
301 
302 
303 #define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep)     \
304   template<typename TA, typename TB>                                \
305   inline PBinaryExpr<NodeName, TA, TB>                              \
306   FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {            \
307     CheckStep;                                                      \
308     return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
309   }                                                                 \
310   template<typename TA>                                             \
311   inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> >         \
312   FuncName(const Pattern<TA>& a, int64_t b) {                       \
313     CheckStep;                                                      \
314     return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b));     \
315   }                                                                 \
316   template<typename TA>                                             \
317   inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA>          \
318   FuncName(int64_t b, const Pattern<TA>& a) {                       \
319     CheckStep;                                                      \
320     return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a);     \
321   }
322 
323 #define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
324   TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )
325 
326 
327 // raise ambiguity error for operator overload of / and %
328 TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
329 TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));
330 
331 // arithmetic expressions
332 TVM_PATTERN_BINARY_OP(operator+, ir::Add);
333 TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
334 TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
335 TVM_PATTERN_BINARY_OP(min, ir::Min);
336 TVM_PATTERN_BINARY_OP(max, ir::Max);
337 TVM_PATTERN_BINARY_OP(div, ir::Div);
338 TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
339 TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
340 TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
341 TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
342 
343 // logical expressions
344 TVM_PATTERN_BINARY_OP(operator>, ir::GT);
345 TVM_PATTERN_BINARY_OP(operator>=, ir::GE);
346 TVM_PATTERN_BINARY_OP(operator<, ir::LT);
347 TVM_PATTERN_BINARY_OP(operator<=, ir::LE);
348 TVM_PATTERN_BINARY_OP(operator==, ir::EQ);
349 TVM_PATTERN_BINARY_OP(operator!=, ir::NE);
350 TVM_PATTERN_BINARY_OP(operator&&, ir::And);
351 TVM_PATTERN_BINARY_OP(operator||, ir::Or);
352 
353 /*!
354  * \brief Pattern not expression.
355  * \tparam TA The pattern type of the true operand.
356  */
357 template<typename TA>
358 class PNotExpr : public Pattern<PNotExpr<TA> > {
359  public:
PNotExpr(const TA & value)360   explicit PNotExpr(const TA& value)
361       : value_(value) {}
362 
InitMatch_()363   void InitMatch_() const {
364     value_.InitMatch_();
365   }
366 
Match_(const NodeRef & node)367   bool Match_(const NodeRef& node) const {
368     if (const ir::Not* ptr = node.as<ir::Not>()) {
369       if (!value_.Match_(ptr->a)) return false;
370       return true;
371     } else {
372       return false;
373     }
374   }
375 
Eval()376   Expr Eval() const {
377     return ir::Not::make(value_.Eval());
378   }
379 
380  private:
381   typename TA::Nested value_;
382 };
383 
384 template<typename TA>
385 inline PNotExpr<TA> operator!(const Pattern<TA>& value) {
386   return PNotExpr<TA>(value.derived());
387 }
388 
389 // select
390 /*!
391  * \brief Pattern select expression.
392  * \tparam TCond The pattern type of the condition.
393  * \tparam TA The pattern type of the true operand.
394  * \tparam TB The pattern type of the false operand.
395  */
396 template<typename TCond, typename TA, typename TB>
397 class PSelectExpr :
398       public Pattern<PSelectExpr<TCond, TA, TB> > {
399  public:
PSelectExpr(const TCond & condition,const TA & true_value,const TB & false_value)400   PSelectExpr(const TCond& condition,
401               const TA& true_value,
402               const TB& false_value)
403       : condition_(condition),
404         true_value_(true_value),
405         false_value_(false_value) {}
406 
InitMatch_()407   void InitMatch_() const {
408     condition_.InitMatch_();
409     true_value_.InitMatch_();
410     false_value_.InitMatch_();
411   }
412 
Match_(const NodeRef & node)413   bool Match_(const NodeRef& node) const {
414     if (const ir::Select* ptr = node.as<ir::Select>()) {
415       if (!condition_.Match_(ptr->condition)) return false;
416       if (!true_value_.Match_(ptr->true_value)) return false;
417       if (!false_value_.Match_(ptr->false_value)) return false;
418       return true;
419     } else {
420       return false;
421     }
422   }
423 
Eval()424   Expr Eval() const {
425     return ir::Select::make(
426         condition_.Eval(), true_value_.Eval(), false_value_.Eval());
427   }
428 
429  private:
430   typename TCond::Nested condition_;
431   typename TA::Nested true_value_;
432   typename TB::Nested false_value_;
433 };
434 
435 /*!
436  * \brief Construct a select pattern.
437  *
438  * \param condition The condition expression.
439  * \param true_value The value when condition is true.
440  * \param true_value The value when condition is false.
441  *
442  * \return The result pattern.
443  *
444  * \tparam TCond The pattern type of the condition.
445  * \tparam TA The pattern type of the true operand.
446  * \tparam TB The pattern type of the false operand.
447  */
448 template<typename TCond, typename TA, typename TB>
449 inline PSelectExpr<TCond, TA, TB>
select(const Pattern<TCond> & condition,const Pattern<TA> & true_value,const Pattern<TB> & false_value)450 select(const Pattern<TCond>& condition,
451        const Pattern<TA>& true_value,
452        const Pattern<TB>& false_value) {
453   return PSelectExpr<TCond, TA, TB>(
454       condition.derived(), true_value.derived(), false_value.derived());
455 }
456 
457 /*!
458  * \brief Pattern cast expression.
459  * \tparam DType The Pattern type of dtype.
460  * \tparam TA The pattern type of the first operand.
461  */
462 template<typename DType, typename TA>
463 class PCastExpr :
464       public Pattern<PCastExpr<DType, TA> > {
465  public:
PCastExpr(const DType & dtype,const TA & value)466   PCastExpr(const DType& dtype, const TA& value)
467       : dtype_(dtype), value_(value) {
468   }
469 
InitMatch_()470   void InitMatch_() const {
471     dtype_.InitMatch_();
472     value_.InitMatch_();
473   }
474 
Match_(const NodeRef & node)475   bool Match_(const NodeRef& node) const {
476     if (const ir::Cast* ptr = node.as<ir::Cast>()) {
477       if (!dtype_.Match_(ptr->type)) return false;
478       if (!value_.Match_(ptr->value)) return false;
479       return true;
480     } else {
481       return false;
482     }
483   }
484 
Eval()485   Expr Eval() const {
486     return ir::Cast::make(dtype_.Eval(), value_.Eval());
487   }
488 
489  private:
490   typename DType::Nested dtype_;
491   typename TA::Nested value_;
492 };
493 
494 /*!
495  * \brief Construct a cast pattern.
496  *
497  * \param dtype The target data type, can be PVar<Type> or PConst<Type>.
498  * \param value The input type.
499  *
500  * \return The result pattern.
501  *
502  * \tparam DType The pattern type of type.
503  * \tparam TA The pattern type of value.
504  */
505 template<typename DType, typename TA>
506 inline PCastExpr<DType, TA>
cast(const Pattern<DType> & dtype,const Pattern<TA> & value)507 cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
508   return PCastExpr<DType, TA>(dtype.derived(), value.derived());
509 }
510 
511 /*!
512  * \brief Pattern ramp expression.
513  * \tparam TBase The pattern type of the base.
514  * \tparam TStride The pattern type of the stride.
515  * \tparam TLanes The pattern type of the lanes.
516  */
517 template<typename TBase, typename TStride, typename TLanes>
518 class PRampExpr :
519       public Pattern<PRampExpr<TBase, TStride, TLanes> > {
520  public:
PRampExpr(const TBase & base,const TStride & stride,const TLanes & lanes)521   PRampExpr(const TBase& base,
522             const TStride& stride,
523             const TLanes& lanes)
524       : base_(base), stride_(stride), lanes_(lanes) {
525   }
526 
InitMatch_()527   void InitMatch_() const {
528     base_.InitMatch_();
529     stride_.InitMatch_();
530     lanes_.InitMatch_();
531   }
532 
Match_(const NodeRef & node)533   bool Match_(const NodeRef& node) const {
534     if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
535       if (!base_.Match_(ptr->base)) return false;
536       if (!stride_.Match_(ptr->stride)) return false;
537       if (!lanes_.Match_(ptr->lanes)) return false;
538       return true;
539     } else {
540       return false;
541     }
542   }
543 
Eval()544   Expr Eval() const {
545     return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
546   }
547 
548  private:
549   typename TBase::Nested base_;
550   typename TStride::Nested stride_;
551   typename TLanes::Nested lanes_;
552 };
553 
554 /*!
555  * \brief Construct a ramp pattern.
556  *
557  * \param base The base pattern.
558  * \param stride The stride pattern.
559  * \param lanes The lanes pattern.
560  *
561  * \return The result pattern.
562  *
563  * \tparam TBase The pattern type of the base.
564  * \tparam TStride The pattern type of the stride.
565  * \tparam TLanes The pattern type of the lanes.
566  */
567 template<typename TBase, typename TStride, typename TLanes>
568 inline PRampExpr<TBase, TStride, TLanes>
ramp(const Pattern<TBase> & base,const Pattern<TStride> & stride,const Pattern<TLanes> & lanes)569 ramp(const Pattern<TBase>& base,
570      const Pattern<TStride>& stride,
571      const Pattern<TLanes>& lanes) {
572   return PRampExpr<TBase, TStride, TLanes>(
573       base.derived(), stride.derived(), lanes.derived());
574 }
575 
576 /*!
577  * \brief Pattern broadcast expression.
578  * \tparam TA The pattern type of the value.
579  * \tparam TLanes The pattern type of the lanes.
580  */
581 template<typename TA, typename TLanes>
582 class PBroadcastExpr :
583       public Pattern<PBroadcastExpr<TA, TLanes> > {
584  public:
PBroadcastExpr(const TA & value,const TLanes & lanes)585   PBroadcastExpr(const TA& value,
586                  const TLanes& lanes)
587       : value_(value), lanes_(lanes) {
588   }
589 
InitMatch_()590   void InitMatch_() const {
591     value_.InitMatch_();
592     lanes_.InitMatch_();
593   }
594 
Match_(const NodeRef & node)595   bool Match_(const NodeRef& node) const {
596     if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
597       if (!value_.Match_(ptr->value)) return false;
598       if (!lanes_.Match_(ptr->lanes)) return false;
599       return true;
600     } else {
601       return false;
602     }
603   }
604 
Eval()605   Expr Eval() const {
606     return ir::Broadcast::make(value_.Eval(), lanes_.Eval());
607   }
608 
609  private:
610   typename TA::Nested value_;
611   typename TLanes::Nested lanes_;
612 };
613 
614 /*!
615  * \brief Construct a broadcast pattern.
616  *
617  * \param value The value pattern.
618  * \param lanes The lanes pattern.
619  *
620  * \return The result pattern.
621  *
622  * \tparam TA The pattern type of the value.
623  * \tparam TLanes The pattern type of the lanes.
624  */
625 template<typename TA, typename TLanes>
626 inline PBroadcastExpr<TA, TLanes>
broadcast(const Pattern<TA> & value,const Pattern<TLanes> & lanes)627 broadcast(const Pattern<TA>& value, const Pattern<TLanes>& lanes) {
628   return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived());
629 }
630 
631 // internal namespace
632 namespace detail {
633 // implementation details for  CallExpr
634 template<bool stop, std::size_t I, typename F>
635 struct tuple_for_each_dispatcher {
636   template<typename TTuple>
runtuple_for_each_dispatcher637   static void run(F& f, const TTuple& tuple) { // NOLINT(*)
638     f(I, std::get<I>(tuple));
639     tuple_for_each_dispatcher<
640       (I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>
641         ::run(f, tuple);
642   }
643 };
644 
645 template<std::size_t I, typename F>
646 struct tuple_for_each_dispatcher<true, I, F> {
647   template<typename TTuple>
648   static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
649 };
650 
651 template<typename F, typename TTuple>
652 inline void tuple_for_each(F& f, const TTuple& tuple) {  // NOLINT(*)
653   tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>
654       ::run(f, tuple);
655 }
656 
657 struct PCallExprInitMatchFunctor {
658   template<typename T>
659   void operator()(size_t i, const T& pattern) const {
660     pattern.InitMatch_();
661   }
662 };
663 
664 struct PCallExprMatchFunctor {
665   const ir::Call* call_;
666   bool matched_{true};
667 
668   explicit PCallExprMatchFunctor(const ir::Call* call)
669       : call_(call) {}
670 
671   template<typename T>
672   void operator()(size_t i, const T& pattern) {
673     matched_ = matched_ && pattern.Match_(call_->args[i]);
674   }
675 };
676 
677 struct PCallExprEvalArgsFunctor {
678   Array<Expr> args_;
679 
680   template<typename T>
681   void operator()(size_t i, const T& pattern) {
682     args_.push_back(pattern.Eval());
683   }
684 };
685 }  // namespace detail
686 
687 /*!
688  * \brief Pattern CallExpr expression.
689  * \tparam Op The operator functor class.
690  * \tparam TArgs The arguments.
691  * \note Op functor contains the name of the function and
692  *          the implementation of Eval.
693  */
694 template<typename Op, typename ...TArgs>
695 class PCallExpr :
696       public Pattern<PCallExpr<Op, TArgs...> > {
697  public:
698   explicit PCallExpr(const TArgs&... args)
699       : args_(args...) {
700   }
701 
702   void InitMatch_() const {
703     detail::PCallExprInitMatchFunctor finit;
704     detail::tuple_for_each(finit, args_);
705   }
706 
707   bool Match_(const NodeRef& node) const {
708     if (const ir::Call* ptr = node.as<ir::Call>()) {
709       if (ptr->args.size() != sizeof...(TArgs)) return false;
710       if (ptr->name != Op::kName) return false;
711       detail::PCallExprMatchFunctor fmatch(ptr);
712       detail::tuple_for_each(fmatch, args_);
713       return fmatch.matched_;
714     } else {
715       return false;
716     }
717   }
718 
719   Expr Eval() const {
720     detail::PCallExprEvalArgsFunctor feval_args;
721     detail::tuple_for_each(feval_args, args_);
722     return Op::Eval(feval_args.args_);
723   }
724 
725  private:
726   std::tuple<typename TArgs::Nested...> args_;
727 };
728 
729 // arithemetic intrinsics
730 #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr)        \
731   struct OpName {                                                     \
732     static Expr Eval(Array<Expr> args) {                              \
733       return ir::Call::make(args[0].type(), kName, args,              \
734                             ir::Call::PureIntrinsic);                 \
735     }                                                                 \
736     static constexpr const char* kName = IntrinStr;                   \
737   };                                                                  \
738   template<typename TA, typename TB>                                  \
739   inline PCallExpr<OpName, TA, TB>                                    \
740   FuncName(const Pattern<TA>& a, const Pattern<TB>& b) {              \
741     return PCallExpr<OpName, TA, TB>(a.derived(), b.derived());             \
742   }
743 
744 TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
745 TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right");
746 TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and");
747 TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or");
748 TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
749 
750 // unary intrinsics
751 #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr)         \
752   struct OpName {                                                     \
753     static Expr Eval(Array<Expr> args) {                              \
754       return ir::Call::make(args[0].type(), kName, args,              \
755                             ir::Call::PureIntrinsic);                 \
756     }                                                                 \
757     static constexpr const char* kName = IntrinStr;                   \
758   };                                                                  \
759   template<typename TA>                                               \
760   inline PCallExpr<OpName, TA>                                        \
761   FuncName(const Pattern<TA>& a) {                                    \
762     return PCallExpr<OpName, TA>(a.derived());                           \
763   }
764 
765 TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
766 
767 // if_then_else
768 struct PIfThenElseOp {
769   static Expr Eval(Array<Expr> args) {
770     return ir::Call::make(
771         args[1].type(), kName, args,
772         ir::Call::PureIntrinsic);
773   }
774   static constexpr const char* kName = "tvm_if_then_else";
775 };
776 
777 /*!
778  * \brief Construct a if_then_else pattern.
779  *
780  * \param cond The condition expression.
781  * \param true_value The value when condition is true.
782  * \param true_value The value when condition is false.
783  *
784  * \return The result pattern.
785  *
786  * \tparam TCond The pattern type of the condition.
787  * \tparam TA The pattern type of the true operand.
788  * \tparam TB The pattern type of the false operand.
789  */
790 template<typename TCond, typename TA, typename TB>
791 inline PCallExpr<PIfThenElseOp, TCond, TA, TB>
792 if_then_else(const Pattern<TCond>& cond,
793              const Pattern<TA>& true_value,
794              const Pattern<TB>& false_value) {
795   return PCallExpr<PIfThenElseOp, TCond, TA, TB>(
796       cond.derived(), true_value.derived(), false_value.derived());
797 }
798 
799 }  // namespace arith
800 }  // namespace tvm
801 #endif  // TVM_ARITHMETIC_PATTERN_MATCH_H_
802