1// RUN: mlir-opt %s -test-sparsification="ptr-type=1 ind-type=1" | \
2// RUN:   FileCheck %s --check-prefix=CHECK-TYPE0
3// RUN: mlir-opt %s -test-sparsification="ptr-type=1 ind-type=2" | \
4// RUN:   FileCheck %s --check-prefix=CHECK-TYPE1
5// RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=1" | \
6// RUN:   FileCheck %s --check-prefix=CHECK-TYPE2
7// RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=2" | \
8// RUN:   FileCheck %s --check-prefix=CHECK-TYPE3
9// RUN: mlir-opt %s -test-sparsification="ptr-type=3 ind-type=3" | \
10// RUN:   FileCheck %s --check-prefix=CHECK-TYPE4
11// RUN: mlir-opt %s -test-sparsification="ptr-type=4 ind-type=4" | \
12// RUN:   FileCheck %s --check-prefix=CHECK-TYPE5
13
14#trait_mul_1d = {
15  indexing_maps = [
16    affine_map<(i) -> (i)>,  // a
17    affine_map<(i) -> (i)>,  // b
18    affine_map<(i) -> (i)>   // x (out)
19  ],
20  sparse = [
21    [ "S" ],  // a
22    [ "D" ],  // b
23    [ "D" ]   // x
24  ],
25  iterator_types = ["parallel"],
26  doc = "x(i) = a(i) * b(i)"
27}
28
29// CHECK-TYPE0-LABEL: func @mul_dd(
30// CHECK-TYPE0: %[[C0:.*]] = constant 0 : index
31// CHECK-TYPE0: %[[C1:.*]] = constant 1 : index
32// CHECK-TYPE0: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi64>
33// CHECK-TYPE0: %[[B0:.*]] = index_cast %[[P0]] : i64 to index
34// CHECK-TYPE0: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi64>
35// CHECK-TYPE0: %[[B1:.*]] = index_cast %[[P1]] : i64 to index
36// CHECK-TYPE0: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
37// CHECK-TYPE0:   %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi64>
38// CHECK-TYPE0:   %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index
39// CHECK-TYPE0:   %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
40// CHECK-TYPE0:   %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
41// CHECK-TYPE0:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
42// CHECK-TYPE0:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
43// CHECK-TYPE0: }
44
45// CHECK-TYPE1-LABEL: func @mul_dd(
46// CHECK-TYPE1: %[[C0:.*]] = constant 0 : index
47// CHECK-TYPE1: %[[C1:.*]] = constant 1 : index
48// CHECK-TYPE1: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi64>
49// CHECK-TYPE1: %[[B0:.*]] = index_cast %[[P0]] : i64 to index
50// CHECK-TYPE1: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi64>
51// CHECK-TYPE1: %[[B1:.*]] = index_cast %[[P1]] : i64 to index
52// CHECK-TYPE1: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
53// CHECK-TYPE1:   %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi32>
54// CHECK-TYPE1:   %[[INDC:.*]] = index_cast %[[IND0]] : i32 to index
55// CHECK-TYPE1:   %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
56// CHECK-TYPE1:   %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
57// CHECK-TYPE1:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
58// CHECK-TYPE1:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
59// CHECK-TYPE1: }
60
61// CHECK-TYPE2-LABEL: func @mul_dd(
62// CHECK-TYPE2: %[[C0:.*]] = constant 0 : index
63// CHECK-TYPE2: %[[C1:.*]] = constant 1 : index
64// CHECK-TYPE2: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi32>
65// CHECK-TYPE2: %[[B0:.*]] = index_cast %[[P0]] : i32 to index
66// CHECK-TYPE2: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi32>
67// CHECK-TYPE2: %[[B1:.*]] = index_cast %[[P1]] : i32 to index
68// CHECK-TYPE2: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
69// CHECK-TYPE2:   %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi64>
70// CHECK-TYPE2:   %[[INDC:.*]] = index_cast %[[IND0]] : i64 to index
71// CHECK-TYPE2:   %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
72// CHECK-TYPE2:   %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
73// CHECK-TYPE2:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
74// CHECK-TYPE2:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
75// CHECK-TYPE2: }
76
77// CHECK-TYPE3-LABEL: func @mul_dd(
78// CHECK-TYPE3: %[[C0:.*]] = constant 0 : index
79// CHECK-TYPE3: %[[C1:.*]] = constant 1 : index
80// CHECK-TYPE3: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi32>
81// CHECK-TYPE3: %[[B0:.*]] = index_cast %[[P0]] : i32 to index
82// CHECK-TYPE3: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi32>
83// CHECK-TYPE3: %[[B1:.*]] = index_cast %[[P1]] : i32 to index
84// CHECK-TYPE3: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
85// CHECK-TYPE3:   %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi32>
86// CHECK-TYPE3:   %[[INDC:.*]] = index_cast %[[IND0]] : i32 to index
87// CHECK-TYPE3:   %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
88// CHECK-TYPE3:   %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
89// CHECK-TYPE3:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
90// CHECK-TYPE3:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
91// CHECK-TYPE3: }
92
93// CHECK-TYPE4-LABEL: func @mul_dd(
94// CHECK-TYPE4: %[[C0:.*]] = constant 0 : index
95// CHECK-TYPE4: %[[C1:.*]] = constant 1 : index
96// CHECK-TYPE4: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi16>
97// CHECK-TYPE4: %[[B0:.*]] = index_cast %[[P0]] : i16 to index
98// CHECK-TYPE4: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi16>
99// CHECK-TYPE4: %[[B1:.*]] = index_cast %[[P1]] : i16 to index
100// CHECK-TYPE4: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
101// CHECK-TYPE4:   %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi16>
102// CHECK-TYPE4:   %[[INDC:.*]] = index_cast %[[IND0]] : i16 to index
103// CHECK-TYPE4:   %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
104// CHECK-TYPE4:   %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
105// CHECK-TYPE4:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
106// CHECK-TYPE4:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
107// CHECK-TYPE4: }
108
109// CHECK-TYPE5-LABEL: func @mul_dd(
110// CHECK-TYPE5: %[[C0:.*]] = constant 0 : index
111// CHECK-TYPE5: %[[C1:.*]] = constant 1 : index
112// CHECK-TYPE5: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref<?xi8>
113// CHECK-TYPE5: %[[B0:.*]] = index_cast %[[P0]] : i8 to index
114// CHECK-TYPE5: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref<?xi8>
115// CHECK-TYPE5: %[[B1:.*]] = index_cast %[[P1]] : i8 to index
116// CHECK-TYPE5: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
117// CHECK-TYPE5:   %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref<?xi8>
118// CHECK-TYPE5:   %[[INDC:.*]] = index_cast %[[IND0]] : i8 to index
119// CHECK-TYPE5:   %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref<?xf64>
120// CHECK-TYPE5:   %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64>
121// CHECK-TYPE5:   %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64
122// CHECK-TYPE5:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
123// CHECK-TYPE5: }
124
125func @mul_dd(%arga: tensor<32xf64>, %argb: tensor<32xf64>) -> tensor<32xf64> {
126  %0 = linalg.generic #trait_mul_1d
127     ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>)
128    outs(%arga : tensor<32xf64>) {
129      ^bb(%a: f64, %b: f64, %s: f64):
130        %0 = mulf %a, %b  : f64
131        linalg.yield %0 : f64
132  } -> tensor<32xf64>
133  return %0 : tensor<32xf64>
134}
135
136