1// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
2
3// CHECK-LABEL: func @memref_cast(
4func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
5  %c0 = constant 0 : index
6  %c1 = constant 1 : index
7  %c8 = constant 8 : index
8  %c16 = constant 16 : index
9  %1 = memref.alloc (%b) : memref<?xi8>
10  %2 = memref.view %1[%c0][] : memref<?xi8> to memref<16x16xf32>
11  %3 = memref.cast %2 : memref<16x16xf32> to memref<?x?xf32>
12
13  // CHECK:  linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>)
14  linalg.matmul ins(%3, %3: memref<?x?xf32>, memref<?x?xf32>)
15               outs(%3: memref<?x?xf32>)
16  return %3: memref<?x?xf32>
17}
18
19// -----
20
21#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
22
23// CHECK-LABEL: func @memref_cast_into_tiled_loop(
24func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>)  {
25  %0 = memref.cast %arg0
26    : memref<192xf32> to memref<192xf32, #map>
27  %cst = constant 0.000000e+00 : f32
28  %c24 = constant 24 : index
29  %c0 = constant 0 : index
30  %c192 = constant 192 : index
31  // CHECK: linalg.tiled_loop
32  // CHECK-SAME: outs (%{{.*}} = %{{.*}}: memref<192xf32>)
33  linalg.tiled_loop (%arg3) = (%c0) to (%c192) step (%c24)
34    outs (%out = %0: memref<192xf32, #map>) {
35    %14 = affine.min affine_map<(d0) -> (-d0 + 192, 24)>(%arg3)
36    %16 = memref.subview %out[%arg3] [%14] [1]
37      : memref<192xf32, #map> to memref<?xf32, #map>
38    linalg.fill(%cst, %16) : f32, memref<?xf32, #map>
39    linalg.yield
40  }
41  return
42}
43
44// -----
45
46// CHECK-LABEL: zero_rank_reshape_multi
47func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
48  // CHECK: return %arg0
49  %0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
50  %1 = linalg.tensor_expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32>
51  %2 = linalg.tensor_collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
52  return %2 : tensor<f32>
53}
54
55// -----
56
57func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
58{
59  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4]]
60      : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
61  %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]]
62      : tensor<?x?x?xf32> into tensor<?x?xf32>
63  return %1 : tensor<?x?xf32>
64}
65// CHECK-LABEL: collapsing_tensor_reshapes
66//       CHECK:   linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
67//   CHECK-NOT:   linalg.tensor_collapse_shape
68
69// -----
70
71func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
72                                             -> tensor<f32> {
73  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2]]
74      : tensor<1x1x1xf32> into tensor<1xf32>
75  %1 = linalg.tensor_collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
76  return %1 : tensor<f32>
77}
78// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
79//       CHECK:   linalg.tensor_collapse_shape %{{.*}} []
80//  CHECK-SAME:     tensor<1x1x1xf32> into tensor<f32>
81
82// -----
83
84func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x6x4x?x5xf32>
85{
86  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
87      : tensor<?x?xf32> into tensor<?x4x?xf32>
88  %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4]]
89      : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
90  return %1 : tensor<?x6x4x?x5xf32>
91}
92// CHECK-LABEL: expanding_tensor_reshapes
93//       CHECK:   linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
94//   CHECK-NOT:   linalg.tensor_expand_shape
95
96// -----
97
98func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
99                                             -> tensor<1x1x1xf32> {
100  %0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
101  %1 = linalg.tensor_expand_shape %0 [[0, 1, 2]]
102      : tensor<1xf32> into tensor<1x1x1xf32>
103  return %1 : tensor<1x1x1xf32>
104}
105// CHECK-LABEL: expanding_tensor_reshapes_to_zero
106//       CHECK:   linalg.tensor_expand_shape %{{.*}} []
107//  CHECK-SAME:     tensor<f32> into tensor<1x1x1xf32>
108
109// -----
110
111func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
112{
113  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
114      : tensor<12x4xf32> into tensor<3x4x4xf32>
115  %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]]
116      : tensor<3x4x4xf32> into tensor<12x4xf32>
117  return %1 : tensor<12x4xf32>
118}
119// CHECK-LABEL: @fold_tensor_reshape
120//   CHECK-NOT:   linalg.{{.*}}shape
121
122// -----
123
124func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
125{
126  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]]
127      : tensor<?x?xf32> into tensor<?x4x?xf32>
128  %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]]
129      : tensor<?x4x?xf32> into tensor<?x?xf32>
130  return %1 : tensor<?x?xf32>
131}
132// CHECK-LABEL: @fold_tensor_reshape_dynamic
133//   CHECK-NOT:   linalg.{{.*}}_shape
134
135// -----
136
137func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32>
138{
139  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
140      : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
141  %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3]]
142      : tensor<40320xf32> into tensor<24x5x42x8xf32>
143  return %1 : tensor<24x5x42x8xf32>
144}
145//      CHECK: func @reshape_collapse
146// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
147//      CHECK:   %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
148// CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
149//      CHECK:   return %[[RESULT]]
150
151// -----
152
153func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32>
154{
155  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3]]
156      : tensor<24x5x42x8xf32> into tensor<40320xf32>
157  %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]]
158      : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
159  return %1 : tensor<2x3x4x5x6x7x8xf32>
160}
161//      CHECK: func @reshape_expand
162// CHECK-SAME:   %[[ARG0:.+]]: tensor<24x5x42x8xf32>
163//      CHECK:   %[[RESULT:.+]] = linalg.tensor_expand_shape %[[ARG0]]
164// CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
165//      CHECK:   return %[[RESULT]]
166
167// -----
168
169func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32>
170{
171  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3]]
172    : tensor<2048xf32> into tensor<1x4x1x512xf32>
173  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]]
174    : tensor<1x4x1x512xf32> into tensor<4x512xf32>
175  return %1 : tensor<4x512xf32>
176}
177//       CHECK: func @expand_reshape_1D
178//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]]
179//  CHECK-SAME:   tensor<2048xf32> into tensor<4x512xf32>
180
181// -----
182
183func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32>
184{
185  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2], [3]]
186    : tensor<4x512xf32> into tensor<1x4x1x512xf32>
187  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]]
188    : tensor<1x4x1x512xf32> into tensor<2048xf32>
189  return %1 : tensor<2048xf32>
190}
191//       CHECK: func @fold_reshape_1D
192//       CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]]
193//  CHECK-SAME:   tensor<4x512xf32> into tensor<2048xf32>
194
195// -----
196
197func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32>
198{
199  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3], [4], [5]]
200    : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
201  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3], [4], [5]]
202    : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
203  return %1 : tensor<4x512x1x1xf32>
204}
205//       CHECK: func @fold_reshape_unit_dims
206//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]]
207//  CHECK-SAME:   tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
208
209// -----
210
211func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
212{
213  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
214    : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
215  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
216    : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
217  return %1 : tensor<4x512x1x512x4xf32>
218}
219//       CHECK: func @expand_reshape_unit_dims
220//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
221//  CHECK-SAME:   tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
222
223// -----
224
225func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
226{
227  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]]
228      : tensor<2xf32> into tensor<2x1x1xf32>
229  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]]
230      : tensor<2x1x1xf32> into tensor<2x1xf32>
231  return %1 : tensor<2x1xf32>
232}
233//       CHECK: func @fold_reshape_trailing_unit_dims
234//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]]
235//  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
236
237// -----
238
239func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32>
240{
241  %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
242    : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
243  %1 = linalg.tensor_collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
244    : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
245  return %1 : tensor<?x?x?x?xf32>
246}
247//       CHECK: func @collapse_reshape_unit_dims_dynamic
248//       CHECK: linalg.tensor_collapse_shape
249//  CHECK-SAME:   [0], [1, 2], [3, 4, 5], [6, 7, 8]
250//  CHECK-SAME:   tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
251
252// -----
253
254func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
255{
256  %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]]
257      : tensor<2xf32> into tensor<2x1x1xf32>
258  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]]
259      : tensor<2x1x1xf32> into tensor<2x1xf32>
260  return %1 : tensor<2x1xf32>
261}
262//       CHECK: func @fold_reshape_trailing_unit_dims
263//       CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]]
264//  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
265
266// -----
267
268func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32>
269{
270  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
271      : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
272  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]]
273      : tensor<?x1x1x1xf32> into tensor<?xf32>
274  return %1 : tensor<?xf32>
275}
276//       CHECK: func @fold_reshape_trailing_unit_dims_dynamic
277//       CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
278//  CHECK-SAME:   tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
279
280// -----
281
282func @no_fold_reshapes(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32>
283{
284  %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]]
285      : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
286  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]]
287      : tensor<?x?x1x?xf32> into tensor<?x?xf32>
288  return %1 : tensor<?x?xf32>
289}
290// CHECK-LABEL: func @no_fold_reshapes
291//       CHECK:   linalg.tensor_expand_shape
292//       CHECK:   linalg.tensor_collapse_shape
293
294// -----
295
296func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32>
297{
298  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2, 3], [4]]
299      : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
300  %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2], [3, 4]]
301      : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
302  return %1 : tensor<2x6x16xf32>
303}
304// CHECK-LABEL: func @no_fold_reshape_incompatible
305//       CHECK:   linalg.tensor_expand_shape
306//       CHECK:   linalg.tensor_collapse_shape
307
308// -----
309
310func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> {
311  %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]]
312      : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
313  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]]
314      : tensor<3x2x2x1xf32> into tensor<12x1xf32>
315  return %1 : tensor<12x1xf32>
316}
317//      CHECK: func @no_fold_reshape_empty_expr
318// CHECK-SAME:    %[[ARG0:.+]]: tensor<3x2x2xf32>
319//      CHECK:    %[[RARG0:.+]] = linalg.tensor_expand_shape %[[ARG0]]
320// CHECK-SAME:      [0], [1], [2, 3]
321//      CHECK:    %[[RES:.+]] = linalg.tensor_collapse_shape %[[RARG0]]
322// CHECK-SAME:      [0, 1, 2], [3]
323//      CHECK:    return %[[RES:.+]] : tensor<12x1xf32>
324
325// -----
326
327#accesses = [
328  affine_map<(i) -> (i)>
329]
330
331#trait = {
332  indexing_maps = #accesses,
333  iterator_types = ["parallel"]
334}
335
336func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
337  // memref<0x32> is expected to be dce'ed
338  linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32>
339
340  // tensor<0xf32> cannot be dce'ed
341  %1 = linalg.generic #trait outs(%arg1 : tensor<0xf32>) {
342  ^bb(%0: f32) :
343    linalg.yield %0 : f32
344  } -> tensor<0xf32>
345
346  return %1: tensor<0xf32>
347}
348// CHECK-LABEL: @dce_zero_memref
349//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32>
350//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32>
351//   CHECK-NOT:   linalg.copy
352//  CHECK-NEXT:   return %[[ARG1]]
353
354// -----
355
356func @reshape_splat_constant_int32() -> tensor<2x4x2xi32>
357{
358  %c0 = constant dense<42> : tensor<2x8xi32>
359  %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
360      : tensor<2x8xi32> into tensor<2x4x2xi32>
361  return %0 : tensor<2x4x2xi32>
362}
363// CHECK-LABEL: @reshape_splat_constant_int32
364//       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32>
365//   CHECK-NOT:   linalg.tensor_expand_shape
366//       CHECK:   return %[[CST]]
367
368func @reshape_splat_constant_int16() -> tensor<2x4x2xi16>
369{
370  %c0 = constant dense<42> : tensor<2x8xi16>
371  %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
372      : tensor<2x8xi16> into tensor<2x4x2xi16>
373  return %0 : tensor<2x4x2xi16>
374}
375// CHECK-LABEL: @reshape_splat_constant_int16
376//       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16>
377//   CHECK-NOT:   linalg.tensor_expand_shape
378//       CHECK:   return %[[CST]]
379
380func @reshape_splat_constant_float32() -> tensor<2x4x2xf32>
381{
382  %c0 = constant dense<42.0> : tensor<2x8xf32>
383  %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
384      : tensor<2x8xf32> into tensor<2x4x2xf32>
385  return %0 : tensor<2x4x2xf32>
386}
387// CHECK-LABEL: @reshape_splat_constant_float32
388//       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32>
389//   CHECK-NOT:   linalg.tensor_expand_shape
390//       CHECK:   return %[[CST]]
391
392func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
393{
394  %c0 = constant dense<42.0> : tensor<2x8xf64>
395  %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]]
396      : tensor<2x8xf64> into tensor<2x4x2xf64>
397  return %0 : tensor<2x4x2xf64>
398}
399// CHECK-LABEL: @reshape_splat_constant_float64
400//       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
401//   CHECK-NOT:   linalg.tensor_expand_shape
402//       CHECK:   return %[[CST]]
403
404// -----
405
406// CHECK-LABEL: func @tensor.cast(
407func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
408  -> tensor<3x?xf32>
409{
410  %ta = tensor.cast %a : tensor<3x4xf32> to tensor<?x?xf32>
411  %tb = tensor.cast %b : tensor<4x?xf32> to tensor<?x?xf32>
412  %tc = tensor.cast %c : tensor<3x?xf32> to tensor<?x?xf32>
413
414  //      CHECK:  linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
415  // CHECK-SAME:    outs({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
416  %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
417                    outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
418
419  %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
420
421  return %1: tensor<3x?xf32>
422}
423
424// -----
425
426// CHECK-LABEL: func @linalg_effects(
427//  CHECK-SAME:     %[[A:[a-z0-9]*]]: tensor<?x?xf32>
428//  CHECK-SAME:     %[[B:[a-z0-9]*]]: memref<?x?xf32>
429//  CHECK-SAME:     %[[C:[a-z0-9]*]]: tensor<?x?xf32>
430func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) {
431  // CHECK-NOT:   %{{.*}} = linalg.matmul
432  %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
433                    outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
434
435  // CHECK:   linalg.matmul
436  linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>)
437               outs(%b : memref<?x?xf32>)
438  return
439}
440
441// -----
442
443func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
444  %c6 = constant 6 : index
445  %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
446  return %0 : tensor<4x5x?xf32>
447}
448// CHECK: func @init_tensor_canonicalize
449// CHECK:   %[[T0:.+]] = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32>
450// CHECK:   %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
451// CHECK:   return %[[T1]]
452
453// -----
454
455func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
456  %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
457  %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]]
458      : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
459  return %1 : tensor<2x3x5x4x?x7xf32>
460}
461//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
462//      CHECK: func @init_tensor_reshape_expansion
463// CHECK-SAME:     %[[ARG0:.+]]: index
464// CHECK-NEXT:   %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
465// CHECK-NEXT:   %[[INIT:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[D]], 7]
466// CHECK-NEXT:   return %[[INIT]]
467
468// -----
469
470func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
471  %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32>
472  %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2], [3, 4, 5]]
473      : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
474  return %1 : tensor<6x5x?xf32>
475}
476//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
477//      CHECK: func @init_tensor_reshape_collapse
478// CHECK-SAME:     %[[ARG0:.+]]: index
479// CHECK-NEXT:   %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
480// CHECK-NEXT:   %[[INIT:.+]] = linalg.init_tensor [6, 5, %[[D]]]
481// CHECK-NEXT:   return %[[INIT]]
482
483// -----
484
485#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
486func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
487  -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
488  %c0 = constant 0 : index
489  %c1 = constant 1 : index
490  %c2 = constant 2 : index
491  %0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
492  %1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
493  %2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
494  %3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
495  %4, %5 = linalg.generic {
496    indexing_maps = [#map, #map, #map, #map],
497    iterator_types = ["parallel", "parallel", "parallel"]
498  } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
499    outs(%3, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
500  ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32):
501    linalg.yield %arg3, %arg2 : f32, f32
502  } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
503  return %4, %5 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
504}
505// CHECK-LABEL: func @remove_no_op
506//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
507//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
508//       CHECK:     return %[[ARG1]], %[[ARG0]]
509
510// -----
511
512#map = affine_map<(d0, d1) -> (d0, d1)>
513func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
514  %c0 = constant 0 : index
515  %c1 = constant 1 : index
516  %cst = constant 1.000000e+00 : f32
517  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
518  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
519  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
520  br ^bb1(%cst : f32)
521
522^bb1(%arg1 : f32):
523  %3 = linalg.generic
524    {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
525    ins(%arg0 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) {
526    ^bb0(%arg2: f32, %arg3 : f32):
527      linalg.yield %arg1 : f32
528    } -> tensor<?x?xf32>
529  return %3 : tensor<?x?xf32>
530}
531// CHECK-LABEL: func @keep_not_noop
532//       CHECK:   %[[RESULT:.+]] = linalg.generic
533//       CHECK:   return %[[RESULT]]
534
535// -----
536
537#map = affine_map<(d0, d1) -> (d0, d1)>
538func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
539  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
540  %c0 = constant 0 : index
541  %c1 = constant 1 : index
542  %cst = constant 1.000000e+00 : f32
543  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
544  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
545  %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
546  br ^bb1(%cst : f32)
547
548^bb1(%arg2 : f32):
549  %3:2 = linalg.generic
550    {indexing_maps = [#map, #map, #map, #map],
551     iterator_types = ["parallel", "parallel"]}
552    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
553    outs(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>) {
554    ^bb0(%arg3: f32, %arg4 : f32, %arg5 : f32, %arg6 : f32):
555      linalg.yield %arg2, %arg4 : f32, f32
556    } -> (tensor<?x?xf32>, tensor<?x?xf32>)
557  return %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
558}
559// CHECK-LABEL: func @keep_not_noop
560//       CHECK:   %[[RESULT:.+]]:2 = linalg.generic
561//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
562
563// -----
564
565func @fold_init_tensor_with_slice
566  (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32>
567{
568  %0 = linalg.init_tensor[%arg0, 10, 40] : tensor<?x10x40xf32>
569  %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
570    : tensor<?x10x40xf32> to tensor<5x?x20xf32>
571  return %1 : tensor<5x?x20xf32>
572}
573//      CHECK: func @fold_init_tensor_with_slice
574// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
575// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
576//      CHECK:   %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20]
577//      CHECK:   return %[[T0]]
578
579// -----
580
581#accesses = [
582  affine_map<(i, j) -> (i, j)>
583]
584
585#trait = {
586  indexing_maps = #accesses,
587  iterator_types = ["parallel", "parallel"]
588}
589
590// CHECK-LABEL: func @dead_linalg_tensor
591//   CHECK-NOT:   linalg.fill
592//   CHECK-NOT:   linalg.matmul
593//   CHECK-NOT:   linalg.generic
594//   CHECK-NOT:   linalg.pad_tensor
595//       CHECK:   return
596func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>,
597                         %arg2: tensor<?x?xf32>, %high : index) {
598  %c0_i32 = constant 0 : i32
599  %c0 = constant 0 : index
600  %cst = constant 0.000000e+00 : f32
601  %0 = linalg.fill(%c0_i32, %arg0) : i32, tensor<7x7xi32> -> tensor<7x7xi32>
602  %1 = linalg.matmul ins(%arg1, %arg1: tensor<7x7xf32>, tensor<7x7xf32>)
603                     outs(%arg1: tensor<7x7xf32>) -> tensor<7x7xf32>
604  %2 = linalg.generic #trait outs(%arg0 : tensor<7x7xi32>) {
605  ^bb(%3: i32) :
606    linalg.yield %3 : i32
607  } -> tensor<7x7xi32>
608  %3 = linalg.pad_tensor %arg2 low[%c0, %c0] high[%high, %high] {
609        ^bb0(%arg9: index, %arg10: index):  // no predecessors
610          linalg.yield %cst : f32
611  } : tensor<?x?xf32> to tensor<2x4xf32>
612  return
613}
614
615// -----
616
617// CHECK-LABEL: func @pad_tensor_same_static_shape(
618//  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
619//   CHECK-NOT:   linalg.pad_tensor
620//       CHECK:   return %[[ARG0]]
621func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
622    -> tensor<5x6xf32> {
623  %cst = constant 0.000000e+00 : f32
624  %0 = linalg.pad_tensor %arg0 low[%a, 0] high[0, %a] {
625        ^bb0(%arg1: index, %arg2: index):
626          linalg.yield %cst : f32
627  } : tensor<5x6xf32> to tensor<5x6xf32>
628  return %0 : tensor<5x6xf32>
629}
630func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
631    %arg3 : index) -> tensor<?x?xf32> {
632  %c0 = constant 0 : index
633  %c1 = constant 1 : index
634  %c21 = constant 21 : index
635  %c42 = constant 42 : index
636  %0 = linalg.init_tensor [%c21, %c42] : tensor<?x?xf32>
637  %1 = linalg.fill(%arg1, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
638  %2 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
639  %3 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
640  %4 = tensor.insert_slice %arg0 into %1[%arg2, %arg3] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
641  return %4 : tensor<?x?xf32>
642}
643// CHECK-LABEL: func @propogate_casts
644//       CHECK:   %[[INIT:.+]] = linalg.init_tensor [21, 42]
645//       CHECK:   %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]])
646//       CHECK:   %[[INSERTED:.+]] = tensor.insert_slice %{{.+}} into %[[FILL]]
647//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[INSERTED]]
648//       CHECK:   return %[[RESULT]]
649
650// -----
651
652// CHECK-LABEL: @self_copy
653func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
654
655//   CHECK-NOT: linalg.copy
656  linalg.copy(%arg0, %arg0): memref<2x3x?x4xf32>, memref<2x3x?x4xf32>
657
658//   CHECK: return
659  return
660}
661
662// -----
663
664// CHECK-LABEL: func @fold_fill_reshape()
665func @fold_fill_reshape() -> tensor<6x4xf32> {
666  %zero = constant 0.0 : f32
667  // CHECK: %[[INIT:.+]] = linalg.init_tensor [6, 4] : tensor<6x4xf32>
668  %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32>
669  // CHECK: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor<6x4xf32> -> tensor<6x4xf32>
670  %fill = linalg.fill(%zero, %init) : f32, tensor<1x2x3x4xf32> -> tensor<1x2x3x4xf32>
671  %reshape = linalg.tensor_collapse_shape %fill [[0, 1, 2], [3]]
672      : tensor<1x2x3x4xf32> into tensor<6x4xf32>
673  // CHECK: return %[[FILL]] : tensor<6x4xf32>
674  return %reshape : tensor<6x4xf32>
675}
676
677// -----
678
679//       CHECK: func @fold_fill_reshape_dynamic
680//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?x?xf32>
681func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> {
682  %zero = constant 0.0 : f32
683  // CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
684  %0 = linalg.fill(%zero, %arg0) : f32, tensor<?x?x?x?x?xf32> -> tensor<?x?x?x?x?xf32>
685  // CHECK: %[[RESULT:.+]] = linalg.fill(%{{.+}}, %[[RESHAPE]])
686  %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4]]
687      : tensor<?x?x?x?x?xf32> into tensor<?x?xf32>
688  // CHECK: return %[[RESULT]]
689  return %1 : tensor<?x?xf32>
690}
691
692
693// -----
694
695func private @foo(%A: memref<48xf32>, %B: tensor<48xf32>,
696                  %C: memref<48xf32>) -> (tensor<48xf32>)
697
698func @fold_tiled_loop_results(%A: memref<48xf32>, %B: tensor<48xf32>,
699    %C: memref<48xf32>, %C_tensor: tensor<48xf32>) -> tensor<48xf32> {
700  %c0 = constant 0 : index
701  %c24 = constant 24 : index
702  %c48 = constant 48 : index
703  %useful, %useless = linalg.tiled_loop (%i) = (%c0) to (%c48) step (%c24)
704      ins (%A_ = %A: memref<48xf32>)
705      outs (%B_ = %B: tensor<48xf32>,
706            %CT_ = %C_tensor: tensor<48xf32>,
707            %C_ = %C: memref<48xf32>) {
708        %result = call @foo(%A_, %B_, %C_)
709          : (memref<48xf32>, tensor<48xf32>, memref<48xf32>)-> (tensor<48xf32>)
710    linalg.yield %result, %CT_ : tensor<48xf32>, tensor<48xf32>
711  }
712  return %useful : tensor<48xf32>
713}
714
715// CHECK-LABEL: func @fold_tiled_loop_results(
716// CHECK-SAME:   %[[A:.*]]: [[BUF_TY:memref<48xf32>]], %[[B:.*]]: [[TY:tensor<48xf32>]],
717// CHECK-SAME:   %[[C:.*]]: [[BUF_TY]],  %[[C_TENSOR:.*]]: [[TY]]) -> [[TY]] {
718
719// CHECK-DAG:  %[[C0:.*]] = constant 0 : index
720// CHECK-DAG:  %[[C24:.*]] = constant 24 : index
721// CHECK-DAG:  %[[C48:.*]] = constant 48 : index
722
723// CHECK-NOT: %{{.*}} = linalg.tiled_loop
724// CHECK:  %[[RESULT:.*]] = linalg.tiled_loop (%{{.*}}) = (%[[C0]])
725// CHECK-SAME: to (%[[C48]]) step (%[[C24]])
726// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: [[BUF_TY]])
727// CHECK-SAME: outs (%[[B_:.*]] = %[[B]]: [[TY]], %[[C_:.*]] = %[[C]]: [[BUF_TY]]) {
728// CHECK-NEXT:   %[[RES:.*]] = call @foo(%[[A_]], %[[B_]], %[[C_]])
729// CHECK-NEXT:   linalg.yield %[[RES]] :
730
731// CHECK: return %[[RESULT]]
732
733// -----
734
735func private @foo(%A: memref<192xf32>, %B: tensor<192xf32>) -> tensor<192xf32>
736
737func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
738                             %B_tensor: tensor<192xf32>) -> tensor<192xf32> {
739  %c0 = constant 0 : index
740  %c24 = constant 24 : index
741  %c192 = constant 192 : index
742  %result = linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24)
743      ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>)
744      outs (%BT_ = %B_tensor: tensor<192xf32>) {
745    %0 = call @foo(%A_, %BT_) : (memref<192xf32>, tensor<192xf32>) -> tensor<192xf32>
746    linalg.yield %0 : tensor<192xf32>
747  }
748  return %result : tensor<192xf32>
749}
750
751// CHECK-LABEL: func @fold_tiled_loop_inputs
752// CHECK: %[[RESULT:.*]] = linalg.tiled_loop
753// CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>)
754
755// CHECK: return %[[RESULT]]
756
757// -----
758
759func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
760  %c0 = constant 0 : index
761  %cst = constant 0.0 : f32
762  %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
763  %1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %c0]  {
764    ^bb0(%arg1: index, %arg2: index):  // no predecessors
765      linalg.yield %cst : f32
766  } : tensor<?x?xf32> to tensor<4x4xf32>
767  return %1 : tensor<4x4xf32>
768}
769// CHECK-LABEL: @tensor_pad_cast
770// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
771// CHECK: return %[[ARG0]]
772
773// -----
774
775// CHECK-LABEL: func @fold_pad_tensor_source_cast(
776//  CHECK-SAME:                  %[[ARG0:.*]]: tensor<4x?xf32>
777//   CHECK-NOT:   tensor.cast
778//       CHECK:   %[[RESULT:.*]] = linalg.pad_tensor %[[ARG0]]
779func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
780  %cst = constant 0.0 : f32
781  %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
782  %1 = linalg.pad_tensor %0 low[0, 0] high[0, 1]  {
783    ^bb0(%arg1: index, %arg2: index):  // no predecessors
784      linalg.yield %cst : f32
785  } : tensor<?x?xf32> to tensor<4x4xf32>
786  return %1 : tensor<4x4xf32>
787}
788
789// -----
790
791// CHECK-LABEL: func @pad_static_zero_cast(
792//  CHECK-SAME:                  %[[ARG0:.*]]: tensor<?x?x?xf32>
793//   CHECK-NOT:   linalg.pad_tensor
794//       CHECK:   %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
795//       CHECK:   return %[[RESULT]]
796func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
797  %c0 = constant 0 : index
798  %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] {
799    ^bb0(%arg1: index, %arg2: index, %arg3: index):
800      linalg.yield %pad_value : f32
801    } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
802
803  return %0 : tensor<2x3x4xf32>
804}
805
806// -----
807
808func private @some_use(%i : index, %j : index)
809
810// CHECK-LABEL: func @init_canonicalize
811//  CHECK-SAME:   %[[I:.*]]: index
812func @init_canonicalize(%i : index) {
813  %c0 = constant 0 : index
814  %c1 = constant 1 : index
815
816  // CHECK-NOT: init_tensor
817  %0 = linalg.init_tensor [%i, 42] : tensor<?x42xf32>
818
819  // CHECK-NOT: tensor.dim
820  %1 = tensor.dim %0, %c0: tensor<?x42xf32>
821  %2 = tensor.dim %0, %c1: tensor<?x42xf32>
822
823  // CHECK: %[[c42:.*]] = constant 42 : index
824  // CHECK: call @some_use(%[[I]], %[[c42]])
825  call @some_use(%1, %2) : (index, index) -> ()
826
827  return
828}
829
830// -----
831
832// CHECK-LABEL: func @rank_reducing_init_extract
833func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
834  // CHECK: linalg.init_tensor [2] : tensor<2xf32>
835  %a = linalg.init_tensor [%sz, 2] : tensor<?x2xf32>
836
837  // CHECK-NOT: extract
838  %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
839  return %r: tensor<2xf32>
840}
841
842// -----
843
844// CHECK-LABEL: func @dim_of_pad_tensor(
845//  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>
846//       CHECK:     %[[C0:.*]] = constant 0 : index
847//       CHECK:     %[[RESULT:.*]] = tensor.dim %[[ARG1]], %[[C0]]
848//       CHECK:     return %[[RESULT]]
849func @dim_of_pad_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
850                        %pad_value: f32) -> index {
851  %c0 = constant 0 : index
852  %0 = linalg.pad_tensor %arg0 low[2, 3] high[4, 5] into %arg1 {
853    ^bb0(%arg2: index, %arg3: index):
854      linalg.yield %pad_value : f32
855    } : tensor<?x?xf32> to tensor<?x?xf32>
856  %r = tensor.dim %0, %c0 : tensor<?x?xf32>
857  return %r : index
858}
859