1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import builtin
5from mlir.dialects import linalg
6from mlir.dialects import std
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  return f
12
13
14# CHECK-LABEL: TEST: testInitTensor
15@run
16def testInitTensor():
17  with Context() as ctx, Location.unknown():
18    module = Module.create()
19    f32 = F32Type.get()
20    with InsertionPoint(module.body):
21      # CHECK-LABEL: func @static_sizes
22      # CHECK: %0 = linalg.init_tensor [3, 4] : tensor<3x4xf32>
23      @builtin.FuncOp.from_py_func()
24      def static_sizes():
25        return linalg.InitTensorOp([3, 4], f32)
26
27      # CHECK-LABEL: func @dynamic_sizes
28      # CHECK: %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
29      @builtin.FuncOp.from_py_func(IndexType.get(), IndexType.get())
30      def dynamic_sizes(d0, d1):
31        return linalg.InitTensorOp([d0, d1], f32)
32
33      # CHECK-LABEL: func @zero_d
34      # CHECK: %0 = linalg.init_tensor [] : tensor<f32>
35      @builtin.FuncOp.from_py_func()
36      def zero_d():
37        return linalg.InitTensorOp([], f32)
38
39  print(module)
40
41# CHECK-LABEL: TEST: testInitTensorStaticSizesAttribute
42@run
43def testInitTensorStaticSizesAttribute():
44  with Context() as ctx, Location.unknown():
45    module = Module.create()
46    f32 = F32Type.get()
47    with InsertionPoint(module.body):
48      op = linalg.InitTensorOp([3, 4], f32)
49      # CHECK: [3, 4]
50      print(op.attributes['static_sizes'])
51
52# CHECK-LABEL: TEST: testFill
53@run
54def testFill():
55  with Context() as ctx, Location.unknown():
56    module = Module.create()
57    f32 = F32Type.get()
58    with InsertionPoint(module.body):
59      # CHECK-LABEL: func @fill_tensor
60      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
61      #  CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32
62      #  CHECK-NEXT: %[[RES:.*]] = linalg.fill(%[[CST]], %[[OUT]]) : f32, tensor<12x?xf32> -> tensor<12x?xf32>
63      #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
64      @builtin.FuncOp.from_py_func(
65          RankedTensorType.get((12, -1), f32))
66      def fill_tensor(out):
67        zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
68        # TODO: FillOp.result is None. When len(results) == 1 we expect it to
69        # be results[0] as per _linalg_ops_gen.py. This seems like an
70        # orthogonal bug in the generator of _linalg_ops_gen.py.
71        return linalg.FillOp(output=out, value=zero).results[0]
72
73      # CHECK-LABEL: func @fill_buffer
74      #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
75      #  CHECK-NEXT: %[[CST:.*]] = constant 0.0{{.*}} : f32
76      #  CHECK-NEXT: linalg.fill(%[[CST]], %[[OUT]]) : f32, memref<12x?xf32>
77      #  CHECK-NEXT: return
78      @builtin.FuncOp.from_py_func(
79          MemRefType.get((12, -1), f32))
80      def fill_buffer(out):
81        zero = std.ConstantOp(value=FloatAttr.get(f32, 0.), result=f32).result
82        linalg.FillOp(output=out, value=zero)
83
84  print(module)
85
86
87# CHECK-LABEL: TEST: testStructuredOpOnTensors
88@run
89def testStructuredOpOnTensors():
90  with Context() as ctx, Location.unknown():
91    module = Module.create()
92    f32 = F32Type.get()
93    tensor_type = RankedTensorType.get((2, 3, 4), f32)
94    with InsertionPoint(module.body):
95      func = builtin.FuncOp(name="matmul_test",
96                            type=FunctionType.get(
97                                inputs=[tensor_type, tensor_type],
98                                results=[tensor_type]))
99      with InsertionPoint(func.add_entry_block()):
100        lhs, rhs = func.entry_block.arguments
101        result = linalg.MatmulOp([lhs, rhs], results=[tensor_type]).result
102        std.ReturnOp([result])
103
104  # CHECK: %[[R:.*]] = linalg.matmul ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
105  print(module)
106
107
108# CHECK-LABEL: TEST: testStructuredOpOnBuffers
109@run
110def testStructuredOpOnBuffers():
111  with Context() as ctx, Location.unknown():
112    module = Module.create()
113    f32 = F32Type.get()
114    memref_type = MemRefType.get((2, 3, 4), f32)
115    with InsertionPoint(module.body):
116      func = builtin.FuncOp(name="matmul_test",
117                            type=FunctionType.get(
118                                inputs=[memref_type, memref_type, memref_type],
119                                results=[]))
120      with InsertionPoint(func.add_entry_block()):
121        lhs, rhs, result = func.entry_block.arguments
122        # TODO: prperly hook up the region.
123        linalg.MatmulOp([lhs, rhs], outputs=[result])
124        std.ReturnOp([])
125
126  # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
127  print(module)
128
129# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
130@run
131def testNamedStructuredOpCustomForm():
132  with Context() as ctx, Location.unknown():
133    module = Module.create()
134    f32 = F32Type.get()
135    with InsertionPoint(module.body):
136      @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
137                                   RankedTensorType.get((16, 8), f32))
138      def named_form(lhs, rhs):
139        init_result = linalg.InitTensorOp([4, 8], f32)
140        # First check the named form with custom format
141        #      CHECK: linalg.matmul
142        #  CHECK-NOT: linalg.memoized_indexing_maps
143        # CHECK-SAME:    ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
144        # CHECK-SAME:   outs(%{{.*}} : tensor<4x8xf32>)
145        # CHECK-SAME:   -> tensor<4x8xf32>
146        # CHECK-NEXT: return
147        return linalg.matmul(lhs, rhs, outs=[init_result.result])
148
149  print(module)
150
151# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
152@run
153def testNamedStructuredOpGenericForm():
154  with Context() as ctx, Location.unknown():
155    module = Module.create()
156    f32 = F32Type.get()
157    with InsertionPoint(module.body):
158      @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
159                                   RankedTensorType.get((16, 8), f32))
160      def named_form(lhs, rhs):
161        init_result = linalg.InitTensorOp([4, 8], f32)
162        #      CHECK: "linalg.matmul"(%{{.*}})
163        # CHECK-NEXT:  ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
164        # CHECK-NEXT:    std.mulf{{.*}} (f32, f32) -> f32
165        # CHECK-NEXT:    std.addf{{.*}} (f32, f32) -> f32
166        # CHECK-NEXT:    linalg.yield{{.*}} (f32) -> ()
167        # CHECK-NEXT:    {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
168        # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
169        return linalg.matmul(lhs, rhs, outs=[init_result.result])
170
171  module.operation.print(print_generic_op_form=True)
172
173# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
174@run
175def testNamedStructuredAsGenericOp():
176  with Context() as ctx, Location.unknown():
177    module = Module.create()
178    f32 = F32Type.get()
179    with InsertionPoint(module.body):
180      @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
181                                   RankedTensorType.get((16, 8), f32))
182      def generic_form(lhs, rhs):
183        init_result = linalg.InitTensorOp([4, 8], f32)
184        # CHECK: linalg.generic
185        return linalg.matmul(lhs, rhs, outs=[init_result.result], emit_generic=True)
186
187  print(module)
188