1 //===- TestMatchers.cpp - Pass to test matchers ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/StandardOps/IR/Ops.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/IR/Matchers.h" 12 #include "mlir/Pass/Pass.h" 13 14 using namespace mlir; 15 16 namespace { 17 /// This is a test pass for verifying matchers. 18 struct TestMatchers : public PassWrapper<TestMatchers, FunctionPass> { 19 void runOnFunction() override; 20 StringRef getArgument() const final { return "test-matchers"; } 21 StringRef getDescription() const final { 22 return "Test C++ pattern matchers."; 23 } 24 }; 25 } // end anonymous namespace 26 27 // This could be done better but is not worth the variadic template trouble. 28 template <typename Matcher> 29 static unsigned countMatches(FuncOp f, Matcher &matcher) { 30 unsigned count = 0; 31 f.walk([&count, &matcher](Operation *op) { 32 if (matcher.match(op)) 33 ++count; 34 }); 35 return count; 36 } 37 38 using mlir::matchers::m_Any; 39 using mlir::matchers::m_Val; 40 static void test1(FuncOp f) { 41 assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args"); 42 43 auto a = m_Val(f.getArgument(0)); 44 auto b = m_Val(f.getArgument(1)); 45 auto c = m_Val(f.getArgument(2)); 46 47 auto p0 = m_Op<AddFOp>(); // using 0-arity matcher 48 llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0) 49 << " times\n"; 50 51 auto p1 = m_Op<MulFOp>(); // using 0-arity matcher 52 llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1) 53 << " times\n"; 54 55 auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_Any()); 56 llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2) 57 << " times\n"; 58 59 auto p3 = m_Op<AddFOp>(m_Any(), m_Op<AddFOp>()); 60 llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3) 61 << " times\n"; 62 63 auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_Any()); 64 llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4) 65 << " times\n"; 66 67 auto p5 = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>()); 68 llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5) 69 << " times\n"; 70 71 auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Any()); 72 llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6) 73 << " times\n"; 74 75 auto p7 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>()); 76 llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7) 77 << " times\n"; 78 79 auto mul_of_mulmul = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>()); 80 auto p8 = m_Op<MulFOp>(mul_of_mulmul, mul_of_mulmul); 81 llvm::outs() 82 << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched " 83 << countMatches(f, p8) << " times\n"; 84 85 // clang-format off 86 auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>()); 87 auto mul_of_anyadd = m_Op<MulFOp>(m_Any(), m_Op<AddFOp>()); 88 auto p9 = m_Op<MulFOp>(m_Op<MulFOp>( 89 mul_of_muladd, m_Op<MulFOp>()), 90 m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd)); 91 // clang-format on 92 llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, " 93 "add(*)), mul(*, add(*)))) matched " 94 << countMatches(f, p9) << " times\n"; 95 96 auto p10 = m_Op<AddFOp>(a, b); 97 llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10) 98 << " times\n"; 99 100 auto p11 = m_Op<AddFOp>(a, c); 101 llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11) 102 << " times\n"; 103 104 auto p12 = m_Op<AddFOp>(b, a); 105 llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12) 106 << " times\n"; 107 108 auto p13 = m_Op<AddFOp>(c, a); 109 llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13) 110 << " times\n"; 111 112 auto p14 = m_Op<MulFOp>(a, m_Op<AddFOp>(c, b)); 113 llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14) 114 << " times\n"; 115 116 auto p15 = m_Op<MulFOp>(a, m_Op<AddFOp>(b, c)); 117 llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15) 118 << " times\n"; 119 120 auto mul_of_aany = m_Op<MulFOp>(a, m_Any()); 121 auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c)); 122 llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched " 123 << countMatches(f, p16) << " times\n"; 124 125 auto p17 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(c, b)); 126 llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched " 127 << countMatches(f, p17) << " times\n"; 128 } 129 130 void test2(FuncOp f) { 131 auto a = m_Val(f.getArgument(0)); 132 FloatAttr floatAttr; 133 auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr))); 134 auto p1 = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant())); 135 // Last operation that is not the terminator. 136 Operation *lastOp = f.getBody().front().back().getPrevNode(); 137 if (p.match(lastOp)) 138 llvm::outs() 139 << "Pattern add(add(a, constant), a) matched and bound constant to: " 140 << floatAttr.getValueAsDouble() << "\n"; 141 if (p1.match(lastOp)) 142 llvm::outs() << "Pattern add(add(a, constant), a) matched\n"; 143 } 144 145 void TestMatchers::runOnFunction() { 146 auto f = getFunction(); 147 llvm::outs() << f.getName() << "\n"; 148 if (f.getName() == "test1") 149 test1(f); 150 if (f.getName() == "test2") 151 test2(f); 152 } 153 154 namespace mlir { 155 void registerTestMatchers() { PassRegistration<TestMatchers>(); } 156 } // namespace mlir 157