1// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ 2// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ 3// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ 4// RUN: FileCheck %s 5 6#dotp_accesses = [ 7 affine_map<(i) -> (i)>, 8 affine_map<(i) -> (i)>, 9 affine_map<(i) -> ()> 10] 11#dotp_trait = { 12 indexing_maps = #dotp_accesses, 13 iterator_types = ["reduction"] 14} 15 16#matvec_accesses = [ 17 affine_map<(i, j) -> (i, j)>, 18 affine_map<(i, j) -> (j)>, 19 affine_map<(i, j) -> (i)> 20] 21#matvec_trait = { 22 indexing_maps = #matvec_accesses, 23 iterator_types = ["parallel", "reduction"] 24} 25 26#mattransvec_accesses = [ 27 affine_map<(i, j) -> (j, i)>, 28 affine_map<(i, j) -> (j)>, 29 affine_map<(i, j) -> (i)> 30] 31#mattransvec_trait = { 32 indexing_maps = #mattransvec_accesses, 33 iterator_types = ["parallel", "reduction"] 34} 35 36#matmat_accesses = [ 37 affine_map<(i, j, k) -> (i, k)>, 38 affine_map<(i, j, k) -> (k, j)>, 39 affine_map<(i, j, k) -> (i, j)> 40] 41#matmat_trait = { 42 indexing_maps = #matmat_accesses, 43 iterator_types = ["parallel", "parallel", "reduction"] 44} 45 46#mattransmat_accesses = [ 47 affine_map<(i, j, k) -> (k, i)>, 48 affine_map<(i, j, k) -> (k, j)>, 49 affine_map<(i, j, k) -> (i, j)> 50] 51#mattransmat_trait = { 52 indexing_maps = #mattransmat_accesses, 53 iterator_types = ["parallel", "parallel", "reduction"] 54} 55 56#matmattrans_accesses = [ 57 affine_map<(i, j, k) -> (i, k)>, 58 affine_map<(i, j, k) -> (j, k)>, 59 affine_map<(i, j, k) -> (i, j)> 60] 61#matmattrans_trait = { 62 indexing_maps = #matmattrans_accesses, 63 iterator_types = ["parallel", "parallel", "reduction"] 64} 65 66#mattransmattrans_accesses = [ 67 affine_map<(i, j, k) -> (k, i)>, 68 affine_map<(i, j, k) -> (j, k)>, 69 affine_map<(i, j, k) -> (i, j)> 70] 71#mattransmattrans_trait = { 72 indexing_maps = #mattransmattrans_accesses, 73 iterator_types = ["parallel", "parallel", "reduction"] 74} 75 76#matmat_then_trans_accesses = [ 77 affine_map<(i, j, k) -> (i, k)>, 78 affine_map<(i, j, k) -> (k, j)>, 79 affine_map<(i, j, k) -> (j, i)> 80] 81#matmat_then_trans_trait = { 82 indexing_maps = #matmat_then_trans_accesses, 83 iterator_types = ["parallel", "parallel", "reduction"] 84} 85 86#contract2d_accesses = [ 87 affine_map<(i, j) -> (i, j)>, 88 affine_map<(i, j) -> (i, j)>, 89 affine_map<(i, j) -> ()> 90] 91#contract2d_trait = { 92 indexing_maps = #contract2d_accesses, 93 iterator_types = ["reduction", "reduction"] 94} 95 96#contract2d_alt_accesses = [ 97 affine_map<(i, j) -> (j, i)>, 98 affine_map<(i, j) -> (j, i)>, 99 affine_map<(i, j) -> ()> 100] 101#contract2d_alt_trait = { 102 indexing_maps = #contract2d_alt_accesses, 103 iterator_types = ["reduction", "reduction"] 104} 105 106#contract2d_trans_accesses = [ 107 affine_map<(i, j) -> (i, j)>, 108 affine_map<(i, j) -> (j, i)>, 109 affine_map<(i, j) -> ()> 110] 111#contract2d_trans_trait = { 112 indexing_maps = #contract2d_trans_accesses, 113 iterator_types = ["reduction", "reduction"] 114} 115 116#contract2d_trans_alt_accesses = [ 117 affine_map<(i, j) -> (j, i)>, 118 affine_map<(i, j) -> (i, j)>, 119 affine_map<(i, j) -> ()> 120] 121#contract2d_trans_alt_trait = { 122 indexing_maps = #contract2d_trans_alt_accesses, 123 iterator_types = ["reduction", "reduction"] 124} 125 126#column_major_matmat_accesses = [ 127 affine_map<(i, j, k) -> (k, j)>, 128 affine_map<(i, j, k) -> (i, k)>, 129 affine_map<(i, j, k) -> (j, i)> 130] 131#column_major_matmat_trait = { 132 indexing_maps = #column_major_matmat_accesses, 133 iterator_types = ["parallel", "parallel", "reduction"] 134} 135 136func @entry() { 137 %f0 = constant 0.0: f32 138 %f1 = constant 1.0: f32 139 %f2 = constant 2.0: f32 140 %f3 = constant 3.0: f32 141 %f4 = constant 4.0: f32 142 %f5 = constant 5.0: f32 143 %f6 = constant 6.0: f32 144 %f7 = constant 7.0: f32 145 %f8 = constant 8.0: f32 146 147 // Zero vectors. 148 %z1 = vector.broadcast %f0 : f32 to vector<2xf32> 149 %z2 = vector.broadcast %f0 : f32 to vector<2x2xf32> 150 %z3 = vector.broadcast %f0 : f32 to vector<3x4xf32> 151 152 // Construct test vectors. 153 %0 = vector.broadcast %f1 : f32 to vector<2xf32> 154 %a = vector.insert %f2, %0[1] : f32 into vector<2xf32> 155 %1 = vector.broadcast %f3 : f32 to vector<2xf32> 156 %b = vector.insert %f4, %1[1] : f32 into vector<2xf32> 157 %2 = vector.broadcast %f5 : f32 to vector<2xf32> 158 %c = vector.insert %f6, %2[1] : f32 into vector<2xf32> 159 %3 = vector.broadcast %f7 : f32 to vector<2xf32> 160 %d = vector.insert %f8, %3[1] : f32 into vector<2xf32> 161 162 vector.print %a : vector<2xf32> 163 vector.print %b : vector<2xf32> 164 vector.print %c : vector<2xf32> 165 vector.print %d : vector<2xf32> 166 // 167 // test vectors: 168 // 169 // CHECK: ( 1, 2 ) 170 // CHECK: ( 3, 4 ) 171 // CHECK: ( 5, 6 ) 172 // CHECK: ( 7, 8 ) 173 174 // Construct test matrices. 175 %4 = vector.broadcast %f0 : f32 to vector<2x2xf32> 176 %5 = vector.insert %a, %4[0] : vector<2xf32> into vector<2x2xf32> 177 %A = vector.insert %b, %5[1] : vector<2xf32> into vector<2x2xf32> 178 %6 = vector.broadcast %f0 : f32 to vector<2x2xf32> 179 %7 = vector.insert %c, %6[0] : vector<2xf32> into vector<2x2xf32> 180 %B = vector.insert %d, %7[1] : vector<2xf32> into vector<2x2xf32> 181 %8 = vector.broadcast %f0 : f32 to vector<3x2xf32> 182 %9 = vector.insert %a, %8[0] : vector<2xf32> into vector<3x2xf32> 183 %10 = vector.insert %b, %9[1] : vector<2xf32> into vector<3x2xf32> 184 %C = vector.insert %c, %10[2] : vector<2xf32> into vector<3x2xf32> 185 %11 = vector.tuple %A, %B : vector<2x2xf32>, vector<2x2xf32> 186 %D = vector.insert_slices %11, [2, 2], [1, 1] 187 : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<2x4xf32> 188 189 vector.print %A : vector<2x2xf32> 190 vector.print %B : vector<2x2xf32> 191 vector.print %C : vector<3x2xf32> 192 vector.print %D : vector<2x4xf32> 193 // 194 // test matrices: 195 // 196 // CHECK: ( ( 1, 2 ), ( 3, 4 ) ) 197 // CHECK: ( ( 5, 6 ), ( 7, 8 ) ) 198 // CHECK: ( ( 1, 2 ), ( 3, 4 ), ( 5, 6 ) ) 199 // CHECK: ( ( 1, 2, 5, 6 ), ( 3, 4, 7, 8 ) ) 200 201 // Contraction: dot-product a x b 202 %dp1 = vector.contract #dotp_trait %a, %b, %f0 203 : vector<2xf32>, vector<2xf32> into f32 204 %dp2 = vector.contract #dotp_trait %a, %b, %f1 205 : vector<2xf32>, vector<2xf32> into f32 206 207 vector.print %dp1 : f32 208 vector.print %dp2 : f32 209 // 210 // dot products: 211 // 212 // CHECK: 11 213 // CHECK: 12 214 215 // Contraction: matrix-vector A x c 216 %mv1 = vector.contract #matvec_trait %A, %c, %z1 217 : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 218 %mv2 = vector.contract #matvec_trait %A, %c, %a 219 : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 220 221 vector.print %mv1 : vector<2xf32> 222 vector.print %mv2 : vector<2xf32> 223 // 224 // matrix x vector: 225 // 226 // CHECK: ( 17, 39 ) 227 // CHECK: ( 18, 41 ) 228 229 // Contraction: matrix-trans-vector A^T x c 230 %mv3 = vector.contract #mattransvec_trait %A, %c, %z1 231 : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 232 %mv4 = vector.contract #mattransvec_trait %A, %c, %a 233 : vector<2x2xf32>, vector<2xf32> into vector<2xf32> 234 235 vector.print %mv3 : vector<2xf32> 236 vector.print %mv4 : vector<2xf32> 237 // 238 // matrix x vector: 239 // 240 // CHECK: ( 23, 34 ) 241 // CHECK: ( 24, 36 ) 242 243 // Contraction: matrix-matrix A x B 244 %mm1 = vector.contract #matmat_trait %A, %B, %z2 245 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 246 %mm2 = vector.contract #matmat_trait %A, %B, %A 247 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 248 249 vector.print %mm1 : vector<2x2xf32> 250 vector.print %mm2 : vector<2x2xf32> 251 // 252 // matrix x matrix: 253 // 254 // CHECK: ( ( 19, 22 ), ( 43, 50 ) ) 255 // CHECK: ( ( 20, 24 ), ( 46, 54 ) ) 256 257 // Contraction: matrix-matrix A x B where A, B, C have column-major layout. 258 // ( 1 * 5 + 3 * 6 = 23, 2 * 5 + 4 * 6 = 34) 259 // ( 1 * 7 + 3 * 8 = 31, 2 * 7 + 4 * 8 = 46) 260 // + 261 // ( ( 1, 2 ), ( 3, 4 ) ) 262 %llvm_matrix_column_major_mm0 = 263 vector.contract #column_major_matmat_trait %A, %B, %z2 264 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 265 %llvm_matrix_column_major_mm1 = 266 vector.contract #column_major_matmat_trait %A, %B, %A 267 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 268 269 vector.print %llvm_matrix_column_major_mm0 : vector<2x2xf32> 270 vector.print %llvm_matrix_column_major_mm1 : vector<2x2xf32> 271 // 272 // matrix x matrix: 273 // 274 // CHECK: ( ( 23, 31 ), ( 34, 46 ) ) 275 // CHECK: ( ( 24, 33 ), ( 37, 50 ) ) 276 277 // Contraction: matrix-trans-matrix A^T x B 278 %mm3 = vector.contract #mattransmat_trait %A, %B, %z2 279 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 280 %mm4 = vector.contract #mattransmat_trait %A, %B, %A 281 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 282 283 vector.print %mm3 : vector<2x2xf32> 284 vector.print %mm4 : vector<2x2xf32> 285 // 286 // matrix x matrix: 287 // 288 // CHECK: ( ( 26, 30 ), ( 38, 44 ) ) 289 // CHECK: ( ( 27, 32 ), ( 41, 48 ) ) 290 291 // Contraction: matrix-matrix-trans A x B^T 292 %mm5 = vector.contract #matmattrans_trait %A, %B, %z2 293 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 294 %mm6 = vector.contract #matmattrans_trait %A, %B, %A 295 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 296 297 vector.print %mm5 : vector<2x2xf32> 298 vector.print %mm6 : vector<2x2xf32> 299 // 300 // matrix x matrix: 301 // 302 // CHECK: ( ( 17, 23 ), ( 39, 53 ) ) 303 // CHECK: ( ( 18, 25 ), ( 42, 57 ) ) 304 305 // Contraction: matrix-trans-matrix-trans A^T x B^T 306 %mm7 = vector.contract #mattransmattrans_trait %A, %B, %z2 307 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 308 %mm8 = vector.contract #mattransmattrans_trait %A, %B, %A 309 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 310 311 vector.print %mm7 : vector<2x2xf32> 312 vector.print %mm8 : vector<2x2xf32> 313 // 314 // matrix x matrix: 315 // 316 // CHECK: ( ( 23, 31 ), ( 34, 46 ) ) 317 // CHECK: ( ( 24, 33 ), ( 37, 50 ) ) 318 319 // Contraction: matrix-matrix-then-trans (A x B)^T 320 %mm9 = vector.contract #matmat_then_trans_trait %A, %B, %z2 321 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 322 %mm10 = vector.contract #matmat_then_trans_trait %A, %B, %A 323 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 324 325 vector.print %mm9 : vector<2x2xf32> 326 vector.print %mm10 : vector<2x2xf32> 327 // 328 // matrix x matrix: 329 // 330 // CHECK: ( ( 19, 43 ), ( 22, 50 ) ) 331 // CHECK: ( ( 20, 45 ), ( 25, 54 ) ) 332 333 // Contraction: matrix-matrix C x D 334 %mm11 = vector.contract #matmat_trait %C, %D, %z3 335 : vector<3x2xf32>, vector<2x4xf32> into vector<3x4xf32> 336 %mm12 = vector.contract #matmat_trait %C, %D, %mm11 337 : vector<3x2xf32>, vector<2x4xf32> into vector<3x4xf32> 338 339 vector.print %mm11 : vector<3x4xf32> 340 vector.print %mm12 : vector<3x4xf32> 341 // CHECK: ( ( 7, 10, 19, 22 ), ( 15, 22, 43, 50 ), ( 23, 34, 67, 78 ) ) 342 // CHECK: ( ( 14, 20, 38, 44 ), ( 30, 44, 86, 100 ), ( 46, 68, 134, 156 ) ) 343 344 // Contractions in 2D. 345 %c1 = vector.contract #contract2d_trait %A, %B, %f0 346 : vector<2x2xf32>, vector<2x2xf32> into f32 347 %c2 = vector.contract #contract2d_trait %A, %B, %f1 348 : vector<2x2xf32>, vector<2x2xf32> into f32 349 %c3 = vector.contract #contract2d_alt_trait %A, %B, %f0 350 : vector<2x2xf32>, vector<2x2xf32> into f32 351 %c4 = vector.contract #contract2d_alt_trait %A, %B, %f1 352 : vector<2x2xf32>, vector<2x2xf32> into f32 353 %c5 = vector.contract #contract2d_trans_trait %A, %B, %f0 354 : vector<2x2xf32>, vector<2x2xf32> into f32 355 %c6 = vector.contract #contract2d_trans_trait %A, %B, %f1 356 : vector<2x2xf32>, vector<2x2xf32> into f32 357 %c7 = vector.contract #contract2d_trans_alt_trait %A, %B, %f0 358 : vector<2x2xf32>, vector<2x2xf32> into f32 359 %c8 = vector.contract #contract2d_trans_alt_trait %A, %B, %f1 360 : vector<2x2xf32>, vector<2x2xf32> into f32 361 362 vector.print %c1 : f32 363 vector.print %c2 : f32 364 vector.print %c3 : f32 365 vector.print %c4 : f32 366 vector.print %c5 : f32 367 vector.print %c6 : f32 368 vector.print %c7 : f32 369 vector.print %c8 : f32 370 // 371 // 2D contractions: 372 // 373 // CHECK: 70 374 // CHECK: 71 375 // CHECK: 70 376 // CHECK: 71 377 // CHECK: 69 378 // CHECK: 70 379 // CHECK: 69 380 // CHECK: 70 381 382 return 383} 384