1// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
2// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s
3
4// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
5// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
6
7// CHECK-LABEL: func @add4x2
8//      CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
9// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
10// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>>
11// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>>
12// CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32>
13// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
14// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
15// CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32>
16// CHECK-NEXT: %[[R1:.*]] = vector.tuple %[[A1]], %[[A2]] : vector<2x2xf32>, vector<2x2xf32>
17// CHECK-NEXT: %[[R2:.*]] = vector.insert_slices %[[R1]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
18// CHECK-NEXT: return %[[R2:.*]] : vector<4x2xf32>
19
20func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> {
21  %1 = addf %0, %0: vector<4x2xf32>
22  return %1: vector<4x2xf32>
23}
24
25// CHECK-LABEL: func @add4x4
26//      CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
27// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
28
29// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
30// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
31// CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32>
32
33// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
34// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
35// CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32>
36
37// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES1]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
38// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
39// CHECK-NEXT: %[[A3:.*]] = addf %[[TG5]], %[[TG6]] : vector<2x2xf32>
40
41// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES1]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
42// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES2]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
43// CHECK-NEXT: %[[A4:.*]] = addf %[[TG7]], %[[TG8]] : vector<2x2xf32>
44
45// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
46
47// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
48// CHECK-NEXT: %[[A5:.*]] = addf %[[TG9]], %[[A1]] : vector<2x2xf32>
49
50// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
51// CHECK-NEXT: %[[A6:.*]] = addf %[[TG11]], %[[A2]] : vector<2x2xf32>
52
53// CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
54// CHECK-NEXT: %[[A7:.*]] = addf %[[TG13]], %[[A3]] : vector<2x2xf32>
55
56// CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
57// CHECK-NEXT: %[[A8:.*]] = addf %[[TG15]], %[[A4]] : vector<2x2xf32>
58
59// CHECK-NEXT: %[[R3:.*]] = vector.tuple %[[A5]], %[[A6]], %[[A7]], %[[A8]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
60// CHECK-NEXT: %[[R4:.*]] = vector.insert_slices %[[R3]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
61// CHECK-NEXT: return %[[R4]] : vector<4x4xf32>
62
63func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
64  %2 = addf %0, %1: vector<4x4xf32>
65  %3 = addf %1, %2: vector<4x4xf32>
66  return %3: vector<4x4xf32>
67}
68
69#contraction_accesses0 = [
70  affine_map<(i, j, k) -> (i, k)>,
71  affine_map<(i, j, k) -> (k, j)>,
72  affine_map<(i, j, k) -> (i, j)>
73]
74#contraction_trait0 = {
75  indexing_maps = #contraction_accesses0,
76  iterator_types = ["parallel", "parallel", "reduction"]
77}
78
79// CHECK-LABEL: func @contraction4x4_ijk
80
81//      CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 6] : vector<4x6xi1>
82// CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [6, 4] : vector<6x4xi1>
83
84// Reducing output vector [0, 0]
85
86// CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x6xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
87// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<6x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
88// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
89// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %[[LMASK]], [2, 2], [1, 1] : vector<4x6xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
90// CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %[[RMASK]], [2, 2], [1, 1] : vector<6x4xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
91
92// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
93// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
94// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
95// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
96// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
97// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
98
99// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
100// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
101// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
102// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
103// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
104
105// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
106// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
107// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
108// CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
109// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
110
111// Reducing output vector [0, 2]
112
113// CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
114// CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
115// CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
116// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
117
118// CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
119// CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
120// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
121
122// CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
123// CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
124// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
125
126// Reducing output vector [2, 0]
127
128// CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
129// CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
130// CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
131// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
132
133// CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
134// CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
135// CHECK-NEXT:  %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
136
137// CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
138// CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
139// CHECK-NEXT:  %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
140
141// Reducing output vector [2, 2]
142
143// CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
144// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
145// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
146// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
147
148// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
149// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
150// CHECK-NEXT:  return %[[RES1]] : vector<4x4xf32>
151
152func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
153                         %arg2 : vector<4x4xf32>, %arg3 : index)
154                         -> (vector<4x4xf32>) {
155  %lhsm = vector.constant_mask [4, 6] : vector<4x6xi1>
156  %rhsm = vector.constant_mask [6, 4] : vector<6x4xi1>
157  %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
158      : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
159
160  return %0 : vector<4x4xf32>
161}
162
163#contraction_accesses1 = [
164  affine_map<(i, k, j) -> (i, k)>,
165  affine_map<(i, k, j) -> (k, j)>,
166  affine_map<(i, k, j) -> (i, j)>
167]
168#contraction_trait1 = {
169  indexing_maps = #contraction_accesses1,
170  iterator_types = ["parallel", "reduction", "parallel"]
171}
172
173// CHECK-LABEL: func @contraction4x4_ikj
174
175
176//      CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 2] : vector<4x2xi1>
177// CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [2, 4] : vector<2x4xi1>
178
179// Reducing output vector [0, 0]
180
181// CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
182// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
183// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
184// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>>
185// CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>>
186
187// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>>
188// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>>
189// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
190// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>>
191// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>>
192// CHECK-NEXT:  %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
193
194// Reducing output vector [0, 2]
195
196// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
197// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
198// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>>
199// CHECK-NEXT:  %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
200
201// Reducing output vector [2, 0]
202
203// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
204// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
205// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>>
206// CHECK-NEXT:  %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
207
208// Reducing output vector [2, 2]
209
210// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
211// CHECK-NEXT:  %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
212
213// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
214// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
215// CHECK-NEXT:  return %[[RES1]] : vector<4x4xf32>
216
217func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
218                         %arg2 : vector<4x4xf32>, %arg3 : index)
219                         -> (vector<4x4xf32>) {
220  %lhsm = vector.constant_mask [4, 2] : vector<4x2xi1>
221  %rhsm = vector.constant_mask [2, 4] : vector<2x4xi1>
222  %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
223      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
224
225  return %0 : vector<4x4xf32>
226}
227
228// CHECK-LABEL: func @contraction4x4_ikj_xfer_read
229
230// CHECK:      %[[C0:.*]] = constant 0 : index
231// CHECK:      %[[C2:.*]] = constant 2 : index
232
233// Check LHS vector.transfer read is split for each user.
234
235//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32>
236// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32>
237
238// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32>
239// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32>
240
241// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
242// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
243// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
244// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
245
246// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
247// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
248// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
249// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
250
251// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
252// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
253// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
254// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
255// CHECK-NEXT: return
256
257func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
258                                   %arg1 : memref<2x4xf32>,
259                                   %arg2 : memref<4x4xf32>) {
260  %c0 = constant 0 : index
261  %cf0 = constant 0.0 : f32
262
263  %0 = vector.transfer_read %arg0[%c0, %c0], %cf0
264    { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
265      : memref<4x2xf32>, vector<4x2xf32>
266
267  %1 = vector.transfer_read %arg1[%c0, %c0], %cf0
268    { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
269    : memref<2x4xf32>, vector<2x4xf32>
270
271  %2 = vector.transfer_read %arg2[%c0, %c0], %cf0
272    { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
273      : memref<4x4xf32>, vector<4x4xf32>
274
275  %3 = vector.contract #contraction_trait1 %0, %1, %2
276      : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
277
278  vector.transfer_write %3, %arg2[%c0, %c0]
279    {permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
280      : vector<4x4xf32>, memref<4x4xf32>
281  return
282}
283
284// TODO: Update test with VTR split transform.
285// CHECK-LABEL: func @vector_transfers
286// CHECK-COUNT-8: vector.transfer_read
287// CHECK-COUNT-4: addf
288// CHECK-COUNT-4: vector.transfer_write
289
290func @vector_transfers(%arg0: index, %arg1: index) {
291  %cst = constant 0.000000e+00 : f32
292  %0 = alloc(%arg0, %arg1) : memref<?x?xf32>
293  %1 = alloc(%arg0, %arg1) : memref<?x?xf32>
294  %2 = alloc(%arg0, %arg1) : memref<?x?xf32>
295  %cst_0 = constant 1.000000e+00 : f32
296  %cst_1 = constant 2.000000e+00 : f32
297  affine.for %arg2 = 0 to %arg0 step 4 {
298    affine.for %arg3 = 0 to %arg1 step 4 {
299      %4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
300      %5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
301      %6 = addf %4, %5 : vector<4x4xf32>
302      vector.transfer_write %6, %2[%arg2, %arg3] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : vector<4x4xf32>, memref<?x?xf32>
303    }
304  }
305  return
306}
307
308// CHECK-LABEL: func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>)
309//       CHECK: return %arg1
310
311func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
312  %0 = vector.tuple %arg0, %arg1 : vector<4xf32>, vector<8xf32>
313  %1 = vector.tuple_get %0, 1 : tuple<vector<4xf32>, vector<8xf32>>
314  return %1 : vector<8xf32>
315}
316
317// CHECK-LABEL: func @tuple_get_producer_consumer
318// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
319// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
320// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
321// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
322// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
323// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
324// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
325// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
326//      CHECK: return %[[A7]] : vector<2x4xf32>
327
328func @tuple_get_producer_consumer(
329  %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
330  %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
331  %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
332  %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
333  %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
334    : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
335      vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
336  // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
337  %1 = vector.insert_slices %0, [2, 4], [1, 1]
338    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
339            vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
340      into vector<4x16xf32>
341  // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
342  %2 = vector.extract_slices %1, [4, 8], [1, 1]
343    : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
344  // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
345  %3 = vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
346                              tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
347  // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
348  %4 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
349  // %arg7 == %4 at tupleIndex = -1, offsets = [0, 0, 2, 4]
350  %5 = vector.shape_cast %4 : vector<1x1x4x8xf32> to vector<4x8xf32>
351  // %arg7 == %5 at tupleIndex = -1, offsets = [2, 4]
352  %6 = vector.extract_slices %5, [2, 4], [1, 1]
353    : vector<4x8xf32> into
354      tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
355  // %arg7 == %6 at tupleIndex = 3, offsets = [0, 0]
356  %7 = vector.tuple_get %6, 3
357    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
358  // %arg7 == %7
359  return %7 : vector<2x4xf32>
360}
361
362// CHECK-LABEL: func @tuple_get_producer_consumer_swizzle
363// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
364// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
365// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
366// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
367// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
368// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
369// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
370// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
371//      CHECK: return %[[A7]] : vector<2x4xf32>
372
373func @tuple_get_producer_consumer_swizzle(
374  %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
375  %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
376  %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
377  %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
378  %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
379    : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
380      vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
381  // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
382  %1 = vector.insert_slices %0, [2, 4], [1, 1]
383    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
384            vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
385      into vector<4x16xf32>
386  // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
387  %2 = vector.extract_slices %1, [4, 8], [1, 1]
388    : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
389  // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
390  %3= vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
391                             tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
392  // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
393
394  // Extract tuple elements.
395  %4 = vector.tuple_get %3, 0 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
396  %5 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
397  // %arg7 == %5 at tupleIndex = -1, offsets = [0, 0, 2, 4]
398
399  // Swizzle tuple elements.
400  %6 = vector.tuple %5, %4 : vector<1x1x4x8xf32>, vector<1x1x4x8xf32>
401  // %arg7 == %6 at tupleIndex = 0, offsets = [0, 0, 2, 4]
402  %7 = vector.shape_cast %6 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> to
403                              tuple<vector<4x8xf32>, vector<4x8xf32>>
404  // %arg7 = %7 at tupleIndex = 0, offsets = [2, 4]
405  %8 = vector.tuple_get %7, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
406  // %arg7 == %8 at tupleIndex = -1, offsets = [2, 4]
407  %9 = vector.extract_slices %8, [2, 4], [1, 1]
408    : vector<4x8xf32> into
409      tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
410  // %arg7 == %9 at tupleIndex = 3, offsets = [0, 0]
411  %10 = vector.tuple_get %9, 3
412    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
413  // %arg7 == %10
414  return %10 : vector<2x4xf32>
415}
416
417// CHECK-LABEL: func @cancelling_shape_cast_ops
418//  CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
419//       CHECK: return %[[A0]] : vector<2x4xf32>
420func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
421  %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
422  %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
423  return %1 : vector<2x4xf32>
424}
425
426// CHECK-LABEL: func @vector_transfers_vector_element_type
427//      CHECK: %[[C0:.*]] = constant 0 : index
428//      CHECK: %[[C1:.*]] = constant 1 : index
429//      CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
430// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
431// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
432// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
433
434func @vector_transfers_vector_element_type() {
435  %c0 = constant 0 : index
436  %cf0 = constant 0.000000e+00 : f32
437  %vf0 = splat %cf0 : vector<2x4xf32>
438
439  %0 = alloc() : memref<6x2x1xvector<2x4xf32>>
440
441  %1 = vector.transfer_read %0[%c0, %c0, %c0], %vf0
442      {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}
443        : memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32>
444
445  %2 = vector.extract_slices %1, [1, 1, 2, 4], [1, 1, 1, 1]
446    : vector<2x1x2x4xf32> into tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
447  %3 = vector.tuple_get %2, 0 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
448  %4 = vector.tuple_get %2, 1 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
449  %5 = vector.tuple %3, %4 : vector<1x1x2x4xf32>, vector<1x1x2x4xf32>
450  %6 = vector.insert_slices %5, [1, 1, 2, 4], [1, 1, 1, 1]
451    : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>> into vector<2x1x2x4xf32>
452
453  vector.transfer_write %6, %0[%c0, %c0, %c0]
454    {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}
455      : vector<2x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
456
457  return
458}
459
460// Test that ShapeCastOp on tuple of vectors, decomposes to multiple
461// ShapeCastOps on vectors.
462// CHECK-LABEL: func @shape_cast_decomposition
463//       CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<5x4x2xf32> to vector<20x2xf32>
464//  CHECK-NEXT: %[[V1:.*]] = vector.shape_cast %{{.*}} : vector<3x4x2xf32> to vector<12x2xf32>
465//  CHECK-NEXT: return %[[V0]], %[[V1]] : vector<20x2xf32>, vector<12x2xf32>
466
467func @shape_cast_decomposition(%arg0 : vector<5x4x2xf32>,
468                               %arg1 : vector<3x4x2xf32>)
469  -> (vector<20x2xf32>, vector<12x2xf32>) {
470  %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
471  %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
472                              tuple<vector<20x2xf32>, vector<12x2xf32>>
473  %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
474  %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
475  return %2, %3 : vector<20x2xf32>, vector<12x2xf32>
476}
477
478// Test that cancelling ShapeCastOps are canonicalized away.
479// EX:
480//
481//  The following MLIR with cancelling ShapeCastOps:
482//
483//   %0 = source : vector<5x4x2xf32>
484//   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
485//   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
486//   %3 = user %2 : vector<5x4x2xf32>
487//
488//  Should canonicalize to the following:
489//
490//
491//   %0 = source : vector<5x4x2xf32>
492//   %1 = user %0 : vector<5x4x2xf32>
493//
494
495// ShapeCastOps on vectors.
496// CHECK-LABEL: func @shape_cast_fold
497//       CHECK: return %{{.*}},  %{{.*}} : vector<5x4x2xf32>, vector<3x4x2xf32>
498
499func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
500  -> (vector<5x4x2xf32>, vector<3x4x2xf32>) {
501  %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
502
503  %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
504                              tuple<vector<20x2xf32>, vector<12x2xf32>>
505
506  %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
507  %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
508
509  %4 = vector.tuple %2, %3 : vector<20x2xf32>, vector<12x2xf32>
510  %5 = vector.shape_cast %4 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
511                              tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
512
513  %6 = vector.tuple_get %5, 0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
514  %7 = vector.tuple_get %5, 1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
515
516  return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32>
517}
518