1// RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s
2
3// Test case: Basic folding of memref.tensor_load(memref.buffer_cast(t)) -> t
4// CHECK-LABEL: func @tensor_load_of_buffer_cast(
5//  CHECK-SAME:   %[[TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
6//       CHECK: return %[[TENSOR]]
7func @tensor_load_of_buffer_cast(%arg0: tensor<?xf32>) -> tensor<?xf32> {
8  %0 = memref.buffer_cast %arg0 : memref<?xf32>
9  %1 = memref.tensor_load %0 : memref<?xf32>
10  return %1 : tensor<?xf32>
11}
12
13// -----
14
15// Test case: Basic folding of memref.buffer_cast(memref.tensor_load(m)) -> m
16// CHECK-LABEL: func @buffer_cast_of_tensor_load(
17//  CHECK-SAME:   %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> {
18//       CHECK: return %[[MEMREF]]
19func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
20  %0 = memref.tensor_load %arg0 : memref<?xf32>
21  %1 = memref.buffer_cast %0 : memref<?xf32>
22  return %1 : memref<?xf32>
23}
24
25// -----
26
27// Test case: If the memrefs are not the same type, don't fold them.
28// Test case: If the memrefs are not cast-compatible (e.g. different address space),
29// don't canonicalize them either.
30// CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load(
31//  CHECK-SAME:   %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>)
32//  CHECK-SAME:     -> memref<?xf32, 7> {
33//       CHECK: %[[TENSOR:.*]] = memref.tensor_load
34//  CHECK_SAME:   %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2>
35//       CHECK: %[[MEMREF_ADDRSPACE7:.*]] = memref.buffer_cast
36//  CHECK_SAME:   %[[TENSOR]] : memref<?xf32, 7>
37//       CHECK: return %[[MEMREF_ADDRSPACE7]]
38func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref<?xf32, 7> {
39  %0 = memref.tensor_load %arg0 : memref<?xf32, 2>
40  %1 = memref.buffer_cast %0 : memref<?xf32, 7>
41  return %1 : memref<?xf32, 7>
42}
43
44// -----
45
46// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>
47// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
48
49// Test case: If the memrefs are cast-compatible, canonicalize.
50// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load(
51//  CHECK-SAME:   %[[M:.*]]: memref<?xf32, #[[$OFF_3]]>)
52//  CHEKC-SAME:     -> memref<?xf32, #[[$OFF_UNK]]> {
53//   CHECK-NOT: memref.tensor_load
54//   CHECK-NOT: memref.buffer_cast
55//       CHECK: %[[R:.*]] = memref.cast %[[M]]
56//  CHECK-SAME:   memref<?xf32, #[[$OFF_3]]> to memref<?xf32, #[[$OFF_UNK]]>
57//       CHECK: return %[[R]]
58func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref<?xf32, offset: 3, strides: [1]>)
59  -> memref<?xf32, offset: ?, strides: [1]>
60{
61  %0 = memref.tensor_load %arg0 : memref<?xf32, offset: 3, strides: [1]>
62  %1 = memref.buffer_cast %0 : memref<?xf32, offset: ?, strides: [1]>
63  return %1 : memref<?xf32, offset: ?, strides: [1]>
64}
65
66// -----
67
68// CHECK-LABEL: func @subview_of_memcast
69//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
70//       CHECK:   %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
71//       CHECK:   %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
72//       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
73func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
74  memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
75  %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
76  %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
77    memref<?x?x16x32xi8> to
78    memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
79  return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
80}
81
82// -----
83
84// CHECK-LABEL: func @subview_of_static_full_size
85// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
86// CHECK-NOT: memref.subview
87// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
88func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
89  %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
90  return %0 : memref<4x6x16x32xi8>
91}
92
93// -----
94
95#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
96func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
97    %arg2 : index) -> memref<?x?x?xf32, #map0>
98{
99  %c0 = constant 0 : index
100  %c1 = constant 1 : index
101  %c4 = constant 4 : index
102  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map0>
103  return %0 : memref<?x?x?xf32, #map0>
104}
105// CHECK-LABEL: func @subview_canonicalize
106//  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
107//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
108//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
109//  CHECK-SAME:      : memref<?x?x?xf32> to memref<4x1x?xf32
110//       CHECK:   %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
111//       CHEKC:   return %[[RESULT]]
112
113// -----
114
115#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
116func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
117    %arg2 : index) -> memref<?x?xf32, #map0>
118{
119  %c0 = constant 0 : index
120  %c1 = constant 1 : index
121  %c4 = constant 4 : index
122  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, #map0>
123  return %0 : memref<?x?xf32, #map0>
124}
125// CHECK-LABEL: func @rank_reducing_subview_canonicalize
126//  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
127//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
128//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
129//  CHECK-SAME:      : memref<?x?x?xf32> to memref<4x?xf32
130//       CHECK:   %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
131//       CHECK:   return %[[RESULT]]
132
133// -----
134
135// CHECK-LABEL: @clone_before_dealloc
136// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
137func @clone_before_dealloc(%arg0: memref<?xf32>) -> memref<?xf32> {
138  // CHECK-NEXT: return %[[ARG]]
139  %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32>
140  memref.dealloc %arg0 : memref<?xf32>
141  return %0 : memref<?xf32>
142}
143
144// -----
145
146// CHECK-LABEL: @clone_before_dealloc
147// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
148func @clone_before_dealloc(%arg0: memref<?xf32>) -> memref<?xf32> {
149  // CHECK-NEXT: "use"(%arg0)
150  // CHECK-NEXT: return %[[ARG]]
151  %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32>
152  "use"(%0) : (memref<?xf32>) -> ()
153  memref.dealloc %0 : memref<?xf32>
154  return %arg0 : memref<?xf32>
155}
156
157// -----
158
159// CHECK-LABEL: @clone_after_cast
160// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
161func @clone_after_cast(%arg0: memref<?xf32>) -> memref<32xf32> {
162  // CHECK-NEXT: memref.clone %[[ARG]] : memref<?xf32> to memref<32xf32>
163  // CHECK-NOT: memref.cast
164  %0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32>
165  %1 = memref.clone %0 : memref<32xf32> to memref<32xf32>
166  return %1 : memref<32xf32>
167}
168
169// -----
170
171// CHECK-LABEL: @clone_and_cast
172// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
173func @clone_and_cast(%arg0: memref<?xf32>) -> memref<32xf32> {
174  // CHECK-NEXT: %[[RES:.*]] = memref.cast %[[ARG]] : memref<?xf32> to memref<32xf32>
175  %0 = memref.clone %arg0 : memref<?xf32> to memref<32xf32>
176  // CHECK-NEXT: return %[[RES]]
177  memref.dealloc %arg0 : memref<?xf32>
178  return %0 : memref<32xf32>
179}
180
181// -----
182
183// CHECK-LABEL: @alias_is_freed
184func @alias_is_freed(%arg0 : memref<?xf32>) {
185  // CHECK: memref.clone
186  // CHECK: memref.dealloc
187  // CHECK: memref.dealloc
188  %0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32>
189  %1 = memref.clone %0 : memref<32xf32> to memref<32xf32>
190  memref.dealloc %arg0 : memref<?xf32>
191  "use"(%1) : (memref<32xf32>) -> ()
192  memref.dealloc %1 : memref<32xf32>
193  return
194}
195
196// -----
197
198// Verify SimplifyClones skips clones with multiple deallocations.
199// CHECK-LABEL: @clone_multiple_dealloc_of_source
200// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
201func @clone_multiple_dealloc_of_source(%arg0: memref<?xf32>) -> memref<?xf32> {
202  // CHECK-NEXT: %[[RES:.*]] = memref.clone %[[ARG]]
203  // CHECK: memref.dealloc %[[ARG]]
204  // CHECK: memref.dealloc %[[ARG]]
205  // CHECK: return %[[RES]]
206  %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32>
207  "if_else"() ({
208    memref.dealloc %arg0 : memref<?xf32>
209    }, {
210    memref.dealloc %arg0 : memref<?xf32>
211    }) : () -> ()
212  return %0 : memref<?xf32>
213}
214
215// -----
216
217// CHECK-LABEL: @clone_multiple_dealloc_of_clone
218// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
219func @clone_multiple_dealloc_of_clone(%arg0: memref<?xf32>) -> memref<?xf32> {
220  // CHECK-NEXT: %[[CLONE:.*]] = memref.clone %[[ARG]]
221  // CHECK: memref.dealloc %[[CLONE]]
222  // CHECK: memref.dealloc %[[CLONE]]
223  // CHECK: return %[[ARG]]
224  %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32>
225  "use"(%0) : (memref<?xf32>) -> ()
226  "if_else"() ({
227    memref.dealloc %0 : memref<?xf32>
228    }, {
229    memref.dealloc %0 : memref<?xf32>
230    }) : () -> ()
231  return %arg0 : memref<?xf32>
232}
233
234// -----
235
236// CHECK-LABEL: func @dim_of_sized_view
237//  CHECK-SAME:   %{{[a-z0-9A-Z_]+}}: memref<?xi8>
238//  CHECK-SAME:   %[[SIZE:.[a-z0-9A-Z_]+]]: index
239//       CHECK:   return %[[SIZE]] : index
240func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
241  %c0 = constant 0 : index
242  %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [0] : memref<?xi8> to memref<?xi8>
243  %1 = memref.dim %0, %c0 : memref<?xi8>
244  return %1 : index
245}
246
247// -----
248
249// CHECK-LABEL: func @no_fold_of_store
250//  CHECK:   %[[cst:.+]] = memref.cast %arg
251//  CHECK:   memref.store %[[cst]]
252func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
253  %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
254  memref.store %0, %holder[] : memref<memref<?xi8>>
255  return
256}
257
258// -----
259
260// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs))
261//            -> tensor.extract(%v, %idx)
262// CHECK-LABEL: func @load_from_buffer_cast(
263//  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
264//  CHECK-SAME:     %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
265//       CHECK:   %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
266//   CHECK-NOT:   memref.load
267//       CHECK:   return %[[RES]] : f32
268func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
269  %0 = memref.buffer_cast %arg2 : memref<?x?xf32>
270  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
271  return %1 : f32
272}
273
274// -----
275
276
277// Test case: Basic folding of tensor.dim(memref.tensor_load(m)) -> memref.dim(m).
278// CHECK-LABEL: func @dim_of_tensor_load(
279//  CHECK-SAME:     %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
280//       CHECK:   %[[C0:.*]] = constant 0
281//       CHECK:   %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]]
282//       CHECK:   return %[[D]] : index
283func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
284  %c0 = constant 0 : index
285  %0 = memref.tensor_load %arg0 : memref<?xf32>
286  %1 = tensor.dim %0, %c0 : tensor<?xf32>
287  return %1 : index
288}
289
290// -----
291
292// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size
293// CHECK-LABEL: func @dim_of_alloca(
294//  CHECK-SAME:     %[[SIZE:[0-9a-z]+]]: index
295//  CHECK-NEXT:   return %[[SIZE]] : index
296func @dim_of_alloca(%size: index) -> index {
297  %0 = memref.alloca(%size) : memref<?xindex>
298  %c0 = constant 0 : index
299  %1 = memref.dim %0, %c0 : memref<?xindex>
300  return %1 : index
301}
302
303// -----
304
305// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v)
306// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size(
307//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>
308//  CHECK-NEXT:   %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32>
309//  CHECK-NEXT:   return %[[RANK]] : index
310func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
311  %0 = rank %arg0 : memref<*xf32>
312  %1 = memref.alloca(%0) : memref<?xindex>
313  %c0 = constant 0 : index
314  %2 = memref.dim %1, %c0 : memref<?xindex>
315  return %2 : index
316}
317
318// -----
319
320// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
321// CHECK-LABEL: func @dim_of_memref_reshape(
322//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
323//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xindex>
324//  CHECK-NEXT:   %[[IDX:.*]] = constant 3
325//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
326//  CHECK-NEXT:   memref.store
327//   CHECK-NOT:   memref.dim
328//       CHECK:   return %[[DIM]] : index
329func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
330    -> index {
331  %c3 = constant 3 : index
332  %0 = memref.reshape %arg0(%arg1)
333      : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
334  // Update the shape to test that he load ends up in the right place.
335  memref.store %c3, %arg1[%c3] : memref<?xindex>
336  %1 = memref.dim %0, %c3 : memref<*xf32>
337  return %1 : index
338}
339
340// -----
341
342// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
343// CHECK-LABEL: func @dim_of_memref_reshape_i32(
344//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
345//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
346//  CHECK-NEXT:   %[[IDX:.*]] = constant 3
347//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
348//  CHECK-NEXT:   %[[CAST:.*]] = index_cast %[[DIM]]
349//   CHECK-NOT:   memref.dim
350//       CHECK:   return %[[CAST]] : index
351func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
352    -> index {
353  %c3 = constant 3 : index
354  %0 = memref.reshape %arg0(%arg1)
355      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
356  %1 = memref.dim %0, %c3 : memref<*xf32>
357  return %1 : index
358}
359
360// -----
361
362// CHECK-LABEL: func @tensor_cast_to_memref
363//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
364//       CHECK:   %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8>
365//       CHECK:   %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
366//       CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
367func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
368  memref<?x?x16x32xi8> {
369  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
370  %1 = memref.buffer_cast %0 : memref<?x?x16x32xi8>
371  return %1 : memref<?x?x16x32xi8>
372}
373
374// -----
375
376// CHECK-LABEL: func @alloc_const_fold
377func @alloc_const_fold() -> memref<?xf32> {
378  // CHECK-NEXT: %0 = memref.alloc() : memref<4xf32>
379  %c4 = constant 4 : index
380  %a = memref.alloc(%c4) : memref<?xf32>
381
382  // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
383  // CHECK-NEXT: return %1 : memref<?xf32>
384  return %a : memref<?xf32>
385}
386
387// -----
388
389// CHECK-LABEL: func @alloc_alignment_const_fold
390func @alloc_alignment_const_fold() -> memref<?xf32> {
391  // CHECK-NEXT: %0 = memref.alloc() {alignment = 4096 : i64} : memref<4xf32>
392  %c4 = constant 4 : index
393  %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref<?xf32>
394
395  // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
396  // CHECK-NEXT: return %1 : memref<?xf32>
397  return %a : memref<?xf32>
398}
399
400// -----
401
402// CHECK-LABEL: func @alloc_const_fold_with_symbols1(
403//  CHECK: %[[c1:.+]] = constant 1 : index
404//  CHECK: %[[mem1:.+]] = memref.alloc({{.*}})[%[[c1]], %[[c1]]] : memref<?xi32, #map>
405//  CHECK: return %[[mem1]] : memref<?xi32, #map>
406#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
407func @alloc_const_fold_with_symbols1(%arg0 : index) -> memref<?xi32, #map0> {
408  %c1 = constant 1 : index
409  %0 = memref.alloc(%arg0)[%c1, %c1] : memref<?xi32, #map0>
410  return %0 : memref<?xi32, #map0>
411}
412
413// -----
414
415// CHECK-LABEL: func @alloc_const_fold_with_symbols2(
416//  CHECK: %[[c1:.+]] = constant 1 : index
417//  CHECK: %[[mem1:.+]] = memref.alloc()[%[[c1]], %[[c1]]] : memref<1xi32, #map>
418//  CHECK: %[[mem2:.+]] = memref.cast %[[mem1]] : memref<1xi32, #map> to memref<?xi32, #map>
419//  CHECK: return %[[mem2]] : memref<?xi32, #map>
420#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
421func @alloc_const_fold_with_symbols2() -> memref<?xi32, #map0> {
422  %c1 = constant 1 : index
423  %0 = memref.alloc(%c1)[%c1, %c1] : memref<?xi32, #map0>
424  return %0 : memref<?xi32, #map0>
425}
426
427// -----
428// CHECK-LABEL: func @allocator
429// CHECK:   %[[alloc:.+]] = memref.alloc
430// CHECK:   memref.store %[[alloc:.+]], %arg0
431func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index)  {
432  %0 = memref.alloc(%arg1) : memref<?xi32>
433  memref.store %0, %arg0[] : memref<memref<?xi32>>
434  return
435}
436
437// -----
438
439func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
440                                             -> memref<f32> {
441  %0 = memref.collapse_shape %arg0 [[0, 1, 2]]
442      : memref<1x1x1xf32> into memref<1xf32>
443  %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32>
444  return %1 : memref<f32>
445}
446// CHECK-LABEL: collapsing_memref_reshapes_to_zero
447//       CHECK:   memref.collapse_shape %{{.*}} []
448//  CHECK-SAME:     memref<1x1x1xf32> into memref<f32>
449
450// -----
451
452func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>)
453    -> memref<?x?xf32> {
454  %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
455      : memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
456  %1 = memref.collapse_shape %0 [[0, 1], [2]]
457      : memref<?x?x?xf32> into memref<?x?xf32>
458  return %1 : memref<?x?xf32>
459}
460// CHECK-LABEL: collapsing_memref_reshapes
461//       CHECK:   memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
462//   CHECK-NOT:   memref.collapse_shape
463
464// -----
465
466func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>)
467    -> memref<?x6x4x5x?xf32> {
468  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
469      : memref<?x?xf32> into memref<?x4x?xf32>
470  %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]]
471      : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
472  return %1 : memref<?x6x4x5x?xf32>
473}
474// CHECK-LABEL: expanding_memref_reshapes
475//       CHECK:   memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
476//   CHECK-NOT:   memref.expand_shape
477
478// -----
479
480func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
481                                             -> memref<1x1x1xf32> {
482  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
483  %1 = memref.expand_shape %0 [[0, 1, 2]]
484      : memref<1xf32> into memref<1x1x1xf32>
485  return %1 : memref<1x1x1xf32>
486}
487// CHECK-LABEL: expanding_memref_reshapes_to_zero
488//       CHECK:   memref.expand_shape %{{.*}} []
489//  CHECK-SAME:     memref<f32> into memref<1x1x1xf32>
490
491// -----
492
493func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
494  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
495      : memref<12x4xf32> into memref<3x4x4xf32>
496  %1 = memref.collapse_shape %0 [[0, 1], [2]]
497      : memref<3x4x4xf32> into memref<12x4xf32>
498  return %1 : memref<12x4xf32>
499}
500// CHECK-LABEL: @fold_memref_reshape
501//   CHECK-NOT:   linalg.{{.*}}_shape
502
503// -----
504
505func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
506  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
507      : memref<?x?xf32> into memref<?x4x?xf32>
508  %1 = memref.collapse_shape %0 [[0, 1], [2]]
509      : memref<?x4x?xf32> into memref<?x?xf32>
510  return %1 : memref<?x?xf32>
511}
512// CHECK-LABEL: @fold_memref_reshape_dynamic
513//   CHECK-NOT:   linalg.{{.*}}_shape
514