1// RUN: mlir-opt %s -sparsification="vectorization-strategy=0 vl=16" | \
2// RUN:   FileCheck %s --check-prefix=CHECK-VEC0
3// RUN: mlir-opt %s -sparsification="vectorization-strategy=1 vl=16" | \
4// RUN:   FileCheck %s --check-prefix=CHECK-VEC1
5// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16" | \
6// RUN:   FileCheck %s --check-prefix=CHECK-VEC2
7// RUN: mlir-opt %s -sparsification="vectorization-strategy=2 vl=16 enable-simd-index32=true" | \
8// RUN:   FileCheck %s --check-prefix=CHECK-VEC3
9
10#DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
11
12#SparseVector = #sparse_tensor.encoding<{
13  dimLevelType = [ "compressed" ],
14  pointerBitWidth = 32,
15  indexBitWidth = 32
16}>
17
18#SparseMatrix = #sparse_tensor.encoding<{
19  dimLevelType = [ "dense", "compressed" ],
20  pointerBitWidth = 32,
21  indexBitWidth = 32
22}>
23
24#trait_scale_d = {
25  indexing_maps = [
26    affine_map<(i) -> (i)>,  // a
27    affine_map<(i) -> (i)>   // x (out)
28  ],
29  iterator_types = ["parallel"],
30  doc = "x(i) = a(i) * b"
31}
32
33//
34// CHECK-VEC0-LABEL: func @scale_d
35// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
36// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
37// CHECK-VEC0-DAG:   %[[c1024:.*]] = constant 1024 : index
38// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] {
39// CHECK-VEC0:         %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
40// CHECK-VEC0:         %[[m:.*]] = mulf %[[l]], %{{.*}} : f32
41// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>
42// CHECK-VEC0:       }
43// CHECK-VEC0:       return
44//
45// CHECK-VEC1-LABEL: func @scale_d
46// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
47// CHECK-VEC1-DAG:   %[[c16:.*]] = constant 16 : index
48// CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
49// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
50// CHECK-VEC1:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
51// CHECK-VEC1:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
52// CHECK-VEC1:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
53// CHECK-VEC1:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
54// CHECK-VEC1:       }
55// CHECK-VEC1:       return
56//
57// CHECK-VEC2-LABEL: func @scale_d
58// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
59// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
60// CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
61// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] {
62// CHECK-VEC2:         %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
63// CHECK-VEC2:         %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32>
64// CHECK-VEC2:         %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32>
65// CHECK-VEC2:         vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
66// CHECK-VEC2:       }
67// CHECK-VEC2:       return
68//
69func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
70  %0 = linalg.generic #trait_scale_d
71    ins(%arga: tensor<1024xf32, #DenseVector>)
72    outs(%argx: tensor<1024xf32>) {
73      ^bb(%a: f32, %x: f32):
74        %0 = mulf %a, %b : f32
75        linalg.yield %0 : f32
76  } -> tensor<1024xf32>
77  return %0 : tensor<1024xf32>
78}
79
80#trait_mul_s = {
81  indexing_maps = [
82    affine_map<(i) -> (i)>,  // a
83    affine_map<(i) -> (i)>,  // b
84    affine_map<(i) -> (i)>   // x (out)
85  ],
86  iterator_types = ["parallel"],
87  doc = "x(i) = a(i) * b(i)"
88}
89
90//
91// CHECK-VEC0-LABEL: func @mul_s
92// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
93// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
94// CHECK-VEC0:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
95// CHECK-VEC0:       %[[a:.*]] = zexti %[[p]] : i32 to i64
96// CHECK-VEC0:       %[[q:.*]] = index_cast %[[a]] : i64 to index
97// CHECK-VEC0:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
98// CHECK-VEC0:       %[[b:.*]] = zexti %[[r]] : i32 to i64
99// CHECK-VEC0:       %[[s:.*]] = index_cast %[[b]] : i64 to index
100// CHECK-VEC0:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
101// CHECK-VEC0:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
102// CHECK-VEC0:         %[[zi:.*]] = zexti %[[li]] : i32 to i64
103// CHECK-VEC0:         %[[ci:.*]] = index_cast %[[zi]] : i64 to index
104// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
105// CHECK-VEC0:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
106// CHECK-VEC0:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
107// CHECK-VEC0:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
108// CHECK-VEC0:       }
109// CHECK-VEC0:       return
110//
111// CHECK-VEC1-LABEL: func @mul_s
112// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
113// CHECK-VEC1-DAG:   %[[c1:.*]] = constant 1 : index
114// CHECK-VEC1:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
115// CHECK-VEC1:       %[[a:.*]] = zexti %[[p]] : i32 to i64
116// CHECK-VEC1:       %[[q:.*]] = index_cast %[[a]] : i64 to index
117// CHECK-VEC1:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
118// CHECK-VEC1:       %[[b:.*]] = zexti %[[r]] : i32 to i64
119// CHECK-VEC1:       %[[s:.*]] = index_cast %[[b]] : i64 to index
120// CHECK-VEC1:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] {
121// CHECK-VEC1:         %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
122// CHECK-VEC1:         %[[zi:.*]] = zexti %[[li]] : i32 to i64
123// CHECK-VEC1:         %[[ci:.*]] = index_cast %[[zi]] : i64 to index
124// CHECK-VEC1:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
125// CHECK-VEC1:         %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32>
126// CHECK-VEC1:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
127// CHECK-VEC1:         store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32>
128// CHECK-VEC1:       }
129// CHECK-VEC1:       return
130//
131// CHECK-VEC2-LABEL: func @mul_s
132// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
133// CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
134// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
135// CHECK-VEC2:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
136// CHECK-VEC2:       %[[a:.*]] = zexti %[[p]] : i32 to i64
137// CHECK-VEC2:       %[[q:.*]] = index_cast %[[a]] : i64 to index
138// CHECK-VEC2:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
139// CHECK-VEC2:       %[[b:.*]] = zexti %[[r]] : i32 to i64
140// CHECK-VEC2:       %[[s:.*]] = index_cast %[[b]] : i64 to index
141// CHECK-VEC2:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
142// CHECK-VEC2:         %[[sub:.*]] = subi %[[s]], %[[i]] : index
143// CHECK-VEC2:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
144// CHECK-VEC2:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
145// CHECK-VEC2:         %[[zi:.*]] = zexti %[[li]] : vector<16xi32> to vector<16xi64>
146// CHECK-VEC2:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
147// CHECK-VEC2:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
148// CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
149// CHECK-VEC2:         vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
150// CHECK-VEC2:       }
151// CHECK-VEC2:       return
152//
153// CHECK-VEC3-LABEL: func @mul_s
154// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
155// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
156// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
157// CHECK-VEC3:       %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
158// CHECK-VEC3:       %[[a:.*]] = zexti %[[p]] : i32 to i64
159// CHECK-VEC3:       %[[q:.*]] = index_cast %[[a]] : i64 to index
160// CHECK-VEC3:       %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref<?xi32>
161// CHECK-VEC3:       %[[b:.*]] = zexti %[[r]] : i32 to i64
162// CHECK-VEC3:       %[[s:.*]] = index_cast %[[b]] : i64 to index
163// CHECK-VEC3:       scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] {
164// CHECK-VEC3:         %[[sub:.*]] = subi %{{.*}}, %[[i]] : index
165// CHECK-VEC3:         %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
166// CHECK-VEC3:         %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
167// CHECK-VEC3:         %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
168// CHECK-VEC3:         %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
169// CHECK-VEC3:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
170// CHECK-VEC3:         vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
171// CHECK-VEC3:       }
172// CHECK-VEC3:       return
173//
174func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> {
175  %0 = linalg.generic #trait_mul_s
176    ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>)
177    outs(%argx: tensor<1024xf32>) {
178      ^bb(%a: f32, %b: f32, %x: f32):
179        %0 = mulf %a, %b : f32
180        linalg.yield %0 : f32
181  } -> tensor<1024xf32>
182  return %0 : tensor<1024xf32>
183}
184
185#trait_reduction_d = {
186  indexing_maps = [
187    affine_map<(i) -> (i)>,  // a
188    affine_map<(i) -> (i)>,  // b
189    affine_map<(i) -> ()>    // x (out)
190  ],
191  iterator_types = ["reduction"],
192  doc = "x += a(i) * b(i)"
193}
194
195//
196// CHECK-VEC0-LABEL: func @reduction_d
197// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
198// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
199// CHECK-VEC0-DAG:   %[[c1024:.*]] = constant 1024 : index
200// CHECK-VEC0:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) {
201// CHECK-VEC0:         %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xf32>
202// CHECK-VEC0:         %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32>
203// CHECK-VEC0:         %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
204// CHECK-VEC0:         %[[a:.*]] = addf %[[red_in]], %[[m]] : f32
205// CHECK-VEC0:         scf.yield %[[a]] : f32
206// CHECK-VEC0:       }
207// CHECK-VEC0:       return
208//
209// CHECK-VEC1-LABEL: func @reduction_d
210// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
211// CHECK-VEC1-DAG:   %[[c16:.*]] = constant 16 : index
212// CHECK-VEC1-DAG:   %[[c1024:.*]] = constant 1024 : index
213// CHECK-VEC1-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
214// CHECK-VEC1:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
215// CHECK-VEC1:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
216// CHECK-VEC1:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
217// CHECK-VEC1:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
218// CHECK-VEC1:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
219// CHECK-VEC1:         scf.yield %[[a]] : vector<16xf32>
220// CHECK-VEC1:       }
221// CHECK-VEC1:       %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
222// CHECK-VEC1:       return
223//
224// CHECK-VEC2-LABEL: func @reduction_d
225// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
226// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
227// CHECK-VEC2-DAG:   %[[c1024:.*]] = constant 1024 : index
228// CHECK-VEC2-DAG:   %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32>
229// CHECK-VEC2:       %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) {
230// CHECK-VEC2:         %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
231// CHECK-VEC2:         %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
232// CHECK-VEC2:         %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
233// CHECK-VEC2:         %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32>
234// CHECK-VEC2:         scf.yield %[[a]] : vector<16xf32>
235// CHECK-VEC2:       }
236// CHECK-VEC2:       %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32
237// CHECK-VEC2:       return
238//
239func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor<f32>) -> tensor<f32> {
240  %0 = linalg.generic #trait_reduction_d
241    ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32>)
242    outs(%argx: tensor<f32>) {
243      ^bb(%a: f32, %b: f32, %x: f32):
244        %0 = mulf %a, %b : f32
245        %1 = addf %x, %0 : f32
246        linalg.yield %1 : f32
247  } -> tensor<f32>
248  return %0 : tensor<f32>
249}
250
251#trait_mul_ds = {
252  indexing_maps = [
253    affine_map<(i,j) -> (i,j)>,  // A
254    affine_map<(i,j) -> (i,j)>,  // B
255    affine_map<(i,j) -> (i,j)>   // X (out)
256  ],
257  iterator_types = ["parallel", "parallel"],
258  doc = "X(i,j) = A(i,j) * B(i,j)"
259}
260
261//
262// CHECK-VEC0-LABEL: func @mul_ds
263// CHECK-VEC0-DAG:   %[[c0:.*]] = constant 0 : index
264// CHECK-VEC0-DAG:   %[[c1:.*]] = constant 1 : index
265// CHECK-VEC0-DAG:   %[[c512:.*]] = constant 512 : index
266// CHECK-VEC0:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
267// CHECK-VEC0:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
268// CHECK-VEC0:         %[[a:.*]] = zexti %[[p]] : i32 to i64
269// CHECK-VEC0:         %[[q:.*]] = index_cast %[[a]] : i64 to index
270// CHECK-VEC0:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
271// CHECK-VEC0:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
272// CHECK-VEC0:         %[[b:.*]] = zexti %[[r]] : i32 to i64
273// CHECK-VEC0:         %[[s:.*]] = index_cast %[[b]] : i64 to index
274// CHECK-VEC0:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
275// CHECK-VEC0:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
276// CHECK-VEC0:           %[[zj:.*]] = zexti %[[lj]] : i32 to i64
277// CHECK-VEC0:           %[[cj:.*]] = index_cast %[[zj]] : i64 to index
278// CHECK-VEC0:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
279// CHECK-VEC0:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
280// CHECK-VEC0:           %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
281// CHECK-VEC0:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
282// CHECK-VEC0:         }
283// CHECK-VEC0:       }
284// CHECK-VEC0:       return
285//
286// CHECK-VEC1-LABEL: func @mul_ds
287// CHECK-VEC1-DAG:   %[[c0:.*]] = constant 0 : index
288// CHECK-VEC1-DAG:   %[[c1:.*]] = constant 1 : index
289// CHECK-VEC1-DAG:   %[[c512:.*]] = constant 512 : index
290// CHECK-VEC1:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
291// CHECK-VEC1:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
292// CHECK-VEC1:         %[[a:.*]] = zexti %[[p]] : i32 to i64
293// CHECK-VEC1:         %[[q:.*]] = index_cast %[[a]] : i64 to index
294// CHECK-VEC1:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
295// CHECK-VEC1:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
296// CHECK-VEC1:         %[[b:.*]] = zexti %[[r]] : i32 to i64
297// CHECK-VEC1:         %[[s:.*]] = index_cast %[[b]] : i64 to index
298// CHECK-VEC1:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] {
299// CHECK-VEC1:           %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xi32>
300// CHECK-VEC1:           %[[zj:.*]] = zexti %[[lj]] : i32 to i64
301// CHECK-VEC1:           %[[cj:.*]] = index_cast %[[zj]] : i64 to index
302// CHECK-VEC1:           %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref<?xf32>
303// CHECK-VEC1:           %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
304// CHECK-VEC1:           %[[m:.*]] = mulf %[[la]], %[[lb]] : f32
305// CHECK-VEC1:           store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32>
306// CHECK-VEC1:         }
307// CHECK-VEC1:       }
308// CHECK-VEC1:       return
309//
310// CHECK-VEC2-LABEL: func @mul_ds
311// CHECK-VEC2-DAG:   %[[c0:.*]] = constant 0 : index
312// CHECK-VEC2-DAG:   %[[c1:.*]] = constant 1 : index
313// CHECK-VEC2-DAG:   %[[c16:.*]] = constant 16 : index
314// CHECK-VEC2-DAG:   %[[c512:.*]] = constant 512 : index
315// CHECK-VEC2:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
316// CHECK-VEC2:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
317// CHECK-VEC2:         %[[a:.*]] = zexti %[[p]] : i32 to i64
318// CHECK-VEC2:         %[[q:.*]] = index_cast %[[a]] : i64 to index
319// CHECK-VEC2:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
320// CHECK-VEC2:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
321// CHECK-VEC2:         %[[b:.*]] = zexti %[[r]] : i32 to i64
322// CHECK-VEC2:         %[[s:.*]] = index_cast %[[b]] : i64 to index
323// CHECK-VEC2:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
324// CHECK-VEC2:           %[[sub:.*]] = subi %[[s]], %[[j]] : index
325// CHECK-VEC2:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
326// CHECK-VEC2:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
327// CHECK-VEC2:           %[[zj:.*]] = zexti %[[lj]] : vector<16xi32> to vector<16xi64>
328// CHECK-VEC2:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
329// CHECK-VEC2:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
330// CHECK-VEC2:           %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
331// CHECK-VEC2:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32>
332// CHECK-VEC2:         }
333// CHECK-VEC2:       }
334// CHECK-VEC2:       return
335//
336// CHECK-VEC3-LABEL: func @mul_ds
337// CHECK-VEC3-DAG:   %[[c0:.*]] = constant 0 : index
338// CHECK-VEC3-DAG:   %[[c1:.*]] = constant 1 : index
339// CHECK-VEC3-DAG:   %[[c16:.*]] = constant 16 : index
340// CHECK-VEC3-DAG:   %[[c512:.*]] = constant 512 : index
341// CHECK-VEC3:       scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] {
342// CHECK-VEC3:         %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref<?xi32>
343// CHECK-VEC3:         %[[a:.*]] = zexti %[[p]] : i32 to i64
344// CHECK-VEC3:         %[[q:.*]] = index_cast %[[a]] : i64 to index
345// CHECK-VEC3:         %[[a:.*]] = addi %[[i]], %[[c1]] : index
346// CHECK-VEC3:         %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref<?xi32>
347// CHECK-VEC3:         %[[b:.*]] = zexti %[[r]] : i32 to i64
348// CHECK-VEC3:         %[[s:.*]] = index_cast %[[b]] : i64 to index
349// CHECK-VEC3:         scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] {
350// CHECK-VEC3:           %[[sub:.*]] = subi %[[s]], %[[j]] : index
351// CHECK-VEC3:           %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1>
352// CHECK-VEC3:           %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xi32>, vector<16xi1>, vector<16xi32> into vector<16xi32>
353// CHECK-VEC3:           %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
354// CHECK-VEC3:           %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
355// CHECK-VEC3:           %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32>
356// CHECK-VEC3:           vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
357// CHECK-VEC3:         }
358// CHECK-VEC3:       }
359// CHECK-VEC3:       return
360//
361func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> {
362  %0 = linalg.generic #trait_mul_ds
363    ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>)
364    outs(%argx: tensor<512x1024xf32>) {
365      ^bb(%a: f32, %b: f32, %x: f32):
366        %0 = mulf %a, %b : f32
367        linalg.yield %0 : f32
368  } -> tensor<512x1024xf32>
369  return %0 : tensor<512x1024xf32>
370}
371