1// RUN: mlir-opt %s -linalg-promote-subviews | FileCheck %s 2// RUN: mlir-opt %s -linalg-promote-subviews="test-promote-dynamic" | FileCheck %s --check-prefix=DYNAMIC 3// RUN: mlir-opt %s -linalg-promote-subviews="test-use-alloca" | FileCheck %s --check-prefix=ALLOCA 4 5#map1 = affine_map<(d0) -> (d0 + 2)> 6#map2 = affine_map<(d0) -> (d0 + 4)> 7#map3 = affine_map<(d0) -> (d0 + 3)> 8 9// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> 10 11func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) { 12 %c4 = constant 4 : index 13 %c3 = constant 3 : index 14 %c2 = constant 2 : index 15 %c0 = constant 0 : index 16 %c1 = constant 1 : index 17 %3 = memref.view %A[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32> 18 %4 = memref.view %A[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32> 19 %5 = memref.view %A[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32> 20 %6 = memref.dim %3, %c0 : memref<?x?xf32> 21 %7 = memref.dim %3, %c1 : memref<?x?xf32> 22 %8 = memref.dim %4, %c1 : memref<?x?xf32> 23 scf.for %arg4 = %c0 to %6 step %c2 { 24 scf.for %arg5 = %c0 to %8 step %c3 { 25 scf.for %arg6 = %c0 to %7 step %c4 { 26 %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]> 27 %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]> 28 %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]> 29 linalg.matmul 30 ins(%11, %14: memref<?x?xf32, offset: ?, strides: [?, 1]>, 31 memref<?x?xf32, offset: ?, strides: [?, 1]>) 32 outs(%17: memref<?x?xf32, offset: ?, strides: [?, 1]>) 33 } 34 } 35 } 36 return 37} 38 39// CHECK-LABEL: func @matmul_f32(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { 40// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { 41// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { 42// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { 43// CHECK: %[[vA:.*]] = memref.subview {{.*}} : memref<?x?xf32> 44// CHECK: %[[vB:.*]] = memref.subview {{.*}} : memref<?x?xf32> 45// CHECK: %[[vC:.*]] = memref.subview {{.*}} : memref<?x?xf32> 46/// 47// CHECK: %[[tmpA:.*]] = memref.alloc() : memref<32xi8> 48// ALLOCA: %[[tmpA:.*]] = memref.alloca() : memref<32xi8> 49// CHECK: %[[fullA:.*]] = memref.view %[[tmpA]][{{.*}}][{{.*}}] : memref<32xi8> to memref<?x?xf32> 50// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf32> 51// CHECK: %[[partialA:.*]] = memref.subview %[[fullA]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[$strided2D]]> 52/// 53// CHECK: %[[tmpB:.*]] = memref.alloc() : memref<48xi8> 54// ALLOCA: %[[tmpB:.*]] = memref.alloca() : memref<48xi8> 55// CHECK: %[[fullB:.*]] = memref.view %[[tmpB]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf32> 56// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf32> 57// CHECK: %[[partialB:.*]] = memref.subview %[[fullB]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[$strided2D]]> 58/// 59// CHECK: %[[tmpC:.*]] = memref.alloc() : memref<24xi8> 60// ALLOCA: %[[tmpC:.*]] = memref.alloca() : memref<24xi8> 61// CHECK: %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref<?x?xf32> 62// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf32> 63// CHECK: %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[$strided2D]]> 64 65// CHECK: linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]> 66// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]> 67// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]> 68// 69// CHECK: linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]] 70// 71// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : 72// CHECK: memref<?x?xf32, #[[$strided2D]]>, 73// CHECK: memref<?x?xf32, #[[$strided2D]]> 74// 75// CHECK: memref.dealloc %[[tmpA]] : memref<32xi8> 76// CHECK: memref.dealloc %[[tmpB]] : memref<48xi8> 77// CHECK: memref.dealloc %[[tmpC]] : memref<24xi8> 78// ALLOCA-NOT: memref.dealloc %[[tmpA]] : memref<32xi8> 79// ALLOCA-NOT: memref.dealloc %[[tmpB]] : memref<48xi8> 80// ALLOCA-NOT: memref.dealloc %[[tmpC]] : memref<24xi8> 81 82// ----- 83 84func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) { 85 %c4 = constant 4 : index 86 %c3 = constant 3 : index 87 %c2 = constant 2 : index 88 %c0 = constant 0 : index 89 %c1 = constant 1 : index 90 %3 = memref.view %A[%c0][%M, %K] : memref<?xi8> to memref<?x?xf64> 91 %4 = memref.view %A[%c0][%K, %N] : memref<?xi8> to memref<?x?xf64> 92 %5 = memref.view %A[%c0][%M, %N] : memref<?xi8> to memref<?x?xf64> 93 %6 = memref.dim %3, %c0 : memref<?x?xf64> 94 %7 = memref.dim %3, %c1 : memref<?x?xf64> 95 %8 = memref.dim %4, %c1 : memref<?x?xf64> 96 scf.for %arg4 = %c0 to %6 step %c2 { 97 scf.for %arg5 = %c0 to %8 step %c3 { 98 scf.for %arg6 = %c0 to %7 step %c4 { 99 %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]> 100 %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]> 101 %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]> 102 linalg.matmul 103 ins(%11, %14: memref<?x?xf64, offset: ?, strides: [?, 1]>, 104 memref<?x?xf64, offset: ?, strides: [?, 1]>) 105 outs(%17: memref<?x?xf64, offset: ?, strides: [?, 1]>) 106 } 107 } 108 } 109 return 110} 111 112// CHECK-LABEL: func @matmul_f64(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { 113// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { 114// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { 115// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { 116// CHECK: %[[vA_f64:.*]] = memref.subview {{.*}} : memref<?x?xf64> 117// CHECK: %[[vB_f64:.*]] = memref.subview {{.*}} : memref<?x?xf64> 118// CHECK: %[[vC_f64:.*]] = memref.subview {{.*}} : memref<?x?xf64> 119/// 120// CHECK: %[[tmpA_f64:.*]] = memref.alloc() : memref<64xi8> 121// CHECK: %[[fullA_f64:.*]] = memref.view %[[tmpA_f64]][{{.*}}][{{.*}}] : memref<64xi8> to memref<?x?xf64> 122// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf64> 123// CHECK: %[[partialA_f64:.*]] = memref.subview %[[fullA_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, #[[$strided2D]]> 124/// 125// CHECK: %[[tmpB_f64:.*]] = memref.alloc() : memref<96xi8> 126// CHECK: %[[fullB_f64:.*]] = memref.view %[[tmpB_f64]][{{.*}}][{{.*}}] : memref<96xi8> to memref<?x?xf64> 127// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf64> 128// CHECK: %[[partialB_f64:.*]] = memref.subview %[[fullB_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, #[[$strided2D]]> 129/// 130// CHECK: %[[tmpC_f64:.*]] = memref.alloc() : memref<48xi8> 131// CHECK: %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf64> 132// DYNAMIC: memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf64> 133// CHECK: %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, #[[$strided2D]]> 134 135// CHECK: linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D]]> 136// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D]]> 137// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D]]> 138// 139// CHECK: linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]] 140// 141// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : 142// CHECK: memref<?x?xf64, #[[$strided2D]]>, 143// CHECK: memref<?x?xf64, #[[$strided2D]]> 144// 145// CHECK: memref.dealloc %[[tmpA_f64]] : memref<64xi8> 146// CHECK: memref.dealloc %[[tmpB_f64]] : memref<96xi8> 147// CHECK: memref.dealloc %[[tmpC_f64]] : memref<48xi8> 148