1// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=0 ptr-type=2 ind-type=2 vl=16" | \
2// RUN:   FileCheck %s --check-prefix=CHECK-VEC0
3// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=1 ptr-type=2 ind-type=2 vl=16" | \
4// RUN:   FileCheck %s --check-prefix=CHECK-VEC1
5// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=2 ptr-type=2 ind-type=2 vl=16" | \
6// RUN:   FileCheck %s --check-prefix=CHECK-VEC2
7
8#trait_scale_d = {
9  indexing_maps = [
10    affine_map<(i) -> (i)>,  // a
11    affine_map<(i) -> (i)>   // x (out)
12  ],
13  sparse = [
14    [ "D" ],  // a
15    [ "D" ]   // x
16  ],
17  iterator_types = ["parallel"],
18  doc = "x(i) = a(i) * b"
19}
20
21//
22// CHECK-VEC0-LABEL: func @scale_d
23// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
24// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
25// CHECK-VEC0-DAG:   %[[c1024:.*]] = constant 1024 : index
26// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
27// CHECK-VEC0:         %[[l:.*]] = load %{{.*}}[%[[i]]] : memref<1024xf32>
28// CHECK-VEC0:         %[[m:.*]] = mulf %[[l]], %{{.*}} : f32
29// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
30// CHECK-VEC0:       }
31// CHECK-VEC0:       return
32//
33// CHECK-VEC1-LABEL: func @scale_d
34// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
35// CHECK-VEC1-DAG:   %[[c16:.*]] = constant 16 : index
36// CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
37// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
38// CHECK-VEC1:         %[[r:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %{{.*}} {masked = [false]} : memref<1024xf32>, vector<16xf32>
39// CHECK-VEC1:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
40// CHECK-VEC1:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
41// CHECK-VEC1:         vector.transfer_write %[[m]], %{{.*}}[%[[i]]] {masked = [false]} : vector<16xf32>, memref<1024xf32>
42// CHECK-VEC1:       }
43// CHECK-VEC1:       return
44//
45// CHECK-VEC2-LABEL: func @scale_d
46// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
47// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
48// CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
49// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
50// CHECK-VEC2:         %[[r:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %{{.*}} {masked = [false]} : memref<1024xf32>, vector<16xf32>
51// CHECK-VEC2:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
52// CHECK-VEC2:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
53// CHECK-VEC2:         vector.transfer_write %[[m]], %{{.*}}[%[[i]]] {masked = [false]} : vector<16xf32>, memref<1024xf32>
54// CHECK-VEC2:       }
55// CHECK-VEC2:       return
56//
57func @scale_d(%arga: tensor<1024xf32>, %scale: f32) -> tensor<1024xf32> {
58  %0 = linalg.generic #trait_scale_d
59    ins(%arga: tensor<1024xf32>)
60    outs(%arga: tensor<1024xf32>) {
61      ^bb(%a: f32, %s : f32):
62        %0 = mulf %a, %scale : f32
63        linalg.yield %0 : f32
64  } -> tensor<1024xf32>
65  return %0 : tensor<1024xf32>
66}
67
68#trait_mul_s = {
69  indexing_maps = [
70    affine_map<(i) -> (i)>,  // a
71    affine_map<(i) -> (i)>,  // b
72    affine_map<(i) -> (i)>   // x (out)
73  ],
74  sparse = [
75    [ "S" ],  // a
76    [ "D" ],  // b
77    [ "D" ]   // x
78  ],
79  iterator_types = ["parallel"],
80  doc = "x(i) = a(i) * b(i)"
81}
82
83//
84// CHECK-VEC0-LABEL: func @mul_s
85// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
86// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
87// CHECK-VEC0:       %[[p:.*]] = load %{{.*}}[%[[c0]]] : memref<?xi32>
88// CHECK-VEC0:       %[[q:.*]] = index_cast %[[p]] : i32 to index
89// CHECK-VEC0:       %[[r:.*]] = load %{{.*}}[%[[c1]]] : memref<?xi32>
90// CHECK-VEC0:       %[[s:.*]] = index_cast %[[r]] : i32 to index
91// CHECK-VEC0:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
92// CHECK-VEC0:         %[[li:.*]] = load %{{.*}}[%[[i]]] : memref<?xi32>
93// CHECK-VEC0:         %[[ci:.*]] = index_cast %[[li]] : i32 to index
94// CHECK-VEC0:         %[[la:.*]] = load %{{.*}}[%[[i]]] : memref<?xf32>
95// CHECK-VEC0:         %[[lb:.*]] = load %{{.*}}[%[[ci]]] : memref<1024xf32>
96// CHECK-VEC0:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
97// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
98// CHECK-VEC0:       }
99// CHECK-VEC0:       return
100//
101// CHECK-VEC1-LABEL: func @mul_s
102// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
103// CHECK-VEC1-DAG:   %[[c1:.*]] = constant 1 : index
104// CHECK-VEC1:       %[[p:.*]] = load %{{.*}}[%[[c0]]] : memref<?xi32>
105// CHECK-VEC1:       %[[q:.*]] = index_cast %[[p]] : i32 to index
106// CHECK-VEC1:       %[[r:.*]] = load %{{.*}}[%[[c1]]] : memref<?xi32>
107// CHECK-VEC1:       %[[s:.*]] = index_cast %[[r]] : i32 to index
108// CHECK-VEC1:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
109// CHECK-VEC1:         %[[li:.*]] = load %{{.*}}[%[[i]]] : memref<?xi32>
110// CHECK-VEC1:         %[[ci:.*]] = index_cast %[[li]] : i32 to index
111// CHECK-VEC1:         %[[la:.*]] = load %{{.*}}[%[[i]]] : memref<?xf32>
112// CHECK-VEC1:         %[[lb:.*]] = load %{{.*}}[%[[ci]]] : memref<1024xf32>
113// CHECK-VEC1:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
114// CHECK-VEC1:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
115// CHECK-VEC1:       }
116// CHECK-VEC1:       return
117//
118// CHECK-VEC2-LABEL: func @mul_s
119// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
120// CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
121// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
122// CHECK-VEC2:       %[[p:.*]] = load %{{.*}}[%[[c0]]] : memref<?xi32>
123// CHECK-VEC2:       %[[q:.*]] = index_cast %[[p]] : i32 to index
124// CHECK-VEC2:       %[[r:.*]] = load %{{.*}}[%[[c1]]] : memref<?xi32>
125// CHECK-VEC2:       %[[s:.*]] = index_cast %[[r]] : i32 to index
126// CHECK-VEC2:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
127// CHECK-VEC2:         %[[sub:.*]] = subi %[[s]], %[[i]] : index
128// CHECK-VEC2:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
129// CHECK-VEC2:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
130// CHECK-VEC2:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
131// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
132// CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
133// CHECK-VEC2:         vector.scatter %{{.*}}[%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
134// CHECK-VEC2:       }
135// CHECK-VEC2:       return
136//
137func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>) -> tensor<1024xf32> {
138  %0 = linalg.generic #trait_mul_s
139    ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>)
140    outs(%arga: tensor<1024xf32>) {
141      ^bb(%a: f32, %b: f32, %s : f32):
142        %0 = mulf %a, %b : f32
143        linalg.yield %0 : f32
144  } -> tensor<1024xf32>
145  return %0 : tensor<1024xf32>
146}
147
148#trait_reduction_d = {
149  indexing_maps = [
150    affine_map<(i) -> (i)>,  // a
151    affine_map<(i) -> (i)>,  // b
152    affine_map<(i) -> ()>    // x (out)
153  ],
154  sparse = [
155    [ "D" ],  // a
156    [ "D" ],  // b
157    [     ]   // x
158  ],
159  iterator_types = ["reduction"],
160  doc = "x += a(i) * b(i)"
161}
162
163//
164// CHECK-VEC0-LABEL: func @reduction_d
165// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
166// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
167// CHECK-VEC0-DAG:   %[[c1024:.*]] = constant 1024 : index
168// CHECK-VEC0:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
169// CHECK-VEC0:         %[[la:.*]] = load %{{.*}}[%[[i]]] : memref<1024xf32>
170// CHECK-VEC0:         %[[lb:.*]] = load %{{.*}}[%[[i]]] : memref<1024xf32>
171// CHECK-VEC0:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
172// CHECK-VEC0:         %[[a:.*]] = addf %[[red_in]], %[[m]] : f32
173// CHECK-VEC0:         scf.yield %[[a]] : f32
174// CHECK-VEC0:       }
175// CHECK-VEC0:       return
176//
177// CHECK-VEC1-LABEL: func @reduction_d
178// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
179// CHECK-VEC1-DAG:   %[[c16:.*]] = constant 16 : index
180// CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
181// CHECK-VEC1-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
182// CHECK-VEC1:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
183// CHECK-VEC1:         %[[la:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
184// CHECK-VEC1:         %[[lb:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
185// CHECK-VEC1:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
186// CHECK-VEC1:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
187// CHECK-VEC1:         scf.yield %[[a]] : vector<16xf32>
188// CHECK-VEC1:       }
189// CHECK-VEC1:       %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
190// CHECK-VEC1:       return
191//
192// CHECK-VEC2-LABEL: func @reduction_d
193// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
194// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
195// CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
196// CHECK-VEC2-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
197// CHECK-VEC2:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
198// CHECK-VEC2:         %[[la:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
199// CHECK-VEC2:         %[[lb:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32>
200// CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
201// CHECK-VEC2:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
202// CHECK-VEC2:         scf.yield %[[a]] : vector<16xf32>
203// CHECK-VEC2:       }
204// CHECK-VEC2:       %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
205// CHECK-VEC2:       return
206//
207func @reduction_d(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
208  %0 = linalg.generic #trait_reduction_d
209    ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>)
210    outs(%argx: tensor<f32>) {
211      ^bb(%a: f32, %b : f32, %x : f32):
212        %0 = mulf %a, %b : f32
213        %1 = addf %x, %0 : f32
214        linalg.yield %1 : f32
215  } -> tensor<f32>
216  return %0 : tensor<f32>
217}
218
219#trait_mul_ds = {
220  indexing_maps = [
221    affine_map<(i,j) -> (i,j)>,  // a
222    affine_map<(i,j) -> (i,j)>,  // b
223    affine_map<(i,j) -> (i,j)>   // x (out)
224  ],
225  sparse = [
226    [ "D", "S" ],  // a
227    [ "D", "D" ],  // b
228    [ "D", "D" ]   // x
229  ],
230  iterator_types = ["parallel", "parallel"],
231  doc = "x(i,j) = a(i,j) * b(i,j)"
232}
233
234//
235// CHECK-VEC0-LABEL: func @mul_ds
236// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
237// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
238// CHECK-VEC0-DAG:   %[[c512:.*]] = constant 512 : index
239// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
240// CHECK-VEC0:         %[[p:.*]] = load %{{.*}}[%[[i]]] : memref<?xi32>
241// CHECK-VEC0:         %[[q:.*]] = index_cast %[[p]] : i32 to index
242// CHECK-VEC0:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
243// CHECK-VEC0:         %[[r:.*]] = load %{{.*}}[%[[a]]] : memref<?xi32>
244// CHECK-VEC0:         %[[s:.*]] = index_cast %[[r]] : i32 to index
245// CHECK-VEC0:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
246// CHECK-VEC0:           %[[lj:.*]] = load %{{.*}}[%[[j]]] : memref<?xi32>
247// CHECK-VEC0:           %[[cj:.*]] = index_cast %[[lj]] : i32 to index
248// CHECK-VEC0:           %[[la:.*]] = load %{{.*}}[%[[j]]] : memref<?xf32>
249// CHECK-VEC0:           %[[lb:.*]] = load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
250// CHECK-VEC0:           %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
251// CHECK-VEC0:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
252// CHECK-VEC0:         }
253// CHECK-VEC0:       }
254// CHECK-VEC0:       return
255//
256// CHECK-VEC1-LABEL: func @mul_ds
257// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
258// CHECK-VEC1-DAG:   %[[c1:.*]] = constant 1 : index
259// CHECK-VEC1-DAG:   %[[c512:.*]] = constant 512 : index
260// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
261// CHECK-VEC1:         %[[p:.*]] = load %{{.*}}[%[[i]]] : memref<?xi32>
262// CHECK-VEC1:         %[[q:.*]] = index_cast %[[p]] : i32 to index
263// CHECK-VEC1:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
264// CHECK-VEC1:         %[[r:.*]] = load %{{.*}}[%[[a]]] : memref<?xi32>
265// CHECK-VEC1:         %[[s:.*]] = index_cast %[[r]] : i32 to index
266// CHECK-VEC1:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
267// CHECK-VEC1:           %[[lj:.*]] = load %{{.*}}[%[[j]]] : memref<?xi32>
268// CHECK-VEC1:           %[[cj:.*]] = index_cast %[[lj]] : i32 to index
269// CHECK-VEC1:           %[[la:.*]] = load %{{.*}}[%[[j]]] : memref<?xf32>
270// CHECK-VEC1:           %[[lb:.*]] = load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
271// CHECK-VEC1:           %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
272// CHECK-VEC1:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
273// CHECK-VEC1:         }
274// CHECK-VEC1:       }
275// CHECK-VEC1:       return
276//
277// CHECK-VEC2-LABEL: func @mul_ds
278// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
279// CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
280// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
281// CHECK-VEC2-DAG:   %[[c512:.*]] = constant 512 : index
282// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
283// CHECK-VEC2:         %[[p:.*]] = load %{{.*}}[%[[i]]] : memref<?xi32>
284// CHECK-VEC2:         %[[q:.*]] = index_cast %[[p]] : i32 to index
285// CHECK-VEC2:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
286// CHECK-VEC2:         %[[r:.*]] = load %{{.*}}[%[[a]]] : memref<?xi32>
287// CHECK-VEC2:         %[[s:.*]] = index_cast %[[r]] : i32 to index
288// CHECK-VEC2:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
289// CHECK-VEC2:           %[[sub:.*]] = subi %[[s]], %[[j]] : index
290// CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
291// CHECK-VEC2:           %[[lj:.*]] = vector.maskedload %{{.*}}[%arg3], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
292// CHECK-VEC2:           %[[la:.*]] = vector.maskedload %{{.*}}[%arg3], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
293// CHECK-VEC2:           %[[lb:.*]] = vector.gather %{{.*}}[%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
294// CHECK-VEC2:           %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
295// CHECK-VEC2:           vector.scatter %{{.*}}[%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
296// CHECK-VEC2:         }
297// CHECK-VEC2:       }
298// CHECK-VEC2:       return
299//
300func @mul_ds(%arga: tensor<512x1024xf32>, %argb: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
301  %0 = linalg.generic #trait_mul_ds
302    ins(%arga, %argb: tensor<512x1024xf32>, tensor<512x1024xf32>)
303    outs(%arga: tensor<512x1024xf32>) {
304      ^bb(%a: f32, %b: f32, %s : f32):
305        %0 = mulf %a, %b : f32
306        linalg.yield %0 : f32
307  } -> tensor<512x1024xf32>
308  return %0 : tensor<512x1024xf32>
309}
310
311