1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import builtin 5from mlir.dialects import linalg 6from mlir.dialects import std 7 8def run(f): 9 print("\nTEST:", f.__name__) 10 f() 11 return f 12 13 14# CHECK-LABEL: TEST: testInitTensor 15@run 16def testInitTensor(): 17 with Context() as ctx, Location.unknown(): 18 module = Module.create() 19 f32 = F32Type.get() 20 with InsertionPoint(module.body): 21 # CHECK-LABEL: func @static_sizes 22 # CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32> 23 @builtin.FuncOp.from_py_func() 24 def static_sizes(): 25 return linalg.InitTensorOp([3, 4], f32) 26 27 # CHECK-LABEL: func @dynamic_sizes 28 # CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32> 29 @builtin.FuncOp.from_py_func(IndexType.get(), IndexType.get()) 30 def dynamic_sizes(d0, d1): 31 return linalg.InitTensorOp([d0, d1], f32) 32 33 # CHECK-LABEL: func @zero_d 34 # CHECK: %0 = linalg.init_tensor [] : tensor<f32> 35 @builtin.FuncOp.from_py_func() 36 def zero_d(): 37 return linalg.InitTensorOp([], f32) 38 39 print(module) 40 41# CHECK-LABEL: TEST: testInitTensorStaticSizesAttribute 42@run 43def testInitTensorStaticSizesAttribute(): 44 with Context() as ctx, Location.unknown(): 45 module = Module.create() 46 f32 = F32Type.get() 47 with InsertionPoint(module.body): 48 op = linalg.InitTensorOp([3, 4], f32) 49 # CHECK: [3, 4] 50 print(op.attributes['static_sizes']) 51 52# CHECK-LABEL: TEST: testFill 53@run 54def testFill(): 55 with Context() as ctx, Location.unknown(): 56 module = Module.create() 57 f32 = F32Type.get() 58 with InsertionPoint(module.body): 59 # CHECK-LABEL: func @fill_tensor 60 # CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<12x?xf32> 61 # CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32 62 # CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[CST]], %[[OUT]]) : f32, tensor<12x?xf32> -> tensor<12x?xf32> 63 # CHECK-NEXT: return %[[RES]] : tensor<12x?xf32> 64 @builtin.FuncOp.from_py_func( 65 RankedTensorType.get((12, -1), f32)) 66 def fill_tensor(out): 67 zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result 68 # TODO: FillOp.result is None. When len(results) == 1 we expect it to 69 # be results[0] as per _linalg_ops_gen.py. This seems like an 70 # orthogonal bug in the generator of _linalg_ops_gen.py. 71 return linalg.FillOp(output=out, value=zero).results[0] 72 73 # CHECK-LABEL: func @fill_buffer 74 # CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<12x?xf32> 75 # CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32 76 # CHECK-NEXT: linalg.fill(%[[CST]], %[[OUT]]) : f32, memref<12x?xf32> 77 # CHECK-NEXT: return 78 @builtin.FuncOp.from_py_func( 79 MemRefType.get((12, -1), f32)) 80 def fill_buffer(out): 81 zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result 82 linalg.FillOp(output=out, value=zero) 83 84 print(module) 85 86 87# CHECK-LABEL: TEST: testStructuredOpOnTensors 88@run 89def testStructuredOpOnTensors(): 90 with Context() as ctx, Location.unknown(): 91 module = Module.create() 92 f32 = F32Type.get() 93 tensor_type = RankedTensorType.get((2, 3, 4), f32) 94 with InsertionPoint(module.body): 95 func = builtin.FuncOp(name="matmul_test", 96 type=FunctionType.get( 97 inputs=[tensor_type, tensor_type], 98 results=[tensor_type])) 99 with InsertionPoint(func.add_entry_block()): 100 lhs, rhs = func.entry_block.arguments 101 result = linalg.MatmulOp([lhs, rhs], results=[tensor_type]).result 102 std.ReturnOp([result]) 103 104 # CHECK: %[[R:.*]] = linalg.matmul ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32> 105 print(module) 106 107 108# CHECK-LABEL: TEST: testStructuredOpOnBuffers 109@run 110def testStructuredOpOnBuffers(): 111 with Context() as ctx, Location.unknown(): 112 module = Module.create() 113 f32 = F32Type.get() 114 memref_type = MemRefType.get((2, 3, 4), f32) 115 with InsertionPoint(module.body): 116 func = builtin.FuncOp(name="matmul_test", 117 type=FunctionType.get( 118 inputs=[memref_type, memref_type, memref_type], 119 results=[])) 120 with InsertionPoint(func.add_entry_block()): 121 lhs, rhs, result = func.entry_block.arguments 122 # TODO: prperly hook up the region. 123 linalg.MatmulOp([lhs, rhs], outputs=[result]) 124 std.ReturnOp([]) 125 126 # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>) 127 print(module) 128 129# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm 130@run 131def testNamedStructuredOpCustomForm(): 132 with Context() as ctx, Location.unknown(): 133 module = Module.create() 134 f32 = F32Type.get() 135 with InsertionPoint(module.body): 136 @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), 137 RankedTensorType.get((16, 8), f32)) 138 def named_form(lhs, rhs): 139 init_result = linalg.InitTensorOp([4, 8], f32) 140 # First check the named form with custom format 141 # CHECK: linalg.matmul 142 # CHECK-NOT: linalg.memoized_indexing_maps 143 # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>) 144 # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>) 145 # CHECK-SAME: -> tensor<4x8xf32> 146 # CHECK-NEXT: return 147 return linalg.matmul(lhs, rhs, outs=[init_result.result]) 148 149 print(module) 150 151# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm 152@run 153def testNamedStructuredOpGenericForm(): 154 with Context() as ctx, Location.unknown(): 155 module = Module.create() 156 f32 = F32Type.get() 157 with InsertionPoint(module.body): 158 @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), 159 RankedTensorType.get((16, 8), f32)) 160 def named_form(lhs, rhs): 161 init_result = linalg.InitTensorOp([4, 8], f32) 162 # CHECK: "linalg.matmul"(%{{.*}}) 163 # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): 164 # CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32 165 # CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32 166 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () 167 # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} : 168 # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> 169 return linalg.matmul(lhs, rhs, outs=[init_result.result]) 170 171 module.operation.print(print_generic_op_form=True) 172 173# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp 174@run 175def testNamedStructuredAsGenericOp(): 176 with Context() as ctx, Location.unknown(): 177 module = Module.create() 178 f32 = F32Type.get() 179 with InsertionPoint(module.body): 180 @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32), 181 RankedTensorType.get((16, 8), f32)) 182 def generic_form(lhs, rhs): 183 init_result = linalg.InitTensorOp([4, 8], f32) 184 # CHECK: linalg.generic 185 return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True) 186 187 print(module) 188