1// RUN: mlir-opt %s -linalg-promote-subviews | FileCheck %s
2// RUN: mlir-opt %s -linalg-promote-subviews -test-linalg-promote-dynamic | FileCheck %s --check-prefix=DYNAMIC
3
4#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
5#map1 = affine_map<(d0) -> (d0 + 2)>
6#map2 = affine_map<(d0) -> (d0 + 4)>
7#map3 = affine_map<(d0) -> (d0 + 3)>
8
9// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
10// CHECK-DAG: #[[strided2DnoOffset:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
11// CHECK-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
12
13module {
14  func @matmul(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
15    %c4 = constant 4 : index
16    %c3 = constant 3 : index
17    %c2 = constant 2 : index
18    %c0 = constant 0 : index
19    %c1 = constant 1 : index
20    %3 = view %A[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32, #map0>
21    %4 = view %A[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32, #map0>
22    %5 = view %A[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32, #map0>
23    %6 = dim %3, 0 : memref<?x?xf32, #map0>
24    %7 = dim %3, 1 : memref<?x?xf32, #map0>
25    %8 = dim %4, 1 : memref<?x?xf32, #map0>
26    loop.for %arg4 = %c0 to %6 step %c2 {
27      loop.for %arg5 = %c0 to %8 step %c3 {
28        loop.for %arg6 = %c0 to %7 step %c4 {
29          %11 = std.subview %3[%arg4, %arg6][%c2, %c4][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, offset: ?, strides: [?, ?]>
30          %14 = std.subview %4[%arg6, %arg5][%c4, %c3][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, offset: ?, strides: [?, ?]>
31          %17 = std.subview %5[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref<?x?xf32, #map0> to memref<?x?xf32, offset: ?, strides: [?, ?]>
32          linalg.matmul(%11, %14, %17) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
33        }
34      }
35    }
36    return
37  }
38}
39
40// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
41//       CHECK:   loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
42//       CHECK:     loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
43//       CHECK:       loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
44//       CHECK:         %[[vA:.*]] = std.subview {{.*}} : memref<?x?xf32, #[[strided2D]]>
45//       CHECK:         %[[vB:.*]] = std.subview {{.*}} : memref<?x?xf32, #[[strided2D]]>
46//       CHECK:         %[[vC:.*]] = std.subview {{.*}} : memref<?x?xf32, #[[strided2D]]>
47///
48//       CHECK:         %[[tmpA:.*]] = alloc() : memref<32xi8>
49//       CHECK:         %[[fullA:.*]] = std.view %[[tmpA]][][{{.*}}] : memref<32xi8> to memref<?x?xf32>
50//     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
51//       CHECK:         %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
52///
53//       CHECK:         %[[tmpB:.*]] = alloc() : memref<48xi8>
54//       CHECK:         %[[fullB:.*]] = std.view %[[tmpB]][][{{.*}}] : memref<48xi8> to memref<?x?xf32>
55//     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
56//       CHECK:         %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
57///
58//       CHECK:         %[[tmpC:.*]] = alloc() : memref<24xi8>
59//       CHECK:         %[[fullC:.*]] = std.view %[[tmpC]][][{{.*}}] : memref<24xi8> to memref<?x?xf32>
60//     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
61//       CHECK:         %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
62
63//       CHECK:         linalg.fill(%[[fullA]], {{.*}}) : memref<?x?xf32>, f32
64//       CHECK:         linalg.fill(%[[fullB]], {{.*}}) : memref<?x?xf32>, f32
65//       CHECK:         linalg.fill(%[[fullC]], {{.*}}) : memref<?x?xf32>, f32
66//       CHECK:         linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
67//       CHECK:         linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
68//       CHECK:         linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
69//
70//       CHECK:         linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
71//
72//       CHECK:         linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2DnoOffset]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
73//
74//       CHECK:         dealloc %[[tmpA]] : memref<32xi8>
75//       CHECK:         dealloc %[[tmpB]] : memref<48xi8>
76//       CHECK:         dealloc %[[tmpC]] : memref<24xi8>
77