1// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s 2 3// CHECK-LABEL: func @memref_cast( 4func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> { 5 %c0 = constant 0 : index 6 %c1 = constant 1 : index 7 %c8 = constant 8 : index 8 %c16 = constant 16 : index 9 %1 = memref.alloc (%b) : memref<?xi8> 10 %2 = memref.view %1[%c0][] : memref<?xi8> to memref<16x16xf32> 11 %3 = memref.cast %2 : memref<16x16xf32> to memref<?x?xf32> 12 13 // CHECK: linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>) 14 linalg.matmul ins(%3, %3: memref<?x?xf32>, memref<?x?xf32>) 15 outs(%3: memref<?x?xf32>) 16 return %3: memref<?x?xf32> 17} 18 19// ----- 20 21#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> 22 23// CHECK-LABEL: func @memref_cast_into_tiled_loop( 24func @memref_cast_into_tiled_loop(%arg0: memref<192xf32>) { 25 %0 = memref.cast %arg0 26 : memref<192xf32> to memref<192xf32, #map> 27 %cst = constant 0.000000e+00 : f32 28 %c24 = constant 24 : index 29 %c0 = constant 0 : index 30 %c192 = constant 192 : index 31 // CHECK: linalg.tiled_loop 32 // CHECK-SAME: outs (%{{.*}} = %{{.*}}: memref<192xf32>) 33 linalg.tiled_loop (%arg3) = (%c0) to (%c192) step (%c24) 34 outs (%out = %0: memref<192xf32, #map>) { 35 %14 = affine.min affine_map<(d0) -> (-d0 + 192, 24)>(%arg3) 36 %16 = memref.subview %out[%arg3] [%14] [1] 37 : memref<192xf32, #map> to memref<?xf32, #map> 38 linalg.fill(%cst, %16) : f32, memref<?xf32, #map> 39 linalg.yield 40 } 41 return 42} 43 44// ----- 45 46// CHECK-LABEL: zero_rank_reshape_multi 47func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> { 48 // CHECK: return %arg0 49 %0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32> 50 %1 = linalg.tensor_expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32> 51 %2 = linalg.tensor_collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32> 52 return %2 : tensor<f32> 53} 54 55// ----- 56 57func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> 58{ 59 %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4]] 60 : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32> 61 %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]] 62 : tensor<?x?x?xf32> into tensor<?x?xf32> 63 return %1 : tensor<?x?xf32> 64} 65// CHECK-LABEL: collapsing_tensor_reshapes 66// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] 67// CHECK-NOT: linalg.tensor_collapse_shape 68 69// ----- 70 71func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>) 72 -> tensor<f32> { 73 %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2]] 74 : tensor<1x1x1xf32> into tensor<1xf32> 75 %1 = linalg.tensor_collapse_shape %0 [] : tensor<1xf32> into tensor<f32> 76 return %1 : tensor<f32> 77} 78// CHECK-LABEL: collapsing_tensor_reshapes_to_zero 79// CHECK: linalg.tensor_collapse_shape %{{.*}} [] 80// CHECK-SAME: tensor<1x1x1xf32> into tensor<f32> 81 82// ----- 83 84func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x6x4x?x5xf32> 85{ 86 %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] 87 : tensor<?x?xf32> into tensor<?x4x?xf32> 88 %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4]] 89 : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32> 90 return %1 : tensor<?x6x4x?x5xf32> 91} 92// CHECK-LABEL: expanding_tensor_reshapes 93// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] 94// CHECK-NOT: linalg.tensor_expand_shape 95 96// ----- 97 98func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>) 99 -> tensor<1x1x1xf32> { 100 %0 = linalg.tensor_expand_shape %arg0 [] : tensor<f32> into tensor<1xf32> 101 %1 = linalg.tensor_expand_shape %0 [[0, 1, 2]] 102 : tensor<1xf32> into tensor<1x1x1xf32> 103 return %1 : tensor<1x1x1xf32> 104} 105// CHECK-LABEL: expanding_tensor_reshapes_to_zero 106// CHECK: linalg.tensor_expand_shape %{{.*}} [] 107// CHECK-SAME: tensor<f32> into tensor<1x1x1xf32> 108 109// ----- 110 111func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> 112{ 113 %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] 114 : tensor<12x4xf32> into tensor<3x4x4xf32> 115 %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]] 116 : tensor<3x4x4xf32> into tensor<12x4xf32> 117 return %1 : tensor<12x4xf32> 118} 119// CHECK-LABEL: @fold_tensor_reshape 120// CHECK-NOT: linalg.{{.*}}shape 121 122// ----- 123 124func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> 125{ 126 %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2]] 127 : tensor<?x?xf32> into tensor<?x4x?xf32> 128 %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2]] 129 : tensor<?x4x?xf32> into tensor<?x?xf32> 130 return %1 : tensor<?x?xf32> 131} 132// CHECK-LABEL: @fold_tensor_reshape_dynamic 133// CHECK-NOT: linalg.{{.*}}_shape 134 135// ----- 136 137func @reshape_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>) -> tensor<24x5x42x8xf32> 138{ 139 %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]] 140 : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32> 141 %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3]] 142 : tensor<40320xf32> into tensor<24x5x42x8xf32> 143 return %1 : tensor<24x5x42x8xf32> 144} 145// CHECK: func @reshape_collapse 146// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32> 147// CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[ARG0]] 148// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] 149// CHECK: return %[[RESULT]] 150 151// ----- 152 153func @reshape_expand(%arg0 : tensor<24x5x42x8xf32>) -> tensor<2x3x4x5x6x7x8xf32> 154{ 155 %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2, 3]] 156 : tensor<24x5x42x8xf32> into tensor<40320xf32> 157 %1 = linalg.tensor_expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] 158 : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32> 159 return %1 : tensor<2x3x4x5x6x7x8xf32> 160} 161// CHECK: func @reshape_expand 162// CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32> 163// CHECK: %[[RESULT:.+]] = linalg.tensor_expand_shape %[[ARG0]] 164// CHECK-SAME: [0, 1, 2], [3], [4, 5], [6] 165// CHECK: return %[[RESULT]] 166 167// ----- 168 169func @expand_reshape_1D(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32> 170{ 171 %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3]] 172 : tensor<2048xf32> into tensor<1x4x1x512xf32> 173 %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]] 174 : tensor<1x4x1x512xf32> into tensor<4x512xf32> 175 return %1 : tensor<4x512xf32> 176} 177// CHECK: func @expand_reshape_1D 178// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]] 179// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32> 180 181// ----- 182 183func @fold_reshape_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> 184{ 185 %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2], [3]] 186 : tensor<4x512xf32> into tensor<1x4x1x512xf32> 187 %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]] 188 : tensor<1x4x1x512xf32> into tensor<2048xf32> 189 return %1 : tensor<2048xf32> 190} 191// CHECK: func @fold_reshape_1D 192// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1]] 193// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32> 194 195// ----- 196 197func @fold_reshape_unit_dims(%arg0 : tensor<2048x1x1xf32>) -> tensor<4x512x1x1xf32> 198{ 199 %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] 200 : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32> 201 %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3], [4], [5]] 202 : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32> 203 return %1 : tensor<4x512x1x1xf32> 204} 205// CHECK: func @fold_reshape_unit_dims 206// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] 207// CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32> 208 209// ----- 210 211func @expand_reshape_unit_dims(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> 212{ 213 %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] 214 : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32> 215 %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]] 216 : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> 217 return %1 : tensor<4x512x1x512x4xf32> 218} 219// CHECK: func @expand_reshape_unit_dims 220// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] 221// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32> 222 223// ----- 224 225func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> 226{ 227 %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]] 228 : tensor<2xf32> into tensor<2x1x1xf32> 229 %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]] 230 : tensor<2x1x1xf32> into tensor<2x1xf32> 231 return %1 : tensor<2x1xf32> 232} 233// CHECK: func @fold_reshape_trailing_unit_dims 234// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]] 235// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> 236 237// ----- 238 239func @collapse_reshape_unit_dims_dynamic(%arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> 240{ 241 %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]] 242 : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32> 243 %1 = linalg.tensor_collapse_shape %0 [[0], [1], [2, 3, 4], [5]] 244 : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32> 245 return %1 : tensor<?x?x?x?xf32> 246} 247// CHECK: func @collapse_reshape_unit_dims_dynamic 248// CHECK: linalg.tensor_collapse_shape 249// CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8] 250// CHECK-SAME: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32> 251 252// ----- 253 254func @fold_reshape_trailing_unit_dims(%arg0: tensor<2xf32>) -> tensor<2x1xf32> 255{ 256 %0 = linalg.tensor_expand_shape %arg0 [[0, 1, 2]] 257 : tensor<2xf32> into tensor<2x1x1xf32> 258 %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2]] 259 : tensor<2x1x1xf32> into tensor<2x1xf32> 260 return %1 : tensor<2x1xf32> 261} 262// CHECK: func @fold_reshape_trailing_unit_dims 263// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1]] 264// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> 265 266// ----- 267 268func @fold_reshape_trailing_unit_dims_dynamic(%arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> 269{ 270 %0 = linalg.tensor_collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]] 271 : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32> 272 %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2, 3]] 273 : tensor<?x1x1x1xf32> into tensor<?xf32> 274 return %1 : tensor<?xf32> 275} 276// CHECK: func @fold_reshape_trailing_unit_dims_dynamic 277// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]] 278// CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor<?xf32> 279 280// ----- 281 282func @no_fold_reshapes(%arg0 : tensor<?x?x?xf32>) -> tensor<?x?xf32> 283{ 284 %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]] 285 : tensor<?x?x?xf32> into tensor<?x?x1x?xf32> 286 %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3]] 287 : tensor<?x?x1x?xf32> into tensor<?x?xf32> 288 return %1 : tensor<?x?xf32> 289} 290// CHECK-LABEL: func @no_fold_reshapes 291// CHECK: linalg.tensor_expand_shape 292// CHECK: linalg.tensor_collapse_shape 293 294// ----- 295 296func @no_fold_reshape_incompatible(%arg0 : tensor<4x6x8xf32>) -> tensor<2x6x16xf32> 297{ 298 %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2, 3], [4]] 299 : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32> 300 %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2], [3, 4]] 301 : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32> 302 return %1 : tensor<2x6x16xf32> 303} 304// CHECK-LABEL: func @no_fold_reshape_incompatible 305// CHECK: linalg.tensor_expand_shape 306// CHECK: linalg.tensor_collapse_shape 307 308// ----- 309 310func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { 311 %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]] 312 : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> 313 %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3]] 314 : tensor<3x2x2x1xf32> into tensor<12x1xf32> 315 return %1 : tensor<12x1xf32> 316} 317// CHECK: func @no_fold_reshape_empty_expr 318// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> 319// CHECK: %[[RARG0:.+]] = linalg.tensor_expand_shape %[[ARG0]] 320// CHECK-SAME: [0], [1], [2, 3] 321// CHECK: %[[RES:.+]] = linalg.tensor_collapse_shape %[[RARG0]] 322// CHECK-SAME: [0, 1, 2], [3] 323// CHECK: return %[[RES:.+]] : tensor<12x1xf32> 324 325// ----- 326 327#accesses = [ 328 affine_map<(i) -> (i)> 329] 330 331#trait = { 332 indexing_maps = #accesses, 333 iterator_types = ["parallel"] 334} 335 336func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { 337 // memref<0x32> is expected to be dce'ed 338 linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32> 339 340 // tensor<0xf32> cannot be dce'ed 341 %1 = linalg.generic #trait outs(%arg1 : tensor<0xf32>) { 342 ^bb(%0: f32) : 343 linalg.yield %0 : f32 344 } -> tensor<0xf32> 345 346 return %1: tensor<0xf32> 347} 348// CHECK-LABEL: @dce_zero_memref 349// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32> 350// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32> 351// CHECK-NOT: linalg.copy 352// CHECK-NEXT: return %[[ARG1]] 353 354// ----- 355 356func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> 357{ 358 %c0 = constant dense<42> : tensor<2x8xi32> 359 %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] 360 : tensor<2x8xi32> into tensor<2x4x2xi32> 361 return %0 : tensor<2x4x2xi32> 362} 363// CHECK-LABEL: @reshape_splat_constant_int32 364// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32> 365// CHECK-NOT: linalg.tensor_expand_shape 366// CHECK: return %[[CST]] 367 368func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> 369{ 370 %c0 = constant dense<42> : tensor<2x8xi16> 371 %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] 372 : tensor<2x8xi16> into tensor<2x4x2xi16> 373 return %0 : tensor<2x4x2xi16> 374} 375// CHECK-LABEL: @reshape_splat_constant_int16 376// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16> 377// CHECK-NOT: linalg.tensor_expand_shape 378// CHECK: return %[[CST]] 379 380func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> 381{ 382 %c0 = constant dense<42.0> : tensor<2x8xf32> 383 %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] 384 : tensor<2x8xf32> into tensor<2x4x2xf32> 385 return %0 : tensor<2x4x2xf32> 386} 387// CHECK-LABEL: @reshape_splat_constant_float32 388// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32> 389// CHECK-NOT: linalg.tensor_expand_shape 390// CHECK: return %[[CST]] 391 392func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> 393{ 394 %c0 = constant dense<42.0> : tensor<2x8xf64> 395 %0 = linalg.tensor_expand_shape %c0 [[0], [1, 2]] 396 : tensor<2x8xf64> into tensor<2x4x2xf64> 397 return %0 : tensor<2x4x2xf64> 398} 399// CHECK-LABEL: @reshape_splat_constant_float64 400// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64> 401// CHECK-NOT: linalg.tensor_expand_shape 402// CHECK: return %[[CST]] 403 404// ----- 405 406// CHECK-LABEL: func @tensor.cast( 407func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>) 408 -> tensor<3x?xf32> 409{ 410 %ta = tensor.cast %a : tensor<3x4xf32> to tensor<?x?xf32> 411 %tb = tensor.cast %b : tensor<4x?xf32> to tensor<?x?xf32> 412 %tc = tensor.cast %c : tensor<3x?xf32> to tensor<?x?xf32> 413 414 // CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>) 415 // CHECK-SAME: outs({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32> 416 %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>) 417 outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32> 418 419 %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<3x?xf32> 420 421 return %1: tensor<3x?xf32> 422} 423 424// ----- 425 426// CHECK-LABEL: func @linalg_effects( 427// CHECK-SAME: %[[A:[a-z0-9]*]]: tensor<?x?xf32> 428// CHECK-SAME: %[[B:[a-z0-9]*]]: memref<?x?xf32> 429// CHECK-SAME: %[[C:[a-z0-9]*]]: tensor<?x?xf32> 430func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) { 431 // CHECK-NOT: %{{.*}} = linalg.matmul 432 %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>) 433 outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32> 434 435 // CHECK: linalg.matmul 436 linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>) 437 outs(%b : memref<?x?xf32>) 438 return 439} 440 441// ----- 442 443func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) { 444 %c6 = constant 6 : index 445 %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32> 446 return %0 : tensor<4x5x?xf32> 447} 448// CHECK: func @init_tensor_canonicalize 449// CHECK: %[[T0:.+]] = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32> 450// CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32> 451// CHECK: return %[[T1]] 452 453// ----- 454 455func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { 456 %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32> 457 %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]] 458 : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> 459 return %1 : tensor<2x3x5x4x?x7xf32> 460} 461// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> 462// CHECK: func @init_tensor_reshape_expansion 463// CHECK-SAME: %[[ARG0:.+]]: index 464// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] 465// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[D]], 7] 466// CHECK-NEXT: return %[[INIT]] 467 468// ----- 469 470func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { 471 %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32> 472 %1 = linalg.tensor_collapse_shape %0 [[0, 1], [2], [3, 4, 5]] 473 : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> 474 return %1 : tensor<6x5x?xf32> 475} 476// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> 477// CHECK: func @init_tensor_reshape_collapse 478// CHECK-SAME: %[[ARG0:.+]]: index 479// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] 480// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [6, 5, %[[D]]] 481// CHECK-NEXT: return %[[INIT]] 482 483// ----- 484 485#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 486func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) 487 -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) { 488 %c0 = constant 0 : index 489 %c1 = constant 1 : index 490 %c2 = constant 2 : index 491 %0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> 492 %1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> 493 %2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32> 494 %3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32> 495 %4, %5 = linalg.generic { 496 indexing_maps = [#map, #map, #map, #map], 497 iterator_types = ["parallel", "parallel", "parallel"] 498 } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) 499 outs(%3, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) { 500 ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32): 501 linalg.yield %arg3, %arg2 : f32, f32 502 } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) 503 return %4, %5 : tensor<?x?x?xf32>, tensor<?x?x?xf32> 504} 505// CHECK-LABEL: func @remove_no_op 506// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 507// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> 508// CHECK: return %[[ARG1]], %[[ARG0]] 509 510// ----- 511 512#map = affine_map<(d0, d1) -> (d0, d1)> 513func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> { 514 %c0 = constant 0 : index 515 %c1 = constant 1 : index 516 %cst = constant 1.000000e+00 : f32 517 %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 518 %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 519 %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32> 520 br ^bb1(%cst : f32) 521 522^bb1(%arg1 : f32): 523 %3 = linalg.generic 524 {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} 525 ins(%arg0 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) { 526 ^bb0(%arg2: f32, %arg3 : f32): 527 linalg.yield %arg1 : f32 528 } -> tensor<?x?xf32> 529 return %3 : tensor<?x?xf32> 530} 531// CHECK-LABEL: func @keep_not_noop 532// CHECK: %[[RESULT:.+]] = linalg.generic 533// CHECK: return %[[RESULT]] 534 535// ----- 536 537#map = affine_map<(d0, d1) -> (d0, d1)> 538func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) 539 -> (tensor<?x?xf32>, tensor<?x?xf32>) { 540 %c0 = constant 0 : index 541 %c1 = constant 1 : index 542 %cst = constant 1.000000e+00 : f32 543 %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 544 %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 545 %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32> 546 br ^bb1(%cst : f32) 547 548^bb1(%arg2 : f32): 549 %3:2 = linalg.generic 550 {indexing_maps = [#map, #map, #map, #map], 551 iterator_types = ["parallel", "parallel"]} 552 ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) 553 outs(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>) { 554 ^bb0(%arg3: f32, %arg4 : f32, %arg5 : f32, %arg6 : f32): 555 linalg.yield %arg2, %arg4 : f32, f32 556 } -> (tensor<?x?xf32>, tensor<?x?xf32>) 557 return %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32> 558} 559// CHECK-LABEL: func @keep_not_noop 560// CHECK: %[[RESULT:.+]]:2 = linalg.generic 561// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 562 563// ----- 564 565func @fold_init_tensor_with_slice 566 (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> 567{ 568 %0 = linalg.init_tensor[%arg0, 10, 40] : tensor<?x10x40xf32> 569 %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] 570 : tensor<?x10x40xf32> to tensor<5x?x20xf32> 571 return %1 : tensor<5x?x20xf32> 572} 573// CHECK: func @fold_init_tensor_with_slice 574// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index 575// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index 576// CHECK: %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20] 577// CHECK: return %[[T0]] 578 579// ----- 580 581#accesses = [ 582 affine_map<(i, j) -> (i, j)> 583] 584 585#trait = { 586 indexing_maps = #accesses, 587 iterator_types = ["parallel", "parallel"] 588} 589 590// CHECK-LABEL: func @dead_linalg_tensor 591// CHECK-NOT: linalg.fill 592// CHECK-NOT: linalg.matmul 593// CHECK-NOT: linalg.generic 594// CHECK-NOT: linalg.pad_tensor 595// CHECK: return 596func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>, 597 %arg2: tensor<?x?xf32>, %high : index) { 598 %c0_i32 = constant 0 : i32 599 %c0 = constant 0 : index 600 %cst = constant 0.000000e+00 : f32 601 %0 = linalg.fill(%c0_i32, %arg0) : i32, tensor<7x7xi32> -> tensor<7x7xi32> 602 %1 = linalg.matmul ins(%arg1, %arg1: tensor<7x7xf32>, tensor<7x7xf32>) 603 outs(%arg1: tensor<7x7xf32>) -> tensor<7x7xf32> 604 %2 = linalg.generic #trait outs(%arg0 : tensor<7x7xi32>) { 605 ^bb(%3: i32) : 606 linalg.yield %3 : i32 607 } -> tensor<7x7xi32> 608 %3 = linalg.pad_tensor %arg2 low[%c0, %c0] high[%high, %high] { 609 ^bb0(%arg9: index, %arg10: index): // no predecessors 610 linalg.yield %cst : f32 611 } : tensor<?x?xf32> to tensor<2x4xf32> 612 return 613} 614 615// ----- 616 617// CHECK-LABEL: func @pad_tensor_same_static_shape( 618// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> 619// CHECK-NOT: linalg.pad_tensor 620// CHECK: return %[[ARG0]] 621func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index) 622 -> tensor<5x6xf32> { 623 %cst = constant 0.000000e+00 : f32 624 %0 = linalg.pad_tensor %arg0 low[%a, 0] high[0, %a] { 625 ^bb0(%arg1: index, %arg2: index): 626 linalg.yield %cst : f32 627 } : tensor<5x6xf32> to tensor<5x6xf32> 628 return %0 : tensor<5x6xf32> 629} 630func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index, 631 %arg3 : index) -> tensor<?x?xf32> { 632 %c0 = constant 0 : index 633 %c1 = constant 1 : index 634 %c21 = constant 21 : index 635 %c42 = constant 42 : index 636 %0 = linalg.init_tensor [%c21, %c42] : tensor<?x?xf32> 637 %1 = linalg.fill(%arg1, %0) : f32, tensor<?x?xf32> -> tensor<?x?xf32> 638 %2 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 639 %3 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 640 %4 = tensor.insert_slice %arg0 into %1[%arg2, %arg3] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32> 641 return %4 : tensor<?x?xf32> 642} 643// CHECK-LABEL: func @propogate_casts 644// CHECK: %[[INIT:.+]] = linalg.init_tensor [21, 42] 645// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) 646// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %{{.+}} into %[[FILL]] 647// CHECK: %[[RESULT:.+]] = tensor.cast %[[INSERTED]] 648// CHECK: return %[[RESULT]] 649 650// ----- 651 652// CHECK-LABEL: @self_copy 653func @self_copy(%arg0 : memref<2x3x?x4xf32>) { 654 655// CHECK-NOT: linalg.copy 656 linalg.copy(%arg0, %arg0): memref<2x3x?x4xf32>, memref<2x3x?x4xf32> 657 658// CHECK: return 659 return 660} 661 662// ----- 663 664// CHECK-LABEL: func @fold_fill_reshape() 665func @fold_fill_reshape() -> tensor<6x4xf32> { 666 %zero = constant 0.0 : f32 667 // CHECK: %[[INIT:.+]] = linalg.init_tensor [6, 4] : tensor<6x4xf32> 668 %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32> 669 // CHECK: %[[FILL:.+]] = linalg.fill(%cst, %[[INIT]]) : f32, tensor<6x4xf32> -> tensor<6x4xf32> 670 %fill = linalg.fill(%zero, %init) : f32, tensor<1x2x3x4xf32> -> tensor<1x2x3x4xf32> 671 %reshape = linalg.tensor_collapse_shape %fill [[0, 1, 2], [3]] 672 : tensor<1x2x3x4xf32> into tensor<6x4xf32> 673 // CHECK: return %[[FILL]] : tensor<6x4xf32> 674 return %reshape : tensor<6x4xf32> 675} 676 677// ----- 678 679// CHECK: func @fold_fill_reshape_dynamic 680// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?xf32> 681func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> { 682 %zero = constant 0.0 : f32 683 // CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] 684 %0 = linalg.fill(%zero, %arg0) : f32, tensor<?x?x?x?x?xf32> -> tensor<?x?x?x?x?xf32> 685 // CHECK: %[[RESULT:.+]] = linalg.fill(%{{.+}}, %[[RESHAPE]]) 686 %1 = linalg.tensor_collapse_shape %0 [[0, 1, 2], [3, 4]] 687 : tensor<?x?x?x?x?xf32> into tensor<?x?xf32> 688 // CHECK: return %[[RESULT]] 689 return %1 : tensor<?x?xf32> 690} 691 692 693// ----- 694 695func private @foo(%A: memref<48xf32>, %B: tensor<48xf32>, 696 %C: memref<48xf32>) -> (tensor<48xf32>) 697 698func @fold_tiled_loop_results(%A: memref<48xf32>, %B: tensor<48xf32>, 699 %C: memref<48xf32>, %C_tensor: tensor<48xf32>) -> tensor<48xf32> { 700 %c0 = constant 0 : index 701 %c24 = constant 24 : index 702 %c48 = constant 48 : index 703 %useful, %useless = linalg.tiled_loop (%i) = (%c0) to (%c48) step (%c24) 704 ins (%A_ = %A: memref<48xf32>) 705 outs (%B_ = %B: tensor<48xf32>, 706 %CT_ = %C_tensor: tensor<48xf32>, 707 %C_ = %C: memref<48xf32>) { 708 %result = call @foo(%A_, %B_, %C_) 709 : (memref<48xf32>, tensor<48xf32>, memref<48xf32>)-> (tensor<48xf32>) 710 linalg.yield %result, %CT_ : tensor<48xf32>, tensor<48xf32> 711 } 712 return %useful : tensor<48xf32> 713} 714 715// CHECK-LABEL: func @fold_tiled_loop_results( 716// CHECK-SAME: %[[A:.*]]: [[BUF_TY:memref<48xf32>]], %[[B:.*]]: [[TY:tensor<48xf32>]], 717// CHECK-SAME: %[[C:.*]]: [[BUF_TY]], %[[C_TENSOR:.*]]: [[TY]]) -> [[TY]] { 718 719// CHECK-DAG: %[[C0:.*]] = constant 0 : index 720// CHECK-DAG: %[[C24:.*]] = constant 24 : index 721// CHECK-DAG: %[[C48:.*]] = constant 48 : index 722 723// CHECK-NOT: %{{.*}} = linalg.tiled_loop 724// CHECK: %[[RESULT:.*]] = linalg.tiled_loop (%{{.*}}) = (%[[C0]]) 725// CHECK-SAME: to (%[[C48]]) step (%[[C24]]) 726// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: [[BUF_TY]]) 727// CHECK-SAME: outs (%[[B_:.*]] = %[[B]]: [[TY]], %[[C_:.*]] = %[[C]]: [[BUF_TY]]) { 728// CHECK-NEXT: %[[RES:.*]] = call @foo(%[[A_]], %[[B_]], %[[C_]]) 729// CHECK-NEXT: linalg.yield %[[RES]] : 730 731// CHECK: return %[[RESULT]] 732 733// ----- 734 735func private @foo(%A: memref<192xf32>, %B: tensor<192xf32>) -> tensor<192xf32> 736 737func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>, 738 %B_tensor: tensor<192xf32>) -> tensor<192xf32> { 739 %c0 = constant 0 : index 740 %c24 = constant 24 : index 741 %c192 = constant 192 : index 742 %result = linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24) 743 ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) 744 outs (%BT_ = %B_tensor: tensor<192xf32>) { 745 %0 = call @foo(%A_, %BT_) : (memref<192xf32>, tensor<192xf32>) -> tensor<192xf32> 746 linalg.yield %0 : tensor<192xf32> 747 } 748 return %result : tensor<192xf32> 749} 750 751// CHECK-LABEL: func @fold_tiled_loop_inputs 752// CHECK: %[[RESULT:.*]] = linalg.tiled_loop 753// CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>) 754 755// CHECK: return %[[RESULT]] 756 757// ----- 758 759func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { 760 %c0 = constant 0 : index 761 %cst = constant 0.0 : f32 762 %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32> 763 %1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %c0] { 764 ^bb0(%arg1: index, %arg2: index): // no predecessors 765 linalg.yield %cst : f32 766 } : tensor<?x?xf32> to tensor<4x4xf32> 767 return %1 : tensor<4x4xf32> 768} 769// CHECK-LABEL: @tensor_pad_cast 770// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32> 771// CHECK: return %[[ARG0]] 772 773// ----- 774 775// CHECK-LABEL: func @fold_pad_tensor_source_cast( 776// CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32> 777// CHECK-NOT: tensor.cast 778// CHECK: %[[RESULT:.*]] = linalg.pad_tensor %[[ARG0]] 779func @fold_pad_tensor_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> { 780 %cst = constant 0.0 : f32 781 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32> 782 %1 = linalg.pad_tensor %0 low[0, 0] high[0, 1] { 783 ^bb0(%arg1: index, %arg2: index): // no predecessors 784 linalg.yield %cst : f32 785 } : tensor<?x?xf32> to tensor<4x4xf32> 786 return %1 : tensor<4x4xf32> 787} 788 789// ----- 790 791// CHECK-LABEL: func @pad_static_zero_cast( 792// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32> 793// CHECK-NOT: linalg.pad_tensor 794// CHECK: %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32> 795// CHECK: return %[[RESULT]] 796func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> { 797 %c0 = constant 0 : index 798 %0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] { 799 ^bb0(%arg1: index, %arg2: index, %arg3: index): 800 linalg.yield %pad_value : f32 801 } : tensor<?x?x?xf32> to tensor<2x3x4xf32> 802 803 return %0 : tensor<2x3x4xf32> 804} 805 806// ----- 807 808func private @some_use(%i : index, %j : index) 809 810// CHECK-LABEL: func @init_canonicalize 811// CHECK-SAME: %[[I:.*]]: index 812func @init_canonicalize(%i : index) { 813 %c0 = constant 0 : index 814 %c1 = constant 1 : index 815 816 // CHECK-NOT: init_tensor 817 %0 = linalg.init_tensor [%i, 42] : tensor<?x42xf32> 818 819 // CHECK-NOT: tensor.dim 820 %1 = tensor.dim %0, %c0: tensor<?x42xf32> 821 %2 = tensor.dim %0, %c1: tensor<?x42xf32> 822 823 // CHECK: %[[c42:.*]] = constant 42 : index 824 // CHECK: call @some_use(%[[I]], %[[c42]]) 825 call @some_use(%1, %2) : (index, index) -> () 826 827 return 828} 829 830// ----- 831 832// CHECK-LABEL: func @rank_reducing_init_extract 833func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> { 834 // CHECK: linalg.init_tensor [2] : tensor<2xf32> 835 %a = linalg.init_tensor [%sz, 2] : tensor<?x2xf32> 836 837 // CHECK-NOT: extract 838 %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32> 839 return %r: tensor<2xf32> 840} 841 842// ----- 843 844// CHECK-LABEL: func @dim_of_pad_tensor( 845// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32> 846// CHECK: %[[C0:.*]] = constant 0 : index 847// CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG1]], %[[C0]] 848// CHECK: return %[[RESULT]] 849func @dim_of_pad_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, 850 %pad_value: f32) -> index { 851 %c0 = constant 0 : index 852 %0 = linalg.pad_tensor %arg0 low[2, 3] high[4, 5] into %arg1 { 853 ^bb0(%arg2: index, %arg3: index): 854 linalg.yield %pad_value : f32 855 } : tensor<?x?xf32> to tensor<?x?xf32> 856 %r = tensor.dim %0, %c0 : tensor<?x?xf32> 857 return %r : index 858} 859