1// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
2
3// CHECK-LABEL: @argmax_nofold
4func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
5  // CHECK: "tosa.argmax"
6  %0 = "tosa.argmax"(%arg0) {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
7  return %0 : tensor<?x1xf32>
8}
9
10// -----
11
12// CHECK-LABEL: @cast_fold
13func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
14  // CHECK: return %arg0
15  %0 = "tosa.cast"(%arg0) : (tensor<?x1xf32>) -> tensor<?x1xf32>
16  return %0 : tensor<?x1xf32>
17}
18
19// -----
20
21// CHECK-LABEL: @cast_nofold
22func @cast_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
23  // CHECK: "tosa.cast"
24  %0 = "tosa.cast"(%arg0) : (tensor<?x1xf32>) -> tensor<?x1xi32>
25  return %0 : tensor<?x1xi32>
26}
27
28// -----
29
30// CHECK-LABEL: @concat_fold
31func @concat_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
32  // CHECK: return %arg0
33  %0 = "tosa.concat"(%arg0) {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
34  return %0 : tensor<?x1xf32>
35}
36
37// -----
38
39// CHECK-LABEL: @concat_fold_cast
40func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
41  // CHECK: %[[VAR0:.*]] = tensor.cast %arg0
42  // CHECK: return %[[VAR0]]
43  %0 = "tosa.concat"(%arg0) {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x?xf32>
44  return %0 : tensor<?x?xf32>
45}
46
47// -----
48
49// CHECK-LABEL: @reduce_all_fold
50func @reduce_all_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
51  // CHECK: return %arg0
52  %0 = "tosa.reduce_all"(%arg0) {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
53  return %0 : tensor<?x1xf32>
54}
55
56// -----
57
58// CHECK-LABEL: @reduce_all_nofold
59func @reduce_all_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
60  // CHECK: "tosa.reduce_all"
61  %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
62  return %0 : tensor<?x1xf32>
63}
64
65// -----
66
67// CHECK-LABEL: @reduce_any_fold
68func @reduce_any_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
69  // CHECK: return %arg0
70  %0 = "tosa.reduce_any"(%arg0) {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
71  return %0 : tensor<?x1xf32>
72}
73
74// -----
75
76// CHECK-LABEL: @reduce_any_nofold
77func @reduce_any_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
78  // CHECK: "tosa.reduce_any"
79  %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
80  return %0 : tensor<?x1xf32>
81}
82
83// -----
84
85// CHECK-LABEL: @reduce_max_fold
86func @reduce_max_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
87  // CHECK: return %arg0
88  %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
89  return %0 : tensor<?x1xf32>
90}
91
92// -----
93
94// CHECK-LABEL: @reduce_max_nofold
95func @reduce_max_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
96  // CHECK: "tosa.reduce_max"
97  %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
98  return %0 : tensor<?x1xf32>
99}
100
101// -----
102
103// CHECK-LABEL: @reduce_min_fold
104func @reduce_min_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
105  // CHECK: return %arg0
106  %0 = "tosa.reduce_min"(%arg0) {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
107  return %0 : tensor<?x1xf32>
108}
109
110// -----
111
112// CHECK-LABEL: @reduce_min_nofold
113func @reduce_min_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
114  // CHECK: "tosa.reduce_min"
115  %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
116  return %0 : tensor<?x1xf32>
117}
118
119// -----
120
121// CHECK-LABEL: @reduce_prod_fold
122func @reduce_prod_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
123  // CHECK: return %arg0
124  %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
125  return %0 : tensor<?x1xf32>
126}
127
128// -----
129
130// CHECK-LABEL: @reduce_prod_nofold
131func @reduce_prod_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
132  // CHECK: "tosa.reduce_prod"
133  %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
134  return %0 : tensor<?x1xf32>
135}
136
137// -----
138
139// CHECK-LABEL: @reduce_sum_fold
140func @reduce_sum_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
141  // CHECK: return %arg0
142  %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
143  return %0 : tensor<?x1xf32>
144}
145
146// -----
147
148// CHECK-LABEL: @reduce_sum_nofold
149func @reduce_sum_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
150  // CHECK: "tosa.reduce_sum"
151  %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64}: (tensor<?x1xf32>) -> tensor<?x1xf32>
152  return %0 : tensor<?x1xf32>
153}
154
155// -----
156
157// CHECK-LABEL: @reshape_canonicalize
158func @reshape_canonicalize(%arg0: tensor<?x10xf32>) -> tensor<?x10xf32> {
159  // CHECK: return %arg0
160  %0 = "tosa.reshape"(%arg0) {new_shape = [-1, 10]}: (tensor<?x10xf32>) -> tensor<?x10xf32>
161  return %0 : tensor<?x10xf32>
162}
163
164// -----
165
166// CHECK-LABEL: @reshape_canonicalize_double
167func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
168  // CHECK: %[[VAR0:.+]] = "tosa.reshape"(%arg0) {new_shape = [-1, 5]}
169  // CHECK: return %[[VAR0]]
170  %0 = "tosa.reshape"(%arg0) {new_shape = [5, -1]}: (tensor<?x10xf32>) -> tensor<5x?xf32>
171  %1 = "tosa.reshape"(%0) {new_shape = [-1, 5]}: (tensor<5x?xf32>) -> tensor<?x5xf32>
172  return %1 : tensor<?x5xf32>
173}
174
175// -----
176
177// CHECK-LABEL: @slice_fold
178func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
179  // CHECK: return %arg0
180  %0 = "tosa.slice"(%arg0) { size = [3, 4], start = [0, 0]}: (tensor<3x4xf32>) -> tensor<3x4xf32>
181  return %0 : tensor<3x4xf32>
182}
183
184// -----
185
186// CHECK-LABEL: @slice_nofold
187func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
188  // CHECK: "tosa.slice"
189  %0 = "tosa.slice"(%arg0) { size = [3, 4], start = [0, 0]}: (tensor<?x4xf32>) -> tensor<?x4xf32>
190  return %0 : tensor<?x4xf32>
191}
192
193// -----
194
195// CHECK-LABEL: @tile_fold
196func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
197  // CHECK: return %arg0
198  %0 = "tosa.tile"(%arg0) { multiples = [1, 1] }: (tensor<3x4xf32>) -> tensor<3x4xf32>
199  return %0 : tensor<3x4xf32>
200}
201
202// -----
203
204// CHECK-LABEL: @tile_nofold
205func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
206  // CHECK: "tosa.tile"
207  %0 = "tosa.tile"(%arg0) { multiples = [1, 2] }: (tensor<3x4xf32>) -> tensor<3x8xf32>
208  return %0 : tensor<3x8xf32>
209}
210
211// -----
212
213// CHECK-LABEL: @transpose_fold
214func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
215  // CHECK: return %arg0
216  %0 = constant dense<[0, 1]> : tensor<2xi32>
217  %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32>
218  return %1 : tensor<3x4xf32>
219}
220
221// -----
222
223// CHECK-LABEL: @transpose_nofold
224func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
225  // CHECK: "tosa.transpose"
226  %0 = constant dense<[1, 0]> : tensor<2xi32>
227  %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
228  return %1 : tensor<3x3xf32>
229}
230
231// -----
232
233// CHECK-LABEL: @transpose_nofold_shape
234func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
235  // CHECK: "tosa.transpose"
236  %0 = constant dense<[0, 1]> : tensor<2xi32>
237  %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
238  return %1 : tensor<?x?xf32>
239}
240
241// -----
242
243// CHECK-LABEL: @transpose_fold_splat
244func @transpose_fold_splat() -> tensor<3x2xf32> {
245  %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
246  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
247  //               CHECK: %[[CST:.+]] = "tosa.const"()
248  // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
249  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
250  // CHECK: return %[[CST]]
251  return %1 : tensor<3x2xf32>
252}
253
254// -----
255
256// CHECK-LABEL: @transpose_fold_2d_float
257func @transpose_fold_2d_float() -> tensor<3x2xf32> {
258  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
259  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
260  //               CHECK: %[[CST:.+]] = "tosa.const"()
261  // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
262  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
263  // CHECK: return %[[CST]]
264  return %1 : tensor<3x2xf32>
265}
266
267// -----
268
269// CHECK-LABEL: @transpose_fold_4d_int
270func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
271  %input = "tosa.const"() {value = dense<[[
272    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
273    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
274  ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
275  %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
276  //               CHECK: %[[CST:.+]] = "tosa.const"()
277  // CHECK-SAME{LITERAL}: value = dense<[
278  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
279  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
280  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
281  // CHECK-SAME{LITERAL}: ]>
282  %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
283  // CHECK: return %[[CST]]
284  return %1 : tensor<3x1x4x2xi32>
285}
286
287// -----
288
289// CHECK-LABEL: @transpose_nofold_non_cst_input
290func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
291  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
292  // CHECK: tosa.transpose
293  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
294  return %1 : tensor<3x2xf32>
295}
296
297// -----
298
299// CHECK-LABEL: @transpose_nofold_non_cst_perms
300func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
301  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
302  // CHECK: tosa.transpose
303  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
304  return %1 : tensor<3x2xf32>
305}
306
307// -----
308
309// CHECK-LABEL: @transpose_nofold_multi_users
310func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
311  %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
312  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
313  // CHECK: tosa.transpose
314  %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
315  return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
316}
317
318// -----
319
320// CHECK-LABEL: @transpose_nofold_quantized_types
321func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
322  %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
323  %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8>
324  // CHECK: tosa.transpose
325  %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
326  return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
327}
328