1// RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s 2 3// Test case: Basic folding of memref.tensor_load(memref.buffer_cast(t)) -> t 4// CHECK-LABEL: func @tensor_load_of_buffer_cast( 5// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> { 6// CHECK: return %[[TENSOR]] 7func @tensor_load_of_buffer_cast(%arg0: tensor<?xf32>) -> tensor<?xf32> { 8 %0 = memref.buffer_cast %arg0 : memref<?xf32> 9 %1 = memref.tensor_load %0 : memref<?xf32> 10 return %1 : tensor<?xf32> 11} 12 13// ----- 14 15// Test case: Basic folding of memref.buffer_cast(memref.tensor_load(m)) -> m 16// CHECK-LABEL: func @buffer_cast_of_tensor_load( 17// CHECK-SAME: %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> { 18// CHECK: return %[[MEMREF]] 19func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> { 20 %0 = memref.tensor_load %arg0 : memref<?xf32> 21 %1 = memref.buffer_cast %0 : memref<?xf32> 22 return %1 : memref<?xf32> 23} 24 25// ----- 26 27// Test case: If the memrefs are not the same type, don't fold them. 28// Test case: If the memrefs are not cast-compatible (e.g. different address space), 29// don't canonicalize them either. 30// CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load( 31// CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>) 32// CHECK-SAME: -> memref<?xf32, 7> { 33// CHECK: %[[TENSOR:.*]] = memref.tensor_load 34// CHECK_SAME: %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> 35// CHECK: %[[MEMREF_ADDRSPACE7:.*]] = memref.buffer_cast 36// CHECK_SAME: %[[TENSOR]] : memref<?xf32, 7> 37// CHECK: return %[[MEMREF_ADDRSPACE7]] 38func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref<?xf32, 7> { 39 %0 = memref.tensor_load %arg0 : memref<?xf32, 2> 40 %1 = memref.buffer_cast %0 : memref<?xf32, 7> 41 return %1 : memref<?xf32, 7> 42} 43 44// ----- 45 46// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)> 47// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)> 48 49// Test case: If the memrefs are cast-compatible, canonicalize. 50// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load( 51// CHECK-SAME: %[[M:.*]]: memref<?xf32, #[[$OFF_3]]>) 52// CHEKC-SAME: -> memref<?xf32, #[[$OFF_UNK]]> { 53// CHECK-NOT: memref.tensor_load 54// CHECK-NOT: memref.buffer_cast 55// CHECK: %[[R:.*]] = memref.cast %[[M]] 56// CHECK-SAME: memref<?xf32, #[[$OFF_3]]> to memref<?xf32, #[[$OFF_UNK]]> 57// CHECK: return %[[R]] 58func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref<?xf32, offset: 3, strides: [1]>) 59 -> memref<?xf32, offset: ?, strides: [1]> 60{ 61 %0 = memref.tensor_load %arg0 : memref<?xf32, offset: 3, strides: [1]> 62 %1 = memref.buffer_cast %0 : memref<?xf32, offset: ?, strides: [1]> 63 return %1 : memref<?xf32, offset: ?, strides: [1]> 64} 65 66// ----- 67 68// CHECK-LABEL: func @subview_of_memcast 69// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> 70// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> 71// CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> 72// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> 73func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> 74 memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ 75 %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8> 76 %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : 77 memref<?x?x16x32xi8> to 78 memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> 79 return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>> 80} 81 82// ----- 83 84// CHECK-LABEL: func @subview_of_static_full_size 85// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> 86// CHECK-NOT: memref.subview 87// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> 88func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { 89 %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> 90 return %0 : memref<4x6x16x32xi8> 91} 92 93// ----- 94 95#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> 96func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index, 97 %arg2 : index) -> memref<?x?x?xf32, #map0> 98{ 99 %c0 = constant 0 : index 100 %c1 = constant 1 : index 101 %c4 = constant 4 : index 102 %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map0> 103 return %0 : memref<?x?x?xf32, #map0> 104} 105// CHECK-LABEL: func @subview_canonicalize 106// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32> 107// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] 108// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] 109// CHECK-SAME: : memref<?x?x?xf32> to memref<4x1x?xf32 110// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] 111// CHEKC: return %[[RESULT]] 112 113// ----- 114 115#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> 116func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index, 117 %arg2 : index) -> memref<?x?xf32, #map0> 118{ 119 %c0 = constant 0 : index 120 %c1 = constant 1 : index 121 %c4 = constant 4 : index 122 %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, #map0> 123 return %0 : memref<?x?xf32, #map0> 124} 125// CHECK-LABEL: func @rank_reducing_subview_canonicalize 126// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32> 127// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] 128// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] 129// CHECK-SAME: : memref<?x?x?xf32> to memref<4x?xf32 130// CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] 131// CHECK: return %[[RESULT]] 132 133// ----- 134 135// CHECK-LABEL: @clone_before_dealloc 136// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 137func @clone_before_dealloc(%arg0: memref<?xf32>) -> memref<?xf32> { 138 // CHECK-NEXT: return %[[ARG]] 139 %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32> 140 memref.dealloc %arg0 : memref<?xf32> 141 return %0 : memref<?xf32> 142} 143 144// ----- 145 146// CHECK-LABEL: @clone_before_dealloc 147// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 148func @clone_before_dealloc(%arg0: memref<?xf32>) -> memref<?xf32> { 149 // CHECK-NEXT: "use"(%arg0) 150 // CHECK-NEXT: return %[[ARG]] 151 %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32> 152 "use"(%0) : (memref<?xf32>) -> () 153 memref.dealloc %0 : memref<?xf32> 154 return %arg0 : memref<?xf32> 155} 156 157// ----- 158 159// CHECK-LABEL: @clone_after_cast 160// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 161func @clone_after_cast(%arg0: memref<?xf32>) -> memref<32xf32> { 162 // CHECK-NEXT: memref.clone %[[ARG]] : memref<?xf32> to memref<32xf32> 163 // CHECK-NOT: memref.cast 164 %0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32> 165 %1 = memref.clone %0 : memref<32xf32> to memref<32xf32> 166 return %1 : memref<32xf32> 167} 168 169// ----- 170 171// CHECK-LABEL: @clone_and_cast 172// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 173func @clone_and_cast(%arg0: memref<?xf32>) -> memref<32xf32> { 174 // CHECK-NEXT: %[[RES:.*]] = memref.cast %[[ARG]] : memref<?xf32> to memref<32xf32> 175 %0 = memref.clone %arg0 : memref<?xf32> to memref<32xf32> 176 // CHECK-NEXT: return %[[RES]] 177 memref.dealloc %arg0 : memref<?xf32> 178 return %0 : memref<32xf32> 179} 180 181// ----- 182 183// CHECK-LABEL: @alias_is_freed 184func @alias_is_freed(%arg0 : memref<?xf32>) { 185 // CHECK: memref.clone 186 // CHECK: memref.dealloc 187 // CHECK: memref.dealloc 188 %0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32> 189 %1 = memref.clone %0 : memref<32xf32> to memref<32xf32> 190 memref.dealloc %arg0 : memref<?xf32> 191 "use"(%1) : (memref<32xf32>) -> () 192 memref.dealloc %1 : memref<32xf32> 193 return 194} 195 196// ----- 197 198// Verify SimplifyClones skips clones with multiple deallocations. 199// CHECK-LABEL: @clone_multiple_dealloc_of_source 200// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 201func @clone_multiple_dealloc_of_source(%arg0: memref<?xf32>) -> memref<?xf32> { 202 // CHECK-NEXT: %[[RES:.*]] = memref.clone %[[ARG]] 203 // CHECK: memref.dealloc %[[ARG]] 204 // CHECK: memref.dealloc %[[ARG]] 205 // CHECK: return %[[RES]] 206 %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32> 207 "if_else"() ({ 208 memref.dealloc %arg0 : memref<?xf32> 209 }, { 210 memref.dealloc %arg0 : memref<?xf32> 211 }) : () -> () 212 return %0 : memref<?xf32> 213} 214 215// ----- 216 217// CHECK-LABEL: @clone_multiple_dealloc_of_clone 218// CHECK-SAME: %[[ARG:.*]]: memref<?xf32> 219func @clone_multiple_dealloc_of_clone(%arg0: memref<?xf32>) -> memref<?xf32> { 220 // CHECK-NEXT: %[[CLONE:.*]] = memref.clone %[[ARG]] 221 // CHECK: memref.dealloc %[[CLONE]] 222 // CHECK: memref.dealloc %[[CLONE]] 223 // CHECK: return %[[ARG]] 224 %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32> 225 "use"(%0) : (memref<?xf32>) -> () 226 "if_else"() ({ 227 memref.dealloc %0 : memref<?xf32> 228 }, { 229 memref.dealloc %0 : memref<?xf32> 230 }) : () -> () 231 return %arg0 : memref<?xf32> 232} 233 234// ----- 235 236// CHECK-LABEL: func @dim_of_sized_view 237// CHECK-SAME: %{{[a-z0-9A-Z_]+}}: memref<?xi8> 238// CHECK-SAME: %[[SIZE:.[a-z0-9A-Z_]+]]: index 239// CHECK: return %[[SIZE]] : index 240func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index { 241 %c0 = constant 0 : index 242 %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [0] : memref<?xi8> to memref<?xi8> 243 %1 = memref.dim %0, %c0 : memref<?xi8> 244 return %1 : index 245} 246 247// ----- 248 249// CHECK-LABEL: func @no_fold_of_store 250// CHECK: %[[cst:.+]] = memref.cast %arg 251// CHECK: memref.store %[[cst]] 252func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) { 253 %0 = memref.cast %arg : memref<32xi8> to memref<?xi8> 254 memref.store %0, %holder[] : memref<memref<?xi8>> 255 return 256} 257 258// ----- 259 260// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs)) 261// -> tensor.extract(%v, %idx) 262// CHECK-LABEL: func @load_from_buffer_cast( 263// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index 264// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32> 265// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]] 266// CHECK-NOT: memref.load 267// CHECK: return %[[RES]] : f32 268func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 { 269 %0 = memref.buffer_cast %arg2 : memref<?x?xf32> 270 %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32> 271 return %1 : f32 272} 273 274// ----- 275 276 277// Test case: Basic folding of tensor.dim(memref.tensor_load(m)) -> memref.dim(m). 278// CHECK-LABEL: func @dim_of_tensor_load( 279// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32> 280// CHECK: %[[C0:.*]] = constant 0 281// CHECK: %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]] 282// CHECK: return %[[D]] : index 283func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index { 284 %c0 = constant 0 : index 285 %0 = memref.tensor_load %arg0 : memref<?xf32> 286 %1 = tensor.dim %0, %c0 : tensor<?xf32> 287 return %1 : index 288} 289 290// ----- 291 292// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size 293// CHECK-LABEL: func @dim_of_alloca( 294// CHECK-SAME: %[[SIZE:[0-9a-z]+]]: index 295// CHECK-NEXT: return %[[SIZE]] : index 296func @dim_of_alloca(%size: index) -> index { 297 %0 = memref.alloca(%size) : memref<?xindex> 298 %c0 = constant 0 : index 299 %1 = memref.dim %0, %c0 : memref<?xindex> 300 return %1 : index 301} 302 303// ----- 304 305// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v) 306// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size( 307// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> 308// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32> 309// CHECK-NEXT: return %[[RANK]] : index 310func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { 311 %0 = rank %arg0 : memref<*xf32> 312 %1 = memref.alloca(%0) : memref<?xindex> 313 %c0 = constant 0 : index 314 %2 = memref.dim %1, %c0 : memref<?xindex> 315 return %2 : index 316} 317 318// ----- 319 320// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] 321// CHECK-LABEL: func @dim_of_memref_reshape( 322// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, 323// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex> 324// CHECK-NEXT: %[[IDX:.*]] = constant 3 325// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] 326// CHECK-NEXT: memref.store 327// CHECK-NOT: memref.dim 328// CHECK: return %[[DIM]] : index 329func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>) 330 -> index { 331 %c3 = constant 3 : index 332 %0 = memref.reshape %arg0(%arg1) 333 : (memref<*xf32>, memref<?xindex>) -> memref<*xf32> 334 // Update the shape to test that he load ends up in the right place. 335 memref.store %c3, %arg1[%c3] : memref<?xindex> 336 %1 = memref.dim %0, %c3 : memref<*xf32> 337 return %1 : index 338} 339 340// ----- 341 342// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] 343// CHECK-LABEL: func @dim_of_memref_reshape_i32( 344// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, 345// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32> 346// CHECK-NEXT: %[[IDX:.*]] = constant 3 347// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] 348// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]] 349// CHECK-NOT: memref.dim 350// CHECK: return %[[CAST]] : index 351func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>) 352 -> index { 353 %c3 = constant 3 : index 354 %0 = memref.reshape %arg0(%arg1) 355 : (memref<*xf32>, memref<?xi32>) -> memref<*xf32> 356 %1 = memref.dim %0, %c3 : memref<*xf32> 357 return %1 : index 358} 359 360// ----- 361 362// CHECK-LABEL: func @tensor_cast_to_memref 363// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> 364// CHECK: %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8> 365// CHECK: %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8> 366// CHECK: return %[[M1]] : memref<?x?x16x32xi8> 367func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) -> 368 memref<?x?x16x32xi8> { 369 %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8> 370 %1 = memref.buffer_cast %0 : memref<?x?x16x32xi8> 371 return %1 : memref<?x?x16x32xi8> 372} 373 374// ----- 375 376// CHECK-LABEL: func @alloc_const_fold 377func @alloc_const_fold() -> memref<?xf32> { 378 // CHECK-NEXT: %0 = memref.alloc() : memref<4xf32> 379 %c4 = constant 4 : index 380 %a = memref.alloc(%c4) : memref<?xf32> 381 382 // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32> 383 // CHECK-NEXT: return %1 : memref<?xf32> 384 return %a : memref<?xf32> 385} 386 387// ----- 388 389// CHECK-LABEL: func @alloc_alignment_const_fold 390func @alloc_alignment_const_fold() -> memref<?xf32> { 391 // CHECK-NEXT: %0 = memref.alloc() {alignment = 4096 : i64} : memref<4xf32> 392 %c4 = constant 4 : index 393 %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref<?xf32> 394 395 // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32> 396 // CHECK-NEXT: return %1 : memref<?xf32> 397 return %a : memref<?xf32> 398} 399 400// ----- 401 402// CHECK-LABEL: func @alloc_const_fold_with_symbols1( 403// CHECK: %[[c1:.+]] = constant 1 : index 404// CHECK: %[[mem1:.+]] = memref.alloc({{.*}})[%[[c1]], %[[c1]]] : memref<?xi32, #map> 405// CHECK: return %[[mem1]] : memref<?xi32, #map> 406#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> 407func @alloc_const_fold_with_symbols1(%arg0 : index) -> memref<?xi32, #map0> { 408 %c1 = constant 1 : index 409 %0 = memref.alloc(%arg0)[%c1, %c1] : memref<?xi32, #map0> 410 return %0 : memref<?xi32, #map0> 411} 412 413// ----- 414 415// CHECK-LABEL: func @alloc_const_fold_with_symbols2( 416// CHECK: %[[c1:.+]] = constant 1 : index 417// CHECK: %[[mem1:.+]] = memref.alloc()[%[[c1]], %[[c1]]] : memref<1xi32, #map> 418// CHECK: %[[mem2:.+]] = memref.cast %[[mem1]] : memref<1xi32, #map> to memref<?xi32, #map> 419// CHECK: return %[[mem2]] : memref<?xi32, #map> 420#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> 421func @alloc_const_fold_with_symbols2() -> memref<?xi32, #map0> { 422 %c1 = constant 1 : index 423 %0 = memref.alloc(%c1)[%c1, %c1] : memref<?xi32, #map0> 424 return %0 : memref<?xi32, #map0> 425} 426 427// ----- 428// CHECK-LABEL: func @allocator 429// CHECK: %[[alloc:.+]] = memref.alloc 430// CHECK: memref.store %[[alloc:.+]], %arg0 431func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index) { 432 %0 = memref.alloc(%arg1) : memref<?xi32> 433 memref.store %0, %arg0[] : memref<memref<?xi32>> 434 return 435} 436 437// ----- 438 439func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>) 440 -> memref<f32> { 441 %0 = memref.collapse_shape %arg0 [[0, 1, 2]] 442 : memref<1x1x1xf32> into memref<1xf32> 443 %1 = memref.collapse_shape %0 [] : memref<1xf32> into memref<f32> 444 return %1 : memref<f32> 445} 446// CHECK-LABEL: collapsing_memref_reshapes_to_zero 447// CHECK: memref.collapse_shape %{{.*}} [] 448// CHECK-SAME: memref<1x1x1xf32> into memref<f32> 449 450// ----- 451 452func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) 453 -> memref<?x?xf32> { 454 %0 = memref.collapse_shape %arg0 [[0, 1], [2], [3, 4]] 455 : memref<?x?x?x?x?xf32> into memref<?x?x?xf32> 456 %1 = memref.collapse_shape %0 [[0, 1], [2]] 457 : memref<?x?x?xf32> into memref<?x?xf32> 458 return %1 : memref<?x?xf32> 459} 460// CHECK-LABEL: collapsing_memref_reshapes 461// CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] 462// CHECK-NOT: memref.collapse_shape 463 464// ----- 465 466func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) 467 -> memref<?x6x4x5x?xf32> { 468 %0 = memref.expand_shape %arg0 [[0, 1], [2]] 469 : memref<?x?xf32> into memref<?x4x?xf32> 470 %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]] 471 : memref<?x4x?xf32> into memref<?x6x4x5x?xf32> 472 return %1 : memref<?x6x4x5x?xf32> 473} 474// CHECK-LABEL: expanding_memref_reshapes 475// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] 476// CHECK-NOT: memref.expand_shape 477 478// ----- 479 480func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>) 481 -> memref<1x1x1xf32> { 482 %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32> 483 %1 = memref.expand_shape %0 [[0, 1, 2]] 484 : memref<1xf32> into memref<1x1x1xf32> 485 return %1 : memref<1x1x1xf32> 486} 487// CHECK-LABEL: expanding_memref_reshapes_to_zero 488// CHECK: memref.expand_shape %{{.*}} [] 489// CHECK-SAME: memref<f32> into memref<1x1x1xf32> 490 491// ----- 492 493func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> { 494 %0 = memref.expand_shape %arg0 [[0, 1], [2]] 495 : memref<12x4xf32> into memref<3x4x4xf32> 496 %1 = memref.collapse_shape %0 [[0, 1], [2]] 497 : memref<3x4x4xf32> into memref<12x4xf32> 498 return %1 : memref<12x4xf32> 499} 500// CHECK-LABEL: @fold_memref_reshape 501// CHECK-NOT: linalg.{{.*}}_shape 502 503// ----- 504 505func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> { 506 %0 = memref.expand_shape %arg0 [[0, 1], [2]] 507 : memref<?x?xf32> into memref<?x4x?xf32> 508 %1 = memref.collapse_shape %0 [[0, 1], [2]] 509 : memref<?x4x?xf32> into memref<?x?xf32> 510 return %1 : memref<?x?xf32> 511} 512// CHECK-LABEL: @fold_memref_reshape_dynamic 513// CHECK-NOT: linalg.{{.*}}_shape 514