1// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 vl=16" -cse -split-input-file | \
2// RUN:   FileCheck %s --check-prefix=CHECK-VEC0
3// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 vl=16" -cse -split-input-file | \
4// RUN:   FileCheck %s --check-prefix=CHECK-VEC1
5// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" -cse -split-input-file | \
6// RUN:   FileCheck %s --check-prefix=CHECK-VEC2
7// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" -cse -split-input-file | \
8// RUN:   FileCheck %s --check-prefix=CHECK-VEC3
9
10#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
11
12#trait_scale_d = {
13  indexing_maps = [
14    affine_map<(i) -> (i)>,  // a
15    affine_map<(i) -> (i)>   // x (out)
16  ],
17  iterator_types = ["parallel"],
18  doc = "x(i) = a(i) * b"
19}
20
21//
22// CHECK-VEC0-LABEL: func @scale_d
23// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
24// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
25// CHECK-VEC0-DAG:   %[[c1024:.*]] = constant 1024 : index
26// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
27// CHECK-VEC0:         %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
28// CHECK-VEC0:         %[[m:.*]] = mulf %[[l]], %{{.*}} : f32
29// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
30// CHECK-VEC0:       }
31// CHECK-VEC0:       return
32//
33// CHECK-VEC1-LABEL: func @scale_d
34// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
35// CHECK-VEC1-DAG:   %[[c16:.*]] = constant 16 : index
36// CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
37// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
38// CHECK-VEC1:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
39// CHECK-VEC1:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
40// CHECK-VEC1:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
41// CHECK-VEC1:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
42// CHECK-VEC1:       }
43// CHECK-VEC1:       return
44//
45// CHECK-VEC2-LABEL: func @scale_d
46// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
47// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
48// CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
49// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
50// CHECK-VEC2:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
51// CHECK-VEC2:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
52// CHECK-VEC2:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
53// CHECK-VEC2:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
54// CHECK-VEC2:       }
55// CHECK-VEC2:       return
56//
57func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
58  %0 = linalg.generic #trait_scale_d
59    ins(%arga: tensor<1024xf32, #DenseVector>)
60    outs(%argx: tensor<1024xf32>) {
61      ^bb(%a: f32, %x: f32):
62        %0 = mulf %a, %b : f32
63        linalg.yield %0 : f32
64  } -> tensor<1024xf32>
65  return %0 : tensor<1024xf32>
66}
67
68// -----
69
70#SparseVector = #sparse_tensor.encoding<{
71  dimLevelType = [ "compressed" ],
72  pointerBitWidth = 32,
73  indexBitWidth = 32
74}>
75
76#trait_mul_s = {
77  indexing_maps = [
78    affine_map<(i) -> (i)>,  // a
79    affine_map<(i) -> (i)>,  // b
80    affine_map<(i) -> (i)>   // x (out)
81  ],
82  iterator_types = ["parallel"],
83  doc = "x(i) = a(i) * b(i)"
84}
85
86//
87// CHECK-VEC0-LABEL: func @mul_s
88// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
89// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
90// CHECK-VEC0:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
91// CHECK-VEC0:       %[[a:.*]] = zexti %[[p]] : i32 to i64
92// CHECK-VEC0:       %[[q:.*]] = index_cast %[[a]] : i64 to index
93// CHECK-VEC0:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
94// CHECK-VEC0:       %[[b:.*]] = zexti %[[r]] : i32 to i64
95// CHECK-VEC0:       %[[s:.*]] = index_cast %[[b]] : i64 to index
96// CHECK-VEC0:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
97// CHECK-VEC0:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
98// CHECK-VEC0:         %[[zi:.*]] = zexti %[[li]] : i32 to i64
99// CHECK-VEC0:         %[[ci:.*]] = index_cast %[[zi]] : i64 to index
100// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
101// CHECK-VEC0:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
102// CHECK-VEC0:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
103// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
104// CHECK-VEC0:       }
105// CHECK-VEC0:       return
106//
107// CHECK-VEC1-LABEL: func @mul_s
108// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
109// CHECK-VEC1-DAG:   %[[c1:.*]] = constant 1 : index
110// CHECK-VEC1:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
111// CHECK-VEC1:       %[[a:.*]] = zexti %[[p]] : i32 to i64
112// CHECK-VEC1:       %[[q:.*]] = index_cast %[[a]] : i64 to index
113// CHECK-VEC1:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
114// CHECK-VEC1:       %[[b:.*]] = zexti %[[r]] : i32 to i64
115// CHECK-VEC1:       %[[s:.*]] = index_cast %[[b]] : i64 to index
116// CHECK-VEC1:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
117// CHECK-VEC1:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
118// CHECK-VEC1:         %[[zi:.*]] = zexti %[[li]] : i32 to i64
119// CHECK-VEC1:         %[[ci:.*]] = index_cast %[[zi]] : i64 to index
120// CHECK-VEC1:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
121// CHECK-VEC1:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
122// CHECK-VEC1:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
123// CHECK-VEC1:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
124// CHECK-VEC1:       }
125// CHECK-VEC1:       return
126//
127// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
128// CHECK-VEC2-LABEL: func @mul_s
129// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
130// CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
131// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
132// CHECK-VEC2:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
133// CHECK-VEC2:       %[[a:.*]] = zexti %[[p]] : i32 to i64
134// CHECK-VEC2:       %[[q:.*]] = index_cast %[[a]] : i64 to index
135// CHECK-VEC2:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
136// CHECK-VEC2:       %[[b:.*]] = zexti %[[r]] : i32 to i64
137// CHECK-VEC2:       %[[s:.*]] = index_cast %[[b]] : i64 to index
138// CHECK-VEC2:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
139// CHECK-VEC2:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
140// CHECK-VEC2:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
141// CHECK-VEC2:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
142// CHECK-VEC2:         %[[zi:.*]] = zexti %[[li]] : vector<16xi32> to vector<16xi64>
143// CHECK-VEC2:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
144// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
145// CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
146// CHECK-VEC2:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
147// CHECK-VEC2:       }
148// CHECK-VEC2:       return
149//
150// CHECK-VEC3:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
151// CHECK-VEC3-LABEL: func @mul_s
152// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
153// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
154// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
155// CHECK-VEC3:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
156// CHECK-VEC3:       %[[a:.*]] = zexti %[[p]] : i32 to i64
157// CHECK-VEC3:       %[[q:.*]] = index_cast %[[a]] : i64 to index
158// CHECK-VEC3:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
159// CHECK-VEC3:       %[[b:.*]] = zexti %[[r]] : i32 to i64
160// CHECK-VEC3:       %[[s:.*]] = index_cast %[[b]] : i64 to index
161// CHECK-VEC3:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
162// CHECK-VEC3:         %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]]
163// CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
164// CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
165// CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
166// CHECK-VEC3:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
167// CHECK-VEC3:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
168// CHECK-VEC3:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
169// CHECK-VEC3:       }
170// CHECK-VEC3:       return
171//
172func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
173  %0 = linalg.generic #trait_mul_s
174    ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
175    outs(%argx: tensor<1024xf32>) {
176      ^bb(%a: f32, %b: f32, %x: f32):
177        %0 = mulf %a, %b : f32
178        linalg.yield %0 : f32
179  } -> tensor<1024xf32>
180  return %0 : tensor<1024xf32>
181}
182
183// -----
184
185#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
186
187#trait_reduction_d = {
188  indexing_maps = [
189    affine_map<(i) -> (i)>,  // a
190    affine_map<(i) -> (i)>,  // b
191    affine_map<(i) -> ()>    // x (out)
192  ],
193  iterator_types = ["reduction"],
194  doc = "x += a(i) * b(i)"
195}
196
197//
198// CHECK-VEC0-LABEL: func @reduction_d
199// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
200// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
201// CHECK-VEC0-DAG:   %[[c1024:.*]] = constant 1024 : index
202// CHECK-VEC0:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
203// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
204// CHECK-VEC0:         %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
205// CHECK-VEC0:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
206// CHECK-VEC0:         %[[a:.*]] = addf %[[red_in]], %[[m]] : f32
207// CHECK-VEC0:         scf.yield %[[a]] : f32
208// CHECK-VEC0:       }
209// CHECK-VEC0:       return
210//
211// CHECK-VEC1-LABEL: func @reduction_d
212// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
213// CHECK-VEC1-DAG:   %[[i0:.*]] = constant 0 : i32
214// CHECK-VEC1-DAG:   %[[c16:.*]] = constant 16 : index
215// CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
216// CHECK-VEC1-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
217// CHECK-VEC1:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
218// CHECK-VEC1:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[i0]] : i32] : vector<16xf32>
219// CHECK-VEC1:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
220// CHECK-VEC1:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
221// CHECK-VEC1:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
222// CHECK-VEC1:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
223// CHECK-VEC1:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
224// CHECK-VEC1:         scf.yield %[[a]] : vector<16xf32>
225// CHECK-VEC1:       }
226// CHECK-VEC1:       %{{.*}} = vector.reduction "add", %[[red]] : vector<16xf32> into f32
227// CHECK-VEC1:       return
228//
229// CHECK-VEC2-LABEL: func @reduction_d
230// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
231// CHECK-VEC2-DAG:   %[[i0:.*]] = constant 0 : i32
232// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
233// CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
234// CHECK-VEC2-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
235// CHECK-VEC2:       %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
236// CHECK-VEC2:       %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[i0]] : i32] : vector<16xf32>
237// CHECK-VEC2:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
238// CHECK-VEC2:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
239// CHECK-VEC2:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
240// CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
241// CHECK-VEC2:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
242// CHECK-VEC2:         scf.yield %[[a]] : vector<16xf32>
243// CHECK-VEC2:       }
244// CHECK-VEC2:       %{{.*}} = vector.reduction "add", %[[red]] : vector<16xf32> into f32
245// CHECK-VEC2:       return
246//
247func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
248  %0 = linalg.generic #trait_reduction_d
249    ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32>)
250    outs(%argx: tensor<f32>) {
251      ^bb(%a: f32, %b: f32, %x: f32):
252        %0 = mulf %a, %b : f32
253        %1 = addf %x, %0 : f32
254        linalg.yield %1 : f32
255  } -> tensor<f32>
256  return %0 : tensor<f32>
257}
258
259// -----
260
261#SparseMatrix = #sparse_tensor.encoding<{
262  dimLevelType = [ "dense", "compressed" ],
263  pointerBitWidth = 32,
264  indexBitWidth = 32
265}>
266
267#trait_mul_ds = {
268  indexing_maps = [
269    affine_map<(i,j) -> (i,j)>,  // A
270    affine_map<(i,j) -> (i,j)>,  // B
271    affine_map<(i,j) -> (i,j)>   // X (out)
272  ],
273  iterator_types = ["parallel", "parallel"],
274  doc = "X(i,j) = A(i,j) * B(i,j)"
275}
276
277//
278// CHECK-VEC0-LABEL: func @mul_ds
279// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
280// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
281// CHECK-VEC0-DAG:   %[[c512:.*]] = constant 512 : index
282// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
283// CHECK-VEC0:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
284// CHECK-VEC0:         %[[a:.*]] = zexti %[[p]] : i32 to i64
285// CHECK-VEC0:         %[[q:.*]] = index_cast %[[a]] : i64 to index
286// CHECK-VEC0:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
287// CHECK-VEC0:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
288// CHECK-VEC0:         %[[b:.*]] = zexti %[[r]] : i32 to i64
289// CHECK-VEC0:         %[[s:.*]] = index_cast %[[b]] : i64 to index
290// CHECK-VEC0:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
291// CHECK-VEC0:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
292// CHECK-VEC0:           %[[zj:.*]] = zexti %[[lj]] : i32 to i64
293// CHECK-VEC0:           %[[cj:.*]] = index_cast %[[zj]] : i64 to index
294// CHECK-VEC0:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
295// CHECK-VEC0:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
296// CHECK-VEC0:           %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
297// CHECK-VEC0:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
298// CHECK-VEC0:         }
299// CHECK-VEC0:       }
300// CHECK-VEC0:       return
301//
302// CHECK-VEC1-LABEL: func @mul_ds
303// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
304// CHECK-VEC1-DAG:   %[[c1:.*]] = constant 1 : index
305// CHECK-VEC1-DAG:   %[[c512:.*]] = constant 512 : index
306// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
307// CHECK-VEC1:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
308// CHECK-VEC1:         %[[a:.*]] = zexti %[[p]] : i32 to i64
309// CHECK-VEC1:         %[[q:.*]] = index_cast %[[a]] : i64 to index
310// CHECK-VEC1:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
311// CHECK-VEC1:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
312// CHECK-VEC1:         %[[b:.*]] = zexti %[[r]] : i32 to i64
313// CHECK-VEC1:         %[[s:.*]] = index_cast %[[b]] : i64 to index
314// CHECK-VEC1:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
315// CHECK-VEC1:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
316// CHECK-VEC1:           %[[zj:.*]] = zexti %[[lj]] : i32 to i64
317// CHECK-VEC1:           %[[cj:.*]] = index_cast %[[zj]] : i64 to index
318// CHECK-VEC1:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
319// CHECK-VEC1:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
320// CHECK-VEC1:           %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
321// CHECK-VEC1:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
322// CHECK-VEC1:         }
323// CHECK-VEC1:       }
324// CHECK-VEC1:       return
325//
326// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
327// CHECK-VEC2-LABEL: func @mul_ds
328// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
329// CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
330// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
331// CHECK-VEC2-DAG:   %[[c512:.*]] = constant 512 : index
332// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
333// CHECK-VEC2:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
334// CHECK-VEC2:         %[[a:.*]] = zexti %[[p]] : i32 to i64
335// CHECK-VEC2:         %[[q:.*]] = index_cast %[[a]] : i64 to index
336// CHECK-VEC2:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
337// CHECK-VEC2:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
338// CHECK-VEC2:         %[[b:.*]] = zexti %[[r]] : i32 to i64
339// CHECK-VEC2:         %[[s:.*]] = index_cast %[[b]] : i64 to index
340// CHECK-VEC2:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
341// CHECK-VEC2:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
342// CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
343// CHECK-VEC2:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
344// CHECK-VEC2:           %[[zj:.*]] = zexti %[[lj]] : vector<16xi32> to vector<16xi64>
345// CHECK-VEC2:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
346// CHECK-VEC2:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
347// CHECK-VEC2:           %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
348// CHECK-VEC2:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
349// CHECK-VEC2:         }
350// CHECK-VEC2:       }
351// CHECK-VEC2:       return
352//
353// CHECK-VEC3:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
354// CHECK-VEC3-LABEL: func @mul_ds
355// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
356// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
357// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
358// CHECK-VEC3-DAG:   %[[c512:.*]] = constant 512 : index
359// CHECK-VEC3:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
360// CHECK-VEC3:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
361// CHECK-VEC3:         %[[a:.*]] = zexti %[[p]] : i32 to i64
362// CHECK-VEC3:         %[[q:.*]] = index_cast %[[a]] : i64 to index
363// CHECK-VEC3:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
364// CHECK-VEC3:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
365// CHECK-VEC3:         %[[b:.*]] = zexti %[[r]] : i32 to i64
366// CHECK-VEC3:         %[[s:.*]] = index_cast %[[b]] : i64 to index
367// CHECK-VEC3:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
368// CHECK-VEC3:           %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]]
369// CHECK-VEC3:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
370// CHECK-VEC3:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
371// CHECK-VEC3:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
372// CHECK-VEC3:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
373// CHECK-VEC3:           %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
374// CHECK-VEC3:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
375// CHECK-VEC3:         }
376// CHECK-VEC3:       }
377// CHECK-VEC3:       return
378//
379func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
380  %0 = linalg.generic #trait_mul_ds
381    ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)
382    outs(%argx: tensor<512x1024xf32>) {
383      ^bb(%a: f32, %b: f32, %x: f32):
384        %0 = mulf %a, %b : f32
385        linalg.yield %0 : f32
386  } -> tensor<512x1024xf32>
387  return %0 : tensor<512x1024xf32>
388}
389
390// -----
391
392#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}>
393
394#trait_affine = {
395  indexing_maps = [
396    affine_map<(i,j) -> (i,j)>,
397    affine_map<(i,j) -> (i+1,j)>
398  ],
399  iterator_types = ["parallel","parallel"],
400  doc = "X(i+1,j) += A(i,j)"
401}
402
403//
404// CHECK-VEC0-LABEL: func @add_dense
405// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
406// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
407// CHECK-VEC0-DAG:   %[[c32:.*]] = constant 32 : index
408// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
409// CHECK-VEC0:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
410// CHECK-VEC0:         %[[i1:.*]] = addi %[[i]], %[[c1]] : index
411// CHECK-VEC0:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
412// CHECK-VEC0:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
413// CHECK-VEC0:           %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
414// CHECK-VEC0:           %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
415// CHECK-VEC0:           %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
416// CHECK-VEC0:           %[[s:.*]] = addf %[[x]], %[[a]] : f64
417// CHECK-VEC0:           memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
418// CHECK-VEC0:         }
419// CHECK-VEC0:       }
420// CHECK-VEC0:       return
421//
422// CHECK-VEC1-LABEL: func @add_dense
423// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
424// CHECK-VEC1-DAG:   %[[c1:.*]] = constant 1 : index
425// CHECK-VEC1-DAG:   %[[c32:.*]] = constant 32 : index
426// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
427// CHECK-VEC1:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
428// CHECK-VEC1:         %[[i1:.*]] = addi %[[i]], %[[c1]] : index
429// CHECK-VEC1:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
430// CHECK-VEC1:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] {
431// CHECK-VEC1:           %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xindex>
432// CHECK-VEC1:           %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
433// CHECK-VEC1:           %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref<?xf64>
434// CHECK-VEC1:           %[[s:.*]] = addf %[[x]], %[[a]] : f64
435// CHECK-VEC1:           memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64>
436// CHECK-VEC1:         }
437// CHECK-VEC1:       }
438// CHECK-VEC1:       return
439//
440// CHECK-VEC2:       #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1)
441// CHECK-VEC2-LABEL: func @add_dense
442// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
443// CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
444// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
445// CHECK-VEC2-DAG:   %[[c32:.*]] = constant 32 : index
446// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] {
447// CHECK-VEC2:         %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xindex>
448// CHECK-VEC2:         %[[i1:.*]] = addi %[[i]], %[[c1]] : index
449// CHECK-VEC2:         %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref<?xindex>
450// CHECK-VEC2:         scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] {
451// CHECK-VEC2:           %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]]
452// CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
453// CHECK-VEC2:           %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xindex>
454// CHECK-VEC2:           %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64>
455// CHECK-VEC2:           %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref<?xf64>
456// CHECK-VEC2:           %[[s:.*]] = addf %[[x]], %[[a]] : vector<16xf64>
457// CHECK-VEC2:           vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64>
458// CHECK-VEC2:         }
459// CHECK-VEC2:       }
460// CHECK-VEC2:       return
461//
462func @add_dense(%arga: tensor<32x64xf64, #SparseMatrix>,
463                %argx: tensor<33x64xf64> {linalg.inplaceable = true}) -> tensor<33x64xf64> {
464  %0 = linalg.generic #trait_affine
465     ins(%arga: tensor<32x64xf64, #SparseMatrix>)
466    outs(%argx: tensor<33x64xf64>) {
467      ^bb(%a: f64, %x: f64):
468        %0 = addf %x, %a : f64
469        linalg.yield %0 : f64
470  } -> tensor<33x64xf64>
471  return %0 : tensor<33x64xf64>
472}
473