1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/arithmetic/pattern_match.h
22 *
23 * \brief Internal tool for expression-template based pattern matching.
24 *
25 * It helps to simplify pattern matching and rewrites.
26 * All the patterns are generated via expression template during compile time,
27 * so the result code should be as efficient as manually written pattern match code.
28 *
29 * The code below shows how to use the pattern matcher.
30 *
31 * \code
32 *
33 * // max(x + z, y + z) => max(x, y) + z
34 * arith::PVar<Expr> x, y, z;
35 *
36 * // The following code tries to match the declared pattern.
37 * // Match will fill the result of match into PVar if successful.
38 * // Note that z occurs twice in the pattern,
39 * // an equality check is performed to ensure each occurance of z
40 * // is equivalent to each other.
41 * if (max(x + z, y + z).Match(expr)) {
42 * // Eval evaluates a pattern with the current matched value.
43 * // The filled value is valid until the next call to Match.
44 * return (max(x, y) + z).Eval();
45 * }
46 *
47 * tvm::Var tx, ty;
48 * arith::PVar<Integer> c;
49 * arith::PVar<Var> v;
50 * // We can match integer and Var, both of which are
51 * // special case container of Expr
52 * CHECK((v * c).Match(tx * 3));
53 * CHECK_EQ(c.Eval()->value, 3);
54 * // cannot match c to ty
55 * CHECK(!(v * c).Match(tx * ty));
56 *
57 * \endcode
58 *
59 * \note The pattern matcher is not threadsafe,
60 * do not use the same PVar in multiple threads.
61 *
62 * Please be aware that the filled value in a PVar
63 * can be overriden in the next call to Match.
64 */
65 #ifndef TVM_ARITHMETIC_PATTERN_MATCH_H_
66 #define TVM_ARITHMETIC_PATTERN_MATCH_H_
67
68 #include <tvm/ir_pass.h>
69 #include <tuple>
70 #include "const_fold.h"
71
72 namespace tvm {
73 namespace arith {
74 /*!
75 * \brief Base class of all the patterns.
76 *
77 * There are two major member functions supported by each pattern.
78 * - Match: checks if value matches the pattern.
79 * - Eval: construct a new value based on matched values in PVar.
80 *
81 * We use curiously recurring template pattern to construct
82 * expression templates.
83 *
84 * \tparam Derived The type of the derived class.
85 */
86 template<typename Derived>
87 class Pattern {
88 public:
89 /*!
90 * \brief Nested storage type in the expression.
91 *
92 * Depending on the Derived class,
93 * Nested can be Derived (nest by value) or
94 * const Derived& (nest by reference).
95 *
96 * The trick of Nested typedef originates from Eigen.
97 *
98 * \note We use nest by value for intermediate expressions,
99 * and nest by reference for PVars.
100 */
101 using Nested = Derived;
102 /*!
103 * \brief Check if value matches the current pattern.
104 *
105 * This call also populates the PVars with matched value.
106 * The values in PVars are valid until the next call to Match.
107 *
108 * \return whether value matches the pattern.
109 */
110 template<typename NodeType>
Match(const NodeType & value)111 bool Match(const NodeType& value) const {
112 derived().InitMatch_();
113 return derived().Match_(value);
114 }
115 /*! \return Derived instance of current class. */
derived()116 const Derived& derived() const {
117 return *static_cast<const Derived*>(this);
118 }
119 };
120
121 /*!
122 * \brief Default deep equality checker
123 * \tparam T the comparison point.
124 */
125 template<typename T>
126 class PEqualChecker {
127 public:
operator()128 bool operator()(const T& lhs, const T& rhs) const {
129 return lhs == rhs;
130 }
131 };
132
133 template<>
134 class PEqualChecker<Expr> {
135 public:
operator()136 bool operator()(const Expr& lhs, const Expr& rhs) const {
137 if (lhs.same_as(rhs)) return true;
138 return ir::Equal(lhs, rhs);
139 }
140 };
141
142 template<>
143 class PEqualChecker<Integer> {
144 public:
operator()145 bool operator()(const Integer& lhs, const Integer& rhs) const {
146 return lhs->value == rhs->value;
147 }
148 };
149
150 template<>
151 class PEqualChecker<Var> {
152 public:
operator()153 bool operator()(const Var& lhs, const Var& rhs) const {
154 return lhs.same_as(rhs);
155 }
156 };
157
158 /*!
159 * \brief Pattern variable container.
160 *
161 * PVar is used as a "hole" in the pattern that can be matched.
162 *
163 * \tparam T the type of the hole.
164 *
165 * \note PVar is not thread safe.
166 * Do not use the same PVar in multiple threads.
167 */
168 template<typename T>
169 class PVar : public Pattern<PVar<T> > {
170 public:
171 // Store PVars by reference in the expression.
172 using Nested = const PVar<T>&;
173
InitMatch_()174 void InitMatch_() const {
175 filled_ = false;
176 }
177
Match_(const T & value)178 bool Match_(const T& value) const {
179 if (!filled_) {
180 value_ = value;
181 filled_ = true;
182 return true;
183 } else {
184 return PEqualChecker<T>()(value_, value);
185 }
186 }
187
188 template<typename NodeRefType,
189 typename = typename std::enable_if<
190 std::is_base_of<NodeRefType, T>::value>::type>
Match_(const NodeRefType & value)191 bool Match_(const NodeRefType& value) const {
192 if (const auto* ptr = value.template as<typename T::ContainerType>()) {
193 return Match_(GetRef<T>(ptr));
194 } else {
195 return false;
196 }
197 }
198
Eval()199 T Eval() const {
200 CHECK(filled_);
201 return value_;
202 }
203
204 protected:
205 /*! \brief The matched value */
206 mutable T value_;
207 /*! \brief whether the variable has been filled */
208 mutable bool filled_{false};
209 };
210
211 /*!
212 * \brief Constant Pattern variable container.
213 *
214 * \tparam T the type of the hole.
215 */
216 template<typename T>
217 class PConst : public Pattern<PConst<T> > {
218 public:
PConst(T value)219 PConst(T value) // NOLINT(*)
220 : value_(value) {}
221
InitMatch_()222 void InitMatch_() const {}
223
Match_(const T & value)224 bool Match_(const T& value) const {
225 return PEqualChecker<T>()(value_, value);
226 }
227
Eval()228 T Eval() const {
229 return value_;
230 }
231
232 private:
233 const T value_;
234 };
235
236 /*!
237 * \brief Pattern binary expression.
238 * \tparam NodeType The AST node type.
239 * \tparam TA The pattern type of the first operand.
240 * \tparam TB The pattern type of the second operand.
241 */
242 template<typename NodeType, typename TA, typename TB>
243 class PBinaryExpr :
244 public Pattern<PBinaryExpr<NodeType, TA, TB> > {
245 public:
PBinaryExpr(const TA & a,const TB & b)246 PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {}
247
InitMatch_()248 void InitMatch_() const {
249 a_.InitMatch_();
250 b_.InitMatch_();
251 }
252
Match_(const NodeRef & node)253 bool Match_(const NodeRef& node) const {
254 if (const NodeType* ptr = node.as<NodeType>()) {
255 if (!a_.Match_(ptr->a)) return false;
256 if (!b_.Match_(ptr->b)) return false;
257 return true;
258 } else {
259 return false;
260 }
261 }
262
Eval()263 Expr Eval() const {
264 Expr lhs = a_.Eval();
265 Expr rhs = b_.Eval();
266 Expr ret = TryConstFold<NodeType>(lhs, rhs);
267 if (ret.defined()) return ret;
268 return NodeType::make(lhs, rhs);
269 }
270
271 private:
272 typename TA::Nested a_;
273 typename TB::Nested b_;
274 };
275
276 template<typename TA>
277 class PConstWithTypeLike :
278 public Pattern<PConstWithTypeLike<TA> > {
279 public:
PConstWithTypeLike(const TA & ref,int64_t value)280 PConstWithTypeLike(const TA& ref, int64_t value)
281 : ref_(ref), value_(value) {}
282
InitMatch_()283 void InitMatch_() const {}
284
Match_(const NodeRef & node)285 bool Match_(const NodeRef& node) const {
286 if (const ir::IntImm* ptr = node.as<ir::IntImm>()) {
287 return ptr->value == value_;
288 } else {
289 return false;
290 }
291 }
292
Eval()293 Expr Eval() const {
294 return make_const(ref_.Eval().type(), value_);
295 }
296
297 private:
298 typename TA::Nested ref_;
299 int64_t value_;
300 };
301
302
303 #define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \
304 template<typename TA, typename TB> \
305 inline PBinaryExpr<NodeName, TA, TB> \
306 FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
307 CheckStep; \
308 return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
309 } \
310 template<typename TA> \
311 inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
312 FuncName(const Pattern<TA>& a, int64_t b) { \
313 CheckStep; \
314 return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
315 } \
316 template<typename TA> \
317 inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
318 FuncName(int64_t b, const Pattern<TA>& a) { \
319 CheckStep; \
320 return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
321 }
322
323 #define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
324 TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )
325
326
327 // raise ambiguity error for operator overload of / and %
328 TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
329 TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));
330
331 // arithmetic expressions
332 TVM_PATTERN_BINARY_OP(operator+, ir::Add);
333 TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
334 TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
335 TVM_PATTERN_BINARY_OP(min, ir::Min);
336 TVM_PATTERN_BINARY_OP(max, ir::Max);
337 TVM_PATTERN_BINARY_OP(div, ir::Div);
338 TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
339 TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
340 TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
341 TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod);
342
343 // logical expressions
344 TVM_PATTERN_BINARY_OP(operator>, ir::GT);
345 TVM_PATTERN_BINARY_OP(operator>=, ir::GE);
346 TVM_PATTERN_BINARY_OP(operator<, ir::LT);
347 TVM_PATTERN_BINARY_OP(operator<=, ir::LE);
348 TVM_PATTERN_BINARY_OP(operator==, ir::EQ);
349 TVM_PATTERN_BINARY_OP(operator!=, ir::NE);
350 TVM_PATTERN_BINARY_OP(operator&&, ir::And);
351 TVM_PATTERN_BINARY_OP(operator||, ir::Or);
352
353 /*!
354 * \brief Pattern not expression.
355 * \tparam TA The pattern type of the true operand.
356 */
357 template<typename TA>
358 class PNotExpr : public Pattern<PNotExpr<TA> > {
359 public:
PNotExpr(const TA & value)360 explicit PNotExpr(const TA& value)
361 : value_(value) {}
362
InitMatch_()363 void InitMatch_() const {
364 value_.InitMatch_();
365 }
366
Match_(const NodeRef & node)367 bool Match_(const NodeRef& node) const {
368 if (const ir::Not* ptr = node.as<ir::Not>()) {
369 if (!value_.Match_(ptr->a)) return false;
370 return true;
371 } else {
372 return false;
373 }
374 }
375
Eval()376 Expr Eval() const {
377 return ir::Not::make(value_.Eval());
378 }
379
380 private:
381 typename TA::Nested value_;
382 };
383
384 template<typename TA>
385 inline PNotExpr<TA> operator!(const Pattern<TA>& value) {
386 return PNotExpr<TA>(value.derived());
387 }
388
389 // select
390 /*!
391 * \brief Pattern select expression.
392 * \tparam TCond The pattern type of the condition.
393 * \tparam TA The pattern type of the true operand.
394 * \tparam TB The pattern type of the false operand.
395 */
396 template<typename TCond, typename TA, typename TB>
397 class PSelectExpr :
398 public Pattern<PSelectExpr<TCond, TA, TB> > {
399 public:
PSelectExpr(const TCond & condition,const TA & true_value,const TB & false_value)400 PSelectExpr(const TCond& condition,
401 const TA& true_value,
402 const TB& false_value)
403 : condition_(condition),
404 true_value_(true_value),
405 false_value_(false_value) {}
406
InitMatch_()407 void InitMatch_() const {
408 condition_.InitMatch_();
409 true_value_.InitMatch_();
410 false_value_.InitMatch_();
411 }
412
Match_(const NodeRef & node)413 bool Match_(const NodeRef& node) const {
414 if (const ir::Select* ptr = node.as<ir::Select>()) {
415 if (!condition_.Match_(ptr->condition)) return false;
416 if (!true_value_.Match_(ptr->true_value)) return false;
417 if (!false_value_.Match_(ptr->false_value)) return false;
418 return true;
419 } else {
420 return false;
421 }
422 }
423
Eval()424 Expr Eval() const {
425 return ir::Select::make(
426 condition_.Eval(), true_value_.Eval(), false_value_.Eval());
427 }
428
429 private:
430 typename TCond::Nested condition_;
431 typename TA::Nested true_value_;
432 typename TB::Nested false_value_;
433 };
434
435 /*!
436 * \brief Construct a select pattern.
437 *
438 * \param condition The condition expression.
439 * \param true_value The value when condition is true.
440 * \param true_value The value when condition is false.
441 *
442 * \return The result pattern.
443 *
444 * \tparam TCond The pattern type of the condition.
445 * \tparam TA The pattern type of the true operand.
446 * \tparam TB The pattern type of the false operand.
447 */
448 template<typename TCond, typename TA, typename TB>
449 inline PSelectExpr<TCond, TA, TB>
select(const Pattern<TCond> & condition,const Pattern<TA> & true_value,const Pattern<TB> & false_value)450 select(const Pattern<TCond>& condition,
451 const Pattern<TA>& true_value,
452 const Pattern<TB>& false_value) {
453 return PSelectExpr<TCond, TA, TB>(
454 condition.derived(), true_value.derived(), false_value.derived());
455 }
456
457 /*!
458 * \brief Pattern cast expression.
459 * \tparam DType The Pattern type of dtype.
460 * \tparam TA The pattern type of the first operand.
461 */
462 template<typename DType, typename TA>
463 class PCastExpr :
464 public Pattern<PCastExpr<DType, TA> > {
465 public:
PCastExpr(const DType & dtype,const TA & value)466 PCastExpr(const DType& dtype, const TA& value)
467 : dtype_(dtype), value_(value) {
468 }
469
InitMatch_()470 void InitMatch_() const {
471 dtype_.InitMatch_();
472 value_.InitMatch_();
473 }
474
Match_(const NodeRef & node)475 bool Match_(const NodeRef& node) const {
476 if (const ir::Cast* ptr = node.as<ir::Cast>()) {
477 if (!dtype_.Match_(ptr->type)) return false;
478 if (!value_.Match_(ptr->value)) return false;
479 return true;
480 } else {
481 return false;
482 }
483 }
484
Eval()485 Expr Eval() const {
486 return ir::Cast::make(dtype_.Eval(), value_.Eval());
487 }
488
489 private:
490 typename DType::Nested dtype_;
491 typename TA::Nested value_;
492 };
493
494 /*!
495 * \brief Construct a cast pattern.
496 *
497 * \param dtype The target data type, can be PVar<Type> or PConst<Type>.
498 * \param value The input type.
499 *
500 * \return The result pattern.
501 *
502 * \tparam DType The pattern type of type.
503 * \tparam TA The pattern type of value.
504 */
505 template<typename DType, typename TA>
506 inline PCastExpr<DType, TA>
cast(const Pattern<DType> & dtype,const Pattern<TA> & value)507 cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
508 return PCastExpr<DType, TA>(dtype.derived(), value.derived());
509 }
510
511 /*!
512 * \brief Pattern ramp expression.
513 * \tparam TBase The pattern type of the base.
514 * \tparam TStride The pattern type of the stride.
515 * \tparam TLanes The pattern type of the lanes.
516 */
517 template<typename TBase, typename TStride, typename TLanes>
518 class PRampExpr :
519 public Pattern<PRampExpr<TBase, TStride, TLanes> > {
520 public:
PRampExpr(const TBase & base,const TStride & stride,const TLanes & lanes)521 PRampExpr(const TBase& base,
522 const TStride& stride,
523 const TLanes& lanes)
524 : base_(base), stride_(stride), lanes_(lanes) {
525 }
526
InitMatch_()527 void InitMatch_() const {
528 base_.InitMatch_();
529 stride_.InitMatch_();
530 lanes_.InitMatch_();
531 }
532
Match_(const NodeRef & node)533 bool Match_(const NodeRef& node) const {
534 if (const ir::Ramp* ptr = node.as<ir::Ramp>()) {
535 if (!base_.Match_(ptr->base)) return false;
536 if (!stride_.Match_(ptr->stride)) return false;
537 if (!lanes_.Match_(ptr->lanes)) return false;
538 return true;
539 } else {
540 return false;
541 }
542 }
543
Eval()544 Expr Eval() const {
545 return ir::Ramp::make(base_.Eval(), stride_.Eval(), lanes_.Eval());
546 }
547
548 private:
549 typename TBase::Nested base_;
550 typename TStride::Nested stride_;
551 typename TLanes::Nested lanes_;
552 };
553
554 /*!
555 * \brief Construct a ramp pattern.
556 *
557 * \param base The base pattern.
558 * \param stride The stride pattern.
559 * \param lanes The lanes pattern.
560 *
561 * \return The result pattern.
562 *
563 * \tparam TBase The pattern type of the base.
564 * \tparam TStride The pattern type of the stride.
565 * \tparam TLanes The pattern type of the lanes.
566 */
567 template<typename TBase, typename TStride, typename TLanes>
568 inline PRampExpr<TBase, TStride, TLanes>
ramp(const Pattern<TBase> & base,const Pattern<TStride> & stride,const Pattern<TLanes> & lanes)569 ramp(const Pattern<TBase>& base,
570 const Pattern<TStride>& stride,
571 const Pattern<TLanes>& lanes) {
572 return PRampExpr<TBase, TStride, TLanes>(
573 base.derived(), stride.derived(), lanes.derived());
574 }
575
576 /*!
577 * \brief Pattern broadcast expression.
578 * \tparam TA The pattern type of the value.
579 * \tparam TLanes The pattern type of the lanes.
580 */
581 template<typename TA, typename TLanes>
582 class PBroadcastExpr :
583 public Pattern<PBroadcastExpr<TA, TLanes> > {
584 public:
PBroadcastExpr(const TA & value,const TLanes & lanes)585 PBroadcastExpr(const TA& value,
586 const TLanes& lanes)
587 : value_(value), lanes_(lanes) {
588 }
589
InitMatch_()590 void InitMatch_() const {
591 value_.InitMatch_();
592 lanes_.InitMatch_();
593 }
594
Match_(const NodeRef & node)595 bool Match_(const NodeRef& node) const {
596 if (const ir::Broadcast* ptr = node.as<ir::Broadcast>()) {
597 if (!value_.Match_(ptr->value)) return false;
598 if (!lanes_.Match_(ptr->lanes)) return false;
599 return true;
600 } else {
601 return false;
602 }
603 }
604
Eval()605 Expr Eval() const {
606 return ir::Broadcast::make(value_.Eval(), lanes_.Eval());
607 }
608
609 private:
610 typename TA::Nested value_;
611 typename TLanes::Nested lanes_;
612 };
613
614 /*!
615 * \brief Construct a broadcast pattern.
616 *
617 * \param value The value pattern.
618 * \param lanes The lanes pattern.
619 *
620 * \return The result pattern.
621 *
622 * \tparam TA The pattern type of the value.
623 * \tparam TLanes The pattern type of the lanes.
624 */
625 template<typename TA, typename TLanes>
626 inline PBroadcastExpr<TA, TLanes>
broadcast(const Pattern<TA> & value,const Pattern<TLanes> & lanes)627 broadcast(const Pattern<TA>& value, const Pattern<TLanes>& lanes) {
628 return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived());
629 }
630
631 // internal namespace
632 namespace detail {
633 // implementation details for CallExpr
634 template<bool stop, std::size_t I, typename F>
635 struct tuple_for_each_dispatcher {
636 template<typename TTuple>
runtuple_for_each_dispatcher637 static void run(F& f, const TTuple& tuple) { // NOLINT(*)
638 f(I, std::get<I>(tuple));
639 tuple_for_each_dispatcher<
640 (I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>
641 ::run(f, tuple);
642 }
643 };
644
645 template<std::size_t I, typename F>
646 struct tuple_for_each_dispatcher<true, I, F> {
647 template<typename TTuple>
648 static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
649 };
650
651 template<typename F, typename TTuple>
652 inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*)
653 tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>
654 ::run(f, tuple);
655 }
656
657 struct PCallExprInitMatchFunctor {
658 template<typename T>
659 void operator()(size_t i, const T& pattern) const {
660 pattern.InitMatch_();
661 }
662 };
663
664 struct PCallExprMatchFunctor {
665 const ir::Call* call_;
666 bool matched_{true};
667
668 explicit PCallExprMatchFunctor(const ir::Call* call)
669 : call_(call) {}
670
671 template<typename T>
672 void operator()(size_t i, const T& pattern) {
673 matched_ = matched_ && pattern.Match_(call_->args[i]);
674 }
675 };
676
677 struct PCallExprEvalArgsFunctor {
678 Array<Expr> args_;
679
680 template<typename T>
681 void operator()(size_t i, const T& pattern) {
682 args_.push_back(pattern.Eval());
683 }
684 };
685 } // namespace detail
686
687 /*!
688 * \brief Pattern CallExpr expression.
689 * \tparam Op The operator functor class.
690 * \tparam TArgs The arguments.
691 * \note Op functor contains the name of the function and
692 * the implementation of Eval.
693 */
694 template<typename Op, typename ...TArgs>
695 class PCallExpr :
696 public Pattern<PCallExpr<Op, TArgs...> > {
697 public:
698 explicit PCallExpr(const TArgs&... args)
699 : args_(args...) {
700 }
701
702 void InitMatch_() const {
703 detail::PCallExprInitMatchFunctor finit;
704 detail::tuple_for_each(finit, args_);
705 }
706
707 bool Match_(const NodeRef& node) const {
708 if (const ir::Call* ptr = node.as<ir::Call>()) {
709 if (ptr->args.size() != sizeof...(TArgs)) return false;
710 if (ptr->name != Op::kName) return false;
711 detail::PCallExprMatchFunctor fmatch(ptr);
712 detail::tuple_for_each(fmatch, args_);
713 return fmatch.matched_;
714 } else {
715 return false;
716 }
717 }
718
719 Expr Eval() const {
720 detail::PCallExprEvalArgsFunctor feval_args;
721 detail::tuple_for_each(feval_args, args_);
722 return Op::Eval(feval_args.args_);
723 }
724
725 private:
726 std::tuple<typename TArgs::Nested...> args_;
727 };
728
729 // arithemetic intrinsics
730 #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \
731 struct OpName { \
732 static Expr Eval(Array<Expr> args) { \
733 return ir::Call::make(args[0].type(), kName, args, \
734 ir::Call::PureIntrinsic); \
735 } \
736 static constexpr const char* kName = IntrinStr; \
737 }; \
738 template<typename TA, typename TB> \
739 inline PCallExpr<OpName, TA, TB> \
740 FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
741 return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \
742 }
743
744 TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left");
745 TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right");
746 TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and");
747 TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or");
748 TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor");
749
750 // unary intrinsics
751 #define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \
752 struct OpName { \
753 static Expr Eval(Array<Expr> args) { \
754 return ir::Call::make(args[0].type(), kName, args, \
755 ir::Call::PureIntrinsic); \
756 } \
757 static constexpr const char* kName = IntrinStr; \
758 }; \
759 template<typename TA> \
760 inline PCallExpr<OpName, TA> \
761 FuncName(const Pattern<TA>& a) { \
762 return PCallExpr<OpName, TA>(a.derived()); \
763 }
764
765 TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not");
766
767 // if_then_else
768 struct PIfThenElseOp {
769 static Expr Eval(Array<Expr> args) {
770 return ir::Call::make(
771 args[1].type(), kName, args,
772 ir::Call::PureIntrinsic);
773 }
774 static constexpr const char* kName = "tvm_if_then_else";
775 };
776
777 /*!
778 * \brief Construct a if_then_else pattern.
779 *
780 * \param cond The condition expression.
781 * \param true_value The value when condition is true.
782 * \param true_value The value when condition is false.
783 *
784 * \return The result pattern.
785 *
786 * \tparam TCond The pattern type of the condition.
787 * \tparam TA The pattern type of the true operand.
788 * \tparam TB The pattern type of the false operand.
789 */
790 template<typename TCond, typename TA, typename TB>
791 inline PCallExpr<PIfThenElseOp, TCond, TA, TB>
792 if_then_else(const Pattern<TCond>& cond,
793 const Pattern<TA>& true_value,
794 const Pattern<TB>& false_value) {
795 return PCallExpr<PIfThenElseOp, TCond, TA, TB>(
796 cond.derived(), true_value.derived(), false_value.derived());
797 }
798
799 } // namespace arith
800 } // namespace tvm
801 #endif // TVM_ARITHMETIC_PATTERN_MATCH_H_
802