1// RUN: mlir-opt %s -linalg-promote-subviews | FileCheck %s
2// RUN: mlir-opt %s -linalg-promote-subviews="test-promote-dynamic" | FileCheck %s --check-prefix=DYNAMIC
3// RUN: mlir-opt %s -linalg-promote-subviews="test-use-alloca" | FileCheck %s --check-prefix=ALLOCA
4
5#map1 = affine_map<(d0) -> (d0 + 2)>
6#map2 = affine_map<(d0) -> (d0 + 4)>
7#map3 = affine_map<(d0) -> (d0 + 3)>
8
9// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
10
11func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
12  %c4 = constant 4 : index
13  %c3 = constant 3 : index
14  %c2 = constant 2 : index
15  %c0 = constant 0 : index
16  %c1 = constant 1 : index
17  %3 = memref.view %A[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32>
18  %4 = memref.view %A[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32>
19  %5 = memref.view %A[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32>
20  %6 = memref.dim %3, %c0 : memref<?x?xf32>
21  %7 = memref.dim %3, %c1 : memref<?x?xf32>
22  %8 = memref.dim %4, %c1 : memref<?x?xf32>
23  scf.for %arg4 = %c0 to %6 step %c2 {
24    scf.for %arg5 = %c0 to %8 step %c3 {
25      scf.for %arg6 = %c0 to %7 step %c4 {
26        %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
27        %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
28        %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
29        linalg.matmul
30          ins(%11, %14: memref<?x?xf32, offset: ?, strides: [?, 1]>,
31                        memref<?x?xf32, offset: ?, strides: [?, 1]>)
32         outs(%17: memref<?x?xf32, offset: ?, strides: [?, 1]>)
33      }
34    }
35  }
36  return
37}
38
39// CHECK-LABEL: func @matmul_f32(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
40//       CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
41//       CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
42//       CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
43//       CHECK:         %[[vA:.*]] = memref.subview {{.*}} : memref<?x?xf32>
44//       CHECK:         %[[vB:.*]] = memref.subview {{.*}} : memref<?x?xf32>
45//       CHECK:         %[[vC:.*]] = memref.subview {{.*}} : memref<?x?xf32>
46///
47//       CHECK:         %[[tmpA:.*]] = memref.alloc() : memref<32xi8>
48//      ALLOCA:         %[[tmpA:.*]] = memref.alloca() : memref<32xi8>
49//       CHECK:         %[[fullA:.*]] = memref.view %[[tmpA]][{{.*}}][{{.*}}] : memref<32xi8> to memref<?x?xf32>
50//     DYNAMIC:         memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf32>
51//       CHECK:         %[[partialA:.*]] = memref.subview %[[fullA]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[$strided2D]]>
52///
53//       CHECK:         %[[tmpB:.*]] = memref.alloc() : memref<48xi8>
54//      ALLOCA:         %[[tmpB:.*]] = memref.alloca() : memref<48xi8>
55//       CHECK:         %[[fullB:.*]] = memref.view %[[tmpB]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf32>
56//     DYNAMIC:         memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf32>
57//       CHECK:         %[[partialB:.*]] = memref.subview %[[fullB]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[$strided2D]]>
58///
59//       CHECK:         %[[tmpC:.*]] = memref.alloc() : memref<24xi8>
60//      ALLOCA:         %[[tmpC:.*]] = memref.alloca() : memref<24xi8>
61//       CHECK:         %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref<?x?xf32>
62//     DYNAMIC:         memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf32>
63//       CHECK:         %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[$strided2D]]>
64
65//       CHECK:         linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>
66//       CHECK:         linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>
67//       CHECK:         linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[$strided2D]]>, memref<?x?xf32, #[[$strided2D]]>
68//
69//       CHECK:         linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
70//
71//       CHECK:         linalg.copy(%[[partialC]], %[[vC]]) :
72//       CHECK:           memref<?x?xf32, #[[$strided2D]]>,
73//       CHECK:           memref<?x?xf32, #[[$strided2D]]>
74//
75//       CHECK:         memref.dealloc %[[tmpA]] : memref<32xi8>
76//       CHECK:         memref.dealloc %[[tmpB]] : memref<48xi8>
77//       CHECK:         memref.dealloc %[[tmpC]] : memref<24xi8>
78//  ALLOCA-NOT:         memref.dealloc %[[tmpA]] : memref<32xi8>
79//  ALLOCA-NOT:         memref.dealloc %[[tmpB]] : memref<48xi8>
80//  ALLOCA-NOT:         memref.dealloc %[[tmpC]] : memref<24xi8>
81
82// -----
83
84func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
85  %c4 = constant 4 : index
86  %c3 = constant 3 : index
87  %c2 = constant 2 : index
88  %c0 = constant 0 : index
89  %c1 = constant 1 : index
90  %3 = memref.view %A[%c0][%M, %K] : memref<?xi8> to memref<?x?xf64>
91  %4 = memref.view %A[%c0][%K, %N] : memref<?xi8> to memref<?x?xf64>
92  %5 = memref.view %A[%c0][%M, %N] : memref<?xi8> to memref<?x?xf64>
93  %6 = memref.dim %3, %c0 : memref<?x?xf64>
94  %7 = memref.dim %3, %c1 : memref<?x?xf64>
95  %8 = memref.dim %4, %c1 : memref<?x?xf64>
96  scf.for %arg4 = %c0 to %6 step %c2 {
97    scf.for %arg5 = %c0 to %8 step %c3 {
98      scf.for %arg6 = %c0 to %7 step %c4 {
99        %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
100        %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
101        %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
102        linalg.matmul
103          ins(%11, %14: memref<?x?xf64, offset: ?, strides: [?, 1]>,
104                        memref<?x?xf64, offset: ?, strides: [?, 1]>)
105         outs(%17: memref<?x?xf64, offset: ?, strides: [?, 1]>)
106      }
107    }
108  }
109  return
110}
111
112// CHECK-LABEL: func @matmul_f64(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
113//       CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
114//       CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
115//       CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
116//       CHECK:         %[[vA_f64:.*]] = memref.subview {{.*}} : memref<?x?xf64>
117//       CHECK:         %[[vB_f64:.*]] = memref.subview {{.*}} : memref<?x?xf64>
118//       CHECK:         %[[vC_f64:.*]] = memref.subview {{.*}} : memref<?x?xf64>
119///
120//       CHECK:         %[[tmpA_f64:.*]] = memref.alloc() : memref<64xi8>
121//       CHECK:         %[[fullA_f64:.*]] = memref.view %[[tmpA_f64]][{{.*}}][{{.*}}] : memref<64xi8> to memref<?x?xf64>
122//     DYNAMIC:         memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf64>
123//       CHECK:         %[[partialA_f64:.*]] = memref.subview %[[fullA_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, #[[$strided2D]]>
124///
125//       CHECK:         %[[tmpB_f64:.*]] = memref.alloc() : memref<96xi8>
126//       CHECK:         %[[fullB_f64:.*]] = memref.view %[[tmpB_f64]][{{.*}}][{{.*}}] : memref<96xi8> to memref<?x?xf64>
127//     DYNAMIC:         memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf64>
128//       CHECK:         %[[partialB_f64:.*]] = memref.subview %[[fullB_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, #[[$strided2D]]>
129///
130//       CHECK:         %[[tmpC_f64:.*]] = memref.alloc() : memref<48xi8>
131//       CHECK:         %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref<?x?xf64>
132//     DYNAMIC:         memref.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf64>
133//       CHECK:         %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<?x?xf64> to memref<?x?xf64, #[[$strided2D]]>
134
135//       CHECK:         linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D]]>
136//       CHECK:         linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D]]>
137//       CHECK:         linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[$strided2D]]>, memref<?x?xf64, #[[$strided2D]]>
138//
139//       CHECK:         linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
140//
141//       CHECK:         linalg.copy(%[[partialC_f64]], %[[vC_f64]]) :
142//       CHECK:           memref<?x?xf64, #[[$strided2D]]>,
143//       CHECK:           memref<?x?xf64, #[[$strided2D]]>
144//
145//       CHECK:         memref.dealloc %[[tmpA_f64]] : memref<64xi8>
146//       CHECK:         memref.dealloc %[[tmpB_f64]] : memref<96xi8>
147//       CHECK:         memref.dealloc %[[tmpC_f64]] : memref<48xi8>
148