1// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
2
3func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
4  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
5                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
6  return %0: tensor<16x32xf32>
7}
8
9// CHECK-LABEL: @generalize_matmul_tensor_f32
10// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
11// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
12// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
13// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
14// CHECK-NEXT: -> tensor<16x32xf32>
15
16// -----
17
18func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
19  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
20                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
21  return %0: tensor<16x32xi32>
22}
23
24// CHECK-LABEL: @generalize_matmul_tensor_i32
25// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i32)
26// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_ARG]], %[[B_ARG]] : i32
27// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
28// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
29// CHECK-NEXT: -> tensor<16x32xi32>
30
31// -----
32
33func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32(%input : tensor<1x4x16x1xf32>, %filter: tensor<2x2x1xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
34  %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
35    ins(%input, %filter : tensor<1x4x16x1xf32>, tensor<2x2x1xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
36  return %0: tensor<1x2x4x1xf32>
37}
38
39// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32
40// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[FILTER_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
41// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[IN_ARG]], %[[FILTER_ARG]] : f32
42// CHECK-NEXT:   %[[ADD:.+]] = addf %[[OUT_ARG]], %[[MUL]] : f32
43// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
44// CHECK-NEXT: -> tensor<1x2x4x1xf32>
45
46// -----
47
48func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tensor<1x4x16x1xi32>, %filter: tensor<2x2x1xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
49  %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
50    ins(%input, %filter : tensor<1x4x16x1xi32>, tensor<2x2x1xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
51  return %0: tensor<1x2x4x1xi32>
52}
53
54// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32
55// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[FILTER_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
56// CHECK-NEXT:   %[[MUL:.+]] = muli %[[IN_ARG]], %[[FILTER_ARG]] : i32
57// CHECK-NEXT:   %[[ADD:.+]] = addi %[[OUT_ARG]], %[[MUL]] : i32
58// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
59// CHECK-NEXT: -> tensor<1x2x4x1xi32>
60
61// -----
62
63func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
64  %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
65    ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
66  return %0: tensor<1x2x4x1xf32>
67}
68
69// CHECK-LABEL: @generalize_pooling_nhwc_max_f32
70// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
71// CHECK-NEXT:   %[[COND:.+]] = cmpf ogt, %[[OUT_ARG]], %[[IN_ARG]] : f32
72// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
73// CHECK-NEXT:   linalg.yield %[[MAX]] : f32
74// CHECK-NEXT: -> tensor<1x2x4x1xf32>
75
76// -----
77
78func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
79  %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
80    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
81  return %0: tensor<1x2x4x1xi32>
82}
83
84// CHECK-LABEL: @generalize_pooling_nhwc_max_i32
85// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
86// CHECK-NEXT:   %[[COND:.+]] = cmpi sgt, %[[OUT_ARG]], %[[IN_ARG]] : i32
87// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
88// CHECK-NEXT:   linalg.yield %[[MAX]] : i32
89// CHECK-NEXT: -> tensor<1x2x4x1xi32>
90
91// -----
92
93func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
94  %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
95    ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
96  return %0: tensor<1x2x4x1xf32>
97}
98
99// CHECK-LABEL: @generalize_pooling_nhwc_min_f32
100// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
101// CHECK-NEXT:   %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32
102// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
103// CHECK-NEXT:   linalg.yield %[[MAX]] : f32
104// CHECK-NEXT: -> tensor<1x2x4x1xf32>
105
106// -----
107
108func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
109  %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
110    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
111  return %0: tensor<1x2x4x1xi32>
112}
113
114// CHECK-LABEL: @generalize_pooling_nhwc_min_i32
115// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
116// CHECK-NEXT:   %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32
117// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
118// CHECK-NEXT:   linalg.yield %[[MAX]] : i32
119// CHECK-NEXT: -> tensor<1x2x4x1xi32>
120
121// -----
122
123func @generalize_pooling_nhwc_sum_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
124  %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
125    ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
126  return %0: tensor<1x2x4x1xf32>
127}
128
129// CHECK-LABEL: @generalize_pooling_nhwc_sum_f32
130// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
131// CHECK-NEXT:   %[[ADD:.+]] = addf %[[OUT_ARG]], %[[IN_ARG]] : f32
132// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
133// CHECK-NEXT: -> tensor<1x2x4x1xf32>
134
135// -----
136
137func @generalize_pooling_nhwc_sum_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
138  %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
139    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
140  return %0: tensor<1x2x4x1xi32>
141}
142
143// CHECK-LABEL: @generalize_pooling_nhwc_sum_i32
144// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
145// CHECK-NEXT:   %[[ADD:.+]] = addi %[[OUT_ARG]], %[[IN_ARG]] : i32
146// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
147// CHECK-NEXT: -> tensor<1x2x4x1xi32>
148
149// -----
150
151func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
152  %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
153  return %0: tensor<16x32xf32>
154}
155
156// CHECK-LABEL: @generalize_fill_rng_2d_f32
157// CHECK-DAG:  ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: f32
158// CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
159// CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
160// CHECK-DAG:    %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
161// CHECK-DAG:    %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32
162// CHECK-DAG:    %[[VAL0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
163// CHECK-DAG:    %[[CST0:.+]] = constant 1103515245 : i32
164// CHECK-DAG:    %[[CST1:.+]] = constant 12345 : i32
165// CHECK-DAG:    %[[VAL1:.+]] = muli %[[VAL0]], %[[CST0]] : i32
166// CHECK-DAG:    %[[VAL2:.+]] = addi %[[VAL1]], %[[CST1]] : i32
167// Skip random number computation for the second index.
168// CHECK-DAG:    %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
169// CHECK-DAG:    %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64
170// CHECK-DAG:    %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64
171// CHECK-DAG:    %[[VAL4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
172// CHECK-DAG:    %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN]] : f64
173// CHECK-DAG:    %[[VAL6:.+]] = fptrunc %[[VAL5]] : f64 to f32
174// CHECK-NEXT:   linalg.yield %[[VAL6]] : f32
175// CHECK-NEXT: -> tensor<16x32xf32>
176
177// -----
178
179func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xi32>) -> tensor<16x32xi32> {
180  %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32>
181  return %0: tensor<16x32xi32>
182}
183
184// CHECK-LABEL: @generalize_fill_rng_2d_i32
185// CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: i32
186// Verifies floating point to integer cast.
187// CHECK:        %[[VAL6:.+]] = fptosi %{{.+}} : f64 to i32
188// CHECK-NEXT:   linalg.yield %[[VAL6]] : i32
189// CHECK-NEXT: -> tensor<16x32xi32>
190
191// -----
192
193func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> {
194  %0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32>
195  return %0: tensor<16x32xf32>
196}
197
198// CHECK-LABEL: @generalize_soft_plus_2d_f32
199//      CHECK: %[[C1:.+]] = constant 1.000000e+00 : f64
200//      CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32
201// CHECK-NEXT:   %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32
202// CHECK-NEXT:   %[[EXP:.+]] = math.exp %[[IN]] : f32
203// CHECK-NEXT:   %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32
204// CHECK-NEXT:   %[[LOG:.+]] = math.log %[[SUM]] : f32
205// CHECK-NEXT:   linalg.yield %[[LOG]] : f32
206// CHECK-NEXT: -> tensor<16x32xf32>
207
208// -----
209// Verifies floating point to integer cast.
210func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
211  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
212                          outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
213  return %0: tensor<16x32xi16>
214}
215
216// CHECK-LABEL: @generalize_matmul_tensor_f32_f32_i16
217// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: i16)
218// CHECK-NEXT:   %[[A_CAST:.+]] = fptosi %[[A_ARG]] : f32 to i16
219// CHECK-NEXT:   %[[B_CAST:.+]] = fptosi %[[B_ARG]] : f32 to i16
220// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
221// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
222// CHECK-NEXT:   linalg.yield %[[ADD]] : i16
223// CHECK-NEXT: -> tensor<16x32xi16>
224
225// -----
226// Verifies sign extension cast.
227func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
228  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
229                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
230  return %0: tensor<16x32xi32>
231}
232
233// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_i32
234// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
235// CHECK-NEXT:   %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
236// CHECK-NEXT:   %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32
237// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
238// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
239// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
240// CHECK-NEXT: -> tensor<16x32xi32>
241
242// -----
243// Verifies that different argument types is legal.
244func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
245  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>)
246                          outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
247  return %0: tensor<16x32xi32>
248}
249
250// CHECK-LABEL: @generalize_matmul_tensor_i8_i16_i32
251// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
252// CHECK-NEXT:   %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
253// CHECK-NEXT:   %[[B_CAST:.+]] = sexti %[[B_ARG]] : i16 to i32
254// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
255// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
256// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
257// CHECK-NEXT: -> tensor<16x32xi32>
258
259// -----
260// Somewhat non-sensical but checks integer truncation cast.
261func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> {
262  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
263                          outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16>
264  return %0: tensor<16x32xi16>
265}
266
267// CHECK-LABEL: @generalize_matmul_tensor_i32_i32_i16
268// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
269// CHECK-NEXT:   %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16
270// CHECK-NEXT:   %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16
271// CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
272// CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
273// CHECK-NEXT:   linalg.yield %[[ADD]] : i16
274// CHECK-NEXT: -> tensor<16x32xi16>
275
276// -----
277// Verifies integer to floating point cast.
278func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
279  %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>)
280                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
281  return %0: tensor<16x32xf32>
282}
283
284// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_f32
285// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
286// CHECK-NEXT:   %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32
287// CHECK-NEXT:   %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32
288// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
289// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
290// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
291// CHECK-NEXT: -> tensor<16x32xf32>
292
293// -----
294// Verifies floating point extension cast.
295func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
296  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>)
297                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
298  return %0: tensor<16x32xf32>
299}
300
301// CHECK-LABEL: @generalize_matmul_tensor_f16_f16_f32
302// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
303// CHECK-NEXT:   %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
304// CHECK-NEXT:   %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32
305// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
306// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
307// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
308// CHECK-NEXT: -> tensor<16x32xf32>
309
310// -----
311// Verifies floating point truncation.
312func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
313  %0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>)
314                          outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
315  return %0: tensor<16x32xf32>
316}
317
318// CHECK-LABEL: @generalize_matmul_tensor_f64_f64_f32
319// CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
320// CHECK-NEXT:   %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32
321// CHECK-NEXT:   %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32
322// CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
323// CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
324// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
325// CHECK-NEXT: -> tensor<16x32xf32>
326