1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/StandardOps/IR/Ops.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/Pass/Pass.h"
13
14 using namespace mlir;
15
16 namespace {
17 /// This is a test pass for verifying matchers.
18 struct TestMatchers : public PassWrapper<TestMatchers, FunctionPass> {
19 void runOnFunction() override;
getArgument__anon14bd0fdf0111::TestMatchers20 StringRef getArgument() const final { return "test-matchers"; }
getDescription__anon14bd0fdf0111::TestMatchers21 StringRef getDescription() const final {
22 return "Test C++ pattern matchers.";
23 }
24 };
25 } // end anonymous namespace
26
27 // This could be done better but is not worth the variadic template trouble.
28 template <typename Matcher>
countMatches(FuncOp f,Matcher & matcher)29 static unsigned countMatches(FuncOp f, Matcher &matcher) {
30 unsigned count = 0;
31 f.walk([&count, &matcher](Operation *op) {
32 if (matcher.match(op))
33 ++count;
34 });
35 return count;
36 }
37
38 using mlir::matchers::m_Any;
39 using mlir::matchers::m_Val;
test1(FuncOp f)40 static void test1(FuncOp f) {
41 assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
42
43 auto a = m_Val(f.getArgument(0));
44 auto b = m_Val(f.getArgument(1));
45 auto c = m_Val(f.getArgument(2));
46
47 auto p0 = m_Op<AddFOp>(); // using 0-arity matcher
48 llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
49 << " times\n";
50
51 auto p1 = m_Op<MulFOp>(); // using 0-arity matcher
52 llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
53 << " times\n";
54
55 auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_Any());
56 llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
57 << " times\n";
58
59 auto p3 = m_Op<AddFOp>(m_Any(), m_Op<AddFOp>());
60 llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
61 << " times\n";
62
63 auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_Any());
64 llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
65 << " times\n";
66
67 auto p5 = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>());
68 llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
69 << " times\n";
70
71 auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Any());
72 llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
73 << " times\n";
74
75 auto p7 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
76 llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
77 << " times\n";
78
79 auto mul_of_mulmul = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
80 auto p8 = m_Op<MulFOp>(mul_of_mulmul, mul_of_mulmul);
81 llvm::outs()
82 << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
83 << countMatches(f, p8) << " times\n";
84
85 // clang-format off
86 auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>());
87 auto mul_of_anyadd = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>());
88 auto p9 = m_Op<MulFOp>(m_Op<MulFOp>(
89 mul_of_muladd, m_Op<MulFOp>()),
90 m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd));
91 // clang-format on
92 llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
93 "add(*)), mul(*, add(*)))) matched "
94 << countMatches(f, p9) << " times\n";
95
96 auto p10 = m_Op<AddFOp>(a, b);
97 llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
98 << " times\n";
99
100 auto p11 = m_Op<AddFOp>(a, c);
101 llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
102 << " times\n";
103
104 auto p12 = m_Op<AddFOp>(b, a);
105 llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
106 << " times\n";
107
108 auto p13 = m_Op<AddFOp>(c, a);
109 llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
110 << " times\n";
111
112 auto p14 = m_Op<MulFOp>(a, m_Op<AddFOp>(c, b));
113 llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
114 << " times\n";
115
116 auto p15 = m_Op<MulFOp>(a, m_Op<AddFOp>(b, c));
117 llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
118 << " times\n";
119
120 auto mul_of_aany = m_Op<MulFOp>(a, m_Any());
121 auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c));
122 llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
123 << countMatches(f, p16) << " times\n";
124
125 auto p17 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(c, b));
126 llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
127 << countMatches(f, p17) << " times\n";
128 }
129
test2(FuncOp f)130 void test2(FuncOp f) {
131 auto a = m_Val(f.getArgument(0));
132 FloatAttr floatAttr;
133 auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr)));
134 auto p1 = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant()));
135 // Last operation that is not the terminator.
136 Operation *lastOp = f.getBody().front().back().getPrevNode();
137 if (p.match(lastOp))
138 llvm::outs()
139 << "Pattern add(add(a, constant), a) matched and bound constant to: "
140 << floatAttr.getValueAsDouble() << "\n";
141 if (p1.match(lastOp))
142 llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
143 }
144
runOnFunction()145 void TestMatchers::runOnFunction() {
146 auto f = getFunction();
147 llvm::outs() << f.getName() << "\n";
148 if (f.getName() == "test1")
149 test1(f);
150 if (f.getName() == "test2")
151 test2(f);
152 }
153
154 namespace mlir {
registerTestMatchers()155 void registerTestMatchers() { PassRegistration<TestMatchers>(); }
156 } // namespace mlir
157