1// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
2// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
3// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
4// RUN: FileCheck %s
5
6#dotp_accesses = [
7  affine_map<(i) -> (i)>,
8  affine_map<(i) -> (i)>,
9  affine_map<(i) -> ()>
10]
11#dotp_trait = {
12  indexing_maps = #dotp_accesses,
13  iterator_types = ["reduction"]
14}
15
16#matvec_accesses = [
17  affine_map<(i, j) -> (i, j)>,
18  affine_map<(i, j) -> (j)>,
19  affine_map<(i, j) -> (i)>
20]
21#matvec_trait = {
22  indexing_maps = #matvec_accesses,
23  iterator_types = ["parallel", "reduction"]
24}
25
26#mattransvec_accesses = [
27  affine_map<(i, j) -> (j, i)>,
28  affine_map<(i, j) -> (j)>,
29  affine_map<(i, j) -> (i)>
30]
31#mattransvec_trait = {
32  indexing_maps = #mattransvec_accesses,
33  iterator_types = ["parallel", "reduction"]
34}
35
36#matmat_accesses = [
37  affine_map<(i, j, k) -> (i, k)>,
38  affine_map<(i, j, k) -> (k, j)>,
39  affine_map<(i, j, k) -> (i, j)>
40]
41#matmat_trait = {
42  indexing_maps = #matmat_accesses,
43  iterator_types = ["parallel", "parallel", "reduction"]
44}
45
46#mattransmat_accesses = [
47  affine_map<(i, j, k) -> (k, i)>,
48  affine_map<(i, j, k) -> (k, j)>,
49  affine_map<(i, j, k) -> (i, j)>
50]
51#mattransmat_trait = {
52  indexing_maps = #mattransmat_accesses,
53  iterator_types = ["parallel", "parallel", "reduction"]
54}
55
56#matmattrans_accesses = [
57  affine_map<(i, j, k) -> (i, k)>,
58  affine_map<(i, j, k) -> (j, k)>,
59  affine_map<(i, j, k) -> (i, j)>
60]
61#matmattrans_trait = {
62  indexing_maps = #matmattrans_accesses,
63  iterator_types = ["parallel", "parallel", "reduction"]
64}
65
66#mattransmattrans_accesses = [
67  affine_map<(i, j, k) -> (k, i)>,
68  affine_map<(i, j, k) -> (j, k)>,
69  affine_map<(i, j, k) -> (i, j)>
70]
71#mattransmattrans_trait = {
72  indexing_maps = #mattransmattrans_accesses,
73  iterator_types = ["parallel", "parallel", "reduction"]
74}
75
76#matmat_then_trans_accesses = [
77  affine_map<(i, j, k) -> (i, k)>,
78  affine_map<(i, j, k) -> (k, j)>,
79  affine_map<(i, j, k) -> (j, i)>
80]
81#matmat_then_trans_trait = {
82  indexing_maps = #matmat_then_trans_accesses,
83  iterator_types = ["parallel", "parallel", "reduction"]
84}
85
86#contract2d_accesses = [
87  affine_map<(i, j) -> (i, j)>,
88  affine_map<(i, j) -> (i, j)>,
89  affine_map<(i, j) -> ()>
90]
91#contract2d_trait = {
92  indexing_maps = #contract2d_accesses,
93  iterator_types = ["reduction", "reduction"]
94}
95
96#contract2d_alt_accesses = [
97  affine_map<(i, j) -> (j, i)>,
98  affine_map<(i, j) -> (j, i)>,
99  affine_map<(i, j) -> ()>
100]
101#contract2d_alt_trait = {
102  indexing_maps = #contract2d_alt_accesses,
103  iterator_types = ["reduction", "reduction"]
104}
105
106#contract2d_trans_accesses = [
107  affine_map<(i, j) -> (i, j)>,
108  affine_map<(i, j) -> (j, i)>,
109  affine_map<(i, j) -> ()>
110]
111#contract2d_trans_trait = {
112  indexing_maps = #contract2d_trans_accesses,
113  iterator_types = ["reduction", "reduction"]
114}
115
116#contract2d_trans_alt_accesses = [
117  affine_map<(i, j) -> (j, i)>,
118  affine_map<(i, j) -> (i, j)>,
119  affine_map<(i, j) -> ()>
120]
121#contract2d_trans_alt_trait = {
122  indexing_maps = #contract2d_trans_alt_accesses,
123  iterator_types = ["reduction", "reduction"]
124}
125
126#column_major_matmat_accesses = [
127  affine_map<(i, j, k) -> (k, j)>,
128  affine_map<(i, j, k) -> (i, k)>,
129  affine_map<(i, j, k) -> (j, i)>
130]
131#column_major_matmat_trait = {
132  indexing_maps = #column_major_matmat_accesses,
133  iterator_types = ["parallel", "parallel", "reduction"]
134}
135
136func @entry() {
137  %f0 = constant 0.0: f32
138  %f1 = constant 1.0: f32
139  %f2 = constant 2.0: f32
140  %f3 = constant 3.0: f32
141  %f4 = constant 4.0: f32
142  %f5 = constant 5.0: f32
143  %f6 = constant 6.0: f32
144  %f7 = constant 7.0: f32
145  %f8 = constant 8.0: f32
146
147  // Zero vectors.
148  %z1 = vector.broadcast %f0 : f32 to vector<2xf32>
149  %z2 = vector.broadcast %f0 : f32 to vector<2x2xf32>
150  %z3 = vector.broadcast %f0 : f32 to vector<3x4xf32>
151
152  // Construct test vectors.
153  %0 = vector.broadcast %f1 : f32 to vector<2xf32>
154  %a = vector.insert %f2, %0[1] : f32 into vector<2xf32>
155  %1 = vector.broadcast %f3 : f32 to vector<2xf32>
156  %b = vector.insert %f4, %1[1] : f32 into vector<2xf32>
157  %2 = vector.broadcast %f5 : f32 to vector<2xf32>
158  %c = vector.insert %f6, %2[1] : f32 into vector<2xf32>
159  %3 = vector.broadcast %f7 : f32 to vector<2xf32>
160  %d = vector.insert %f8, %3[1] : f32 into vector<2xf32>
161
162  vector.print %a : vector<2xf32>
163  vector.print %b : vector<2xf32>
164  vector.print %c : vector<2xf32>
165  vector.print %d : vector<2xf32>
166  //
167  // test vectors:
168  //
169  // CHECK: ( 1, 2 )
170  // CHECK: ( 3, 4 )
171  // CHECK: ( 5, 6 )
172  // CHECK: ( 7, 8 )
173
174  // Construct test matrices.
175  %4 = vector.broadcast %f0 : f32 to vector<2x2xf32>
176  %5 = vector.insert %a, %4[0] : vector<2xf32> into vector<2x2xf32>
177  %A = vector.insert %b, %5[1] : vector<2xf32> into vector<2x2xf32>
178  %6 = vector.broadcast %f0 : f32 to vector<2x2xf32>
179  %7 = vector.insert %c, %6[0] : vector<2xf32> into vector<2x2xf32>
180  %B = vector.insert %d, %7[1] : vector<2xf32> into vector<2x2xf32>
181  %8 = vector.broadcast %f0 : f32 to vector<3x2xf32>
182  %9 = vector.insert %a, %8[0] : vector<2xf32> into vector<3x2xf32>
183  %10 = vector.insert %b, %9[1] : vector<2xf32> into vector<3x2xf32>
184  %C = vector.insert %c, %10[2] : vector<2xf32> into vector<3x2xf32>
185  %11 = vector.tuple %A, %B : vector<2x2xf32>, vector<2x2xf32>
186  %D = vector.insert_slices %11, [2, 2], [1, 1]
187    : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<2x4xf32>
188
189  vector.print %A : vector<2x2xf32>
190  vector.print %B : vector<2x2xf32>
191  vector.print %C : vector<3x2xf32>
192  vector.print %D : vector<2x4xf32>
193  //
194  // test matrices:
195  //
196  // CHECK: ( ( 1, 2 ), ( 3, 4 ) )
197  // CHECK: ( ( 5, 6 ), ( 7, 8 ) )
198  // CHECK: ( ( 1, 2 ), ( 3, 4 ), ( 5, 6 ) )
199  // CHECK: ( ( 1, 2, 5, 6 ), ( 3, 4, 7, 8 ) )
200
201  // Contraction: dot-product a x b
202  %dp1 = vector.contract #dotp_trait %a, %b, %f0
203    : vector<2xf32>, vector<2xf32> into f32
204  %dp2 = vector.contract #dotp_trait %a, %b, %f1
205    : vector<2xf32>, vector<2xf32> into f32
206
207  vector.print %dp1 : f32
208  vector.print %dp2 : f32
209  //
210  // dot products:
211  //
212  // CHECK: 11
213  // CHECK: 12
214
215  // Contraction: matrix-vector A x c
216  %mv1 = vector.contract #matvec_trait %A, %c, %z1
217    : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
218  %mv2 = vector.contract #matvec_trait %A, %c, %a
219    : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
220
221  vector.print %mv1 : vector<2xf32>
222  vector.print %mv2 : vector<2xf32>
223  //
224  // matrix x vector:
225  //
226  // CHECK: ( 17, 39 )
227  // CHECK: ( 18, 41 )
228
229  // Contraction: matrix-trans-vector A^T x c
230  %mv3 = vector.contract #mattransvec_trait %A, %c, %z1
231    : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
232  %mv4 = vector.contract #mattransvec_trait %A, %c, %a
233    : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
234
235  vector.print %mv3 : vector<2xf32>
236  vector.print %mv4 : vector<2xf32>
237  //
238  // matrix x vector:
239  //
240  // CHECK: ( 23, 34 )
241  // CHECK: ( 24, 36 )
242
243  // Contraction: matrix-matrix A x B
244  %mm1 = vector.contract #matmat_trait %A, %B, %z2
245    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
246  %mm2 = vector.contract #matmat_trait %A, %B, %A
247    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
248
249  vector.print %mm1 : vector<2x2xf32>
250  vector.print %mm2 : vector<2x2xf32>
251  //
252  // matrix x matrix:
253  //
254  // CHECK: ( ( 19, 22 ), ( 43, 50 ) )
255  // CHECK: ( ( 20, 24 ), ( 46, 54 ) )
256
257  // Contraction: matrix-matrix A x B where A, B, C have column-major layout.
258  // ( 1 * 5 + 3 * 6 = 23, 2 * 5 + 4 * 6 = 34)
259  // ( 1 * 7 + 3 * 8 = 31, 2 * 7 + 4 * 8 = 46)
260  // +
261  // ( ( 1, 2 ), ( 3, 4 ) )
262  %llvm_matrix_column_major_mm0 =
263    vector.contract #column_major_matmat_trait %A, %B, %z2
264      : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
265  %llvm_matrix_column_major_mm1 =
266    vector.contract #column_major_matmat_trait %A, %B, %A
267      : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
268
269  vector.print %llvm_matrix_column_major_mm0 : vector<2x2xf32>
270  vector.print %llvm_matrix_column_major_mm1 : vector<2x2xf32>
271  //
272  // matrix x matrix:
273  //
274  // CHECK: ( ( 23, 31 ), ( 34, 46 ) )
275  // CHECK: ( ( 24, 33 ), ( 37, 50 ) )
276
277  // Contraction: matrix-trans-matrix A^T x B
278  %mm3 = vector.contract #mattransmat_trait %A, %B, %z2
279    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
280  %mm4 = vector.contract #mattransmat_trait %A, %B, %A
281    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
282
283  vector.print %mm3 : vector<2x2xf32>
284  vector.print %mm4 : vector<2x2xf32>
285  //
286  // matrix x matrix:
287  //
288  // CHECK: ( ( 26, 30 ), ( 38, 44 ) )
289  // CHECK: ( ( 27, 32 ), ( 41, 48 ) )
290
291  // Contraction: matrix-matrix-trans A x B^T
292  %mm5 = vector.contract #matmattrans_trait %A, %B, %z2
293    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
294  %mm6 = vector.contract #matmattrans_trait %A, %B, %A
295    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
296
297  vector.print %mm5 : vector<2x2xf32>
298  vector.print %mm6 : vector<2x2xf32>
299  //
300  // matrix x matrix:
301  //
302  // CHECK: ( ( 17, 23 ), ( 39, 53 ) )
303  // CHECK: ( ( 18, 25 ), ( 42, 57 ) )
304
305  // Contraction: matrix-trans-matrix-trans A^T x B^T
306  %mm7 = vector.contract #mattransmattrans_trait %A, %B, %z2
307    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
308  %mm8 = vector.contract #mattransmattrans_trait %A, %B, %A
309    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
310
311  vector.print %mm7 : vector<2x2xf32>
312  vector.print %mm8 : vector<2x2xf32>
313  //
314  // matrix x matrix:
315  //
316  // CHECK: ( ( 23, 31 ), ( 34, 46 ) )
317  // CHECK: ( ( 24, 33 ), ( 37, 50 ) )
318
319  // Contraction: matrix-matrix-then-trans (A x B)^T
320  %mm9 = vector.contract #matmat_then_trans_trait %A, %B, %z2
321    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
322  %mm10 = vector.contract #matmat_then_trans_trait %A, %B, %A
323    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
324
325  vector.print %mm9 : vector<2x2xf32>
326  vector.print %mm10 : vector<2x2xf32>
327  //
328  // matrix x matrix:
329  //
330  // CHECK: ( ( 19, 43 ), ( 22, 50 ) )
331  // CHECK: ( ( 20, 45 ), ( 25, 54 ) )
332
333  // Contraction: matrix-matrix C x D
334  %mm11 = vector.contract #matmat_trait %C, %D, %z3
335    : vector<3x2xf32>, vector<2x4xf32> into vector<3x4xf32>
336  %mm12 = vector.contract #matmat_trait %C, %D, %mm11
337    : vector<3x2xf32>, vector<2x4xf32> into vector<3x4xf32>
338
339  vector.print %mm11 : vector<3x4xf32>
340  vector.print %mm12 : vector<3x4xf32>
341  // CHECK: ( ( 7, 10, 19, 22 ), ( 15, 22, 43, 50 ), ( 23, 34, 67, 78 ) )
342  // CHECK: ( ( 14, 20, 38, 44 ), ( 30, 44, 86, 100 ), ( 46, 68, 134, 156 ) )
343
344  // Contractions in 2D.
345  %c1 = vector.contract #contract2d_trait %A, %B, %f0
346    : vector<2x2xf32>, vector<2x2xf32> into f32
347  %c2 = vector.contract #contract2d_trait %A, %B, %f1
348    : vector<2x2xf32>, vector<2x2xf32> into f32
349  %c3 = vector.contract #contract2d_alt_trait %A, %B, %f0
350    : vector<2x2xf32>, vector<2x2xf32> into f32
351  %c4 = vector.contract #contract2d_alt_trait %A, %B, %f1
352    : vector<2x2xf32>, vector<2x2xf32> into f32
353  %c5 = vector.contract #contract2d_trans_trait %A, %B, %f0
354    : vector<2x2xf32>, vector<2x2xf32> into f32
355  %c6 = vector.contract #contract2d_trans_trait %A, %B, %f1
356    : vector<2x2xf32>, vector<2x2xf32> into f32
357  %c7 = vector.contract #contract2d_trans_alt_trait %A, %B, %f0
358    : vector<2x2xf32>, vector<2x2xf32> into f32
359  %c8 = vector.contract #contract2d_trans_alt_trait %A, %B, %f1
360    : vector<2x2xf32>, vector<2x2xf32> into f32
361
362  vector.print %c1 : f32
363  vector.print %c2 : f32
364  vector.print %c3 : f32
365  vector.print %c4 : f32
366  vector.print %c5 : f32
367  vector.print %c6 : f32
368  vector.print %c7 : f32
369  vector.print %c8 : f32
370  //
371  // 2D contractions:
372  //
373  // CHECK: 70
374  // CHECK: 71
375  // CHECK: 70
376  // CHECK: 71
377  // CHECK: 69
378  // CHECK: 70
379  // CHECK: 69
380  // CHECK: 70
381
382  return
383}
384