1// Verify the printed output can be parsed.
2// RUN: mlir-opt %s | mlir-opt | FileCheck %s
3// Verify the generic form can be parsed.
4// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
5
6// CHECK-LABEL: shape_num_elements
7func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
8  %init = shape.const_size 1
9  %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
10    ^bb0(%index : index, %extent : !shape.size, %acc : !shape.size):
11      %acc_next = shape.mul %acc, %extent
12          : !shape.size, !shape.size -> !shape.size
13      shape.yield %acc_next : !shape.size
14  }
15  return %num_elements : !shape.size
16}
17
18// CHECK-LABEL: extent_tensor_num_elements
19func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
20  %init = constant 1 : index
21  %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
22    ^bb0(%index : index, %extent : index, %acc : index):
23      %acc_next = shape.mul %acc, %extent : index, index -> index
24      shape.yield %acc_next : index
25  }
26  return %num_elements : index
27}
28
29func @test_shape_num_elements_unknown() {
30  %0 = "shape.unknown_shape"() : () -> !shape.shape
31  %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
32  %2 = "shape.print"(%1) : (!shape.size) -> !shape.size
33  return
34}
35
36func @const_shape() {
37  %0 = shape.const_shape [1, 2, 3] : !shape.shape
38  %1 = shape.const_shape [4, 5, 6] : tensor<?xindex>
39  %2 = shape.const_shape [4, 5, 6] : tensor<3xindex>
40  return
41}
42
43func @test_shape_num_elements_fixed() {
44  %0 = shape.const_shape [1, 57, 92] : !shape.shape
45  %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
46  %3 = "shape.print"(%1) : (!shape.size) -> !shape.size
47  return
48}
49
50func @test_broadcast_fixed() {
51  %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
52  %1 = shape.const_shape [4, 57, 92] : !shape.shape
53  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
54  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
55  return
56}
57
58func @test_broadcast_extents() -> tensor<?xindex> {
59  %0 = shape.const_shape [10, 1, 57, 92] : tensor<?xindex>
60  %1 = shape.const_shape [4, 57, 92] : tensor<?xindex>
61  %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
62  return %2 : tensor<?xindex>
63}
64
65func @test_shape_any_fixed() {
66  %0 = shape.const_shape [4, 57, 92] : !shape.shape
67  %1 = shape.const_shape [4, 57, 92] : !shape.shape
68  %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
69  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
70  return
71}
72
73func @test_shape_any_unknown() {
74  %0 = shape.const_shape [4, -1, 92] : !shape.shape
75  %1 = shape.const_shape [-1, 57, 92] : !shape.shape
76  %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
77  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
78  return
79}
80
81func @test_shape_any_fixed_mismatch() {
82  %0 = shape.const_shape [4, 57, 92] : !shape.shape
83  %1 = shape.const_shape [2, 57, 92] : !shape.shape
84  %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
85  %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
86  return
87}
88
89func @test_parse_const_shape() {
90  %0 = shape.const_shape [] : !shape.shape
91  %1 = shape.const_shape [1, 2, 3] : !shape.shape
92  %2 = shape.const_shape [1, 2, 3] : tensor<?xindex>
93  return
94}
95
96func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> {
97  %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex>
98  return %0 : tensor<?xindex>
99}
100
101func @test_constraints() {
102  %0 = shape.const_shape [] : !shape.shape
103  %1 = shape.const_shape [1, 2, 3] : !shape.shape
104  %true = constant true
105  %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
106  %w1 = shape.cstr_eq %0, %1 : !shape.shape, !shape.shape
107  %w2 = shape.const_witness true
108  %w3 = shape.const_witness false
109  %w4 = shape.cstr_require %true, "msg"
110  %w_all = shape.assuming_all %w0, %w1, %w2, %w3, %w4
111  shape.assuming %w_all -> !shape.shape {
112    %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
113    shape.assuming_yield %2 : !shape.shape
114  }
115  return
116}
117
118func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
119                           %rhs : tensor<?xindex>) {
120  %w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
121  return
122}
123
124func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
125                                      %rhs : tensor<?xindex>) {
126  %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
127  return
128}
129
130func @mul(%size_arg : !shape.size, %index_arg : index) {
131  %size_prod = shape.mul %size_arg, %size_arg
132      : !shape.size, !shape.size -> !shape.size
133  %index_prod = shape.mul %index_arg, %index_arg : index, index -> index
134  %mixed_prod = shape.mul %size_arg, %index_arg
135      : !shape.size, index -> !shape.size
136  return
137}
138
139func @div(%size_arg : !shape.size, %index_arg : index) {
140  %size_div = shape.div %size_arg, %size_arg
141      : !shape.size, !shape.size -> !shape.size
142  %index_div = shape.div %index_arg, %index_arg : index, index -> index
143  %mixed_div = shape.div %size_arg, %index_arg
144      : !shape.size, index -> !shape.size
145  return
146}
147
148func @add(%size_arg : !shape.size, %index_arg : index) {
149  %size_sum = shape.add %size_arg, %size_arg
150      : !shape.size, !shape.size -> !shape.size
151  %index_sum = shape.add %index_arg, %index_arg : index, index -> index
152  %mixed_sum = shape.add %size_arg, %index_arg
153      : !shape.size, index -> !shape.size
154  return
155}
156
157func @const_size() {
158  // CHECK: %c1 = shape.const_size 1
159  // CHECK: %c2 = shape.const_size 2
160  // CHECK: %c2_0 = shape.const_size 2
161  %0 = shape.const_size 1
162  %1 = shape.const_size 2
163  %2 = shape.const_size 2
164  return
165}
166
167func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> {
168  %0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex>
169  return %0 : tensor<3xindex>
170}
171
172func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
173  %0 = shape.from_extent_tensor %arg : tensor<?xindex>
174  return %0 : !shape.shape
175}
176
177func @rank(%shape : !shape.shape) -> !shape.size {
178  %rank = shape.rank %shape : !shape.shape -> !shape.size
179  return %rank : !shape.size
180}
181
182func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
183  %rank = shape.rank %shape : tensor<?xindex> -> index
184  return %rank : index
185}
186
187func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
188  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
189  return %result : i1
190}
191
192func @shape_eq_on_tensors(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
193  %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
194  return %result : i1
195}
196
197func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
198  %result = shape.shape_eq %a, %b : tensor<?xindex>, !shape.shape
199  return %result : i1
200}
201
202func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
203  %c0 = shape.const_size 0
204  %result = shape.get_extent %arg, %c0 :
205      !shape.shape, !shape.size -> !shape.size
206  return %result : !shape.size
207}
208
209func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
210  %c0 = constant 0 : index
211  %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index
212  return %result : index
213}
214
215func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
216  %c0 = shape.const_size 0
217  %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
218  return %result : !shape.size
219}
220
221func @any() {
222  %0 = shape.const_shape [1, 2, 3] : !shape.shape
223  %1 = shape.const_shape [4, 5, 6] : !shape.shape
224  %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
225  %3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
226  %4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
227  %5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
228  return
229}
230
231func @num_elements_extent_tensor(%arg : tensor<?xindex>) -> index {
232  %result = shape.num_elements %arg : tensor<?xindex> -> index
233  return %result : index
234}
235
236func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
237  %result = shape.num_elements %arg : !shape.shape -> !shape.size
238  return %result : !shape.size
239}
240
241// Testing invoking shape function from another. shape_equal_shapes is merely
242// a trivial helper function to invoke elsewhere.
243func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
244  %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
245  %1 = shape.shape_of %b : !shape.value_shape -> !shape.shape
246  %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
247  return %2 : !shape.shape
248}
249func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
250  %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
251  %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape
252  %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape
253  return %2 : !shape.shape
254}
255
256func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
257    -> !shape.shape {
258  %result = shape.any %a, %b, %c
259      : !shape.shape, !shape.shape, !shape.shape -> !shape.shape
260  return %result : !shape.shape
261}
262
263func @any_on_mixed(%a : tensor<?xindex>,
264                   %b : tensor<?xindex>,
265                   %c : !shape.shape) -> !shape.shape {
266  %result = shape.any %a, %b, %c
267      : tensor<?xindex>, tensor<?xindex>, !shape.shape -> !shape.shape
268  return %result : !shape.shape
269}
270
271func @any_on_extent_tensors(%a : tensor<?xindex>,
272                            %b : tensor<?xindex>,
273                            %c : tensor<?xindex>) -> tensor<?xindex> {
274  %result = shape.any %a, %b, %c
275      : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
276  return %result : tensor<?xindex>
277}
278
279func @is_broadcastable_on_extent_tensors(%a : tensor<?xindex>,
280                                         %b : tensor<?xindex>) -> i1 {
281  %result = shape.is_broadcastable %a, %b
282      : tensor<?xindex>, tensor<?xindex>
283  return %result : i1
284}
285
286func @is_broadcastable_on_shapes(%a : !shape.shape,
287                                 %b : !shape.shape) -> i1 {
288  %result = shape.is_broadcastable %a, %b
289      : !shape.shape, !shape.shape
290  return %result : i1
291}
292
293func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
294  %0 = shape.const_shape [4, 57, 92] : !shape.shape
295  %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
296  %2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
297    !shape.shape, !shape.shape -> !shape.shape
298  return %2 : !shape.shape
299}
300
301func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
302  %0 = shape.const_shape [4, 57, 92] : !shape.shape
303  %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
304  %2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
305    !shape.shape, !shape.shape -> !shape.shape
306  return %2 : !shape.shape
307}
308
309func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
310  %0 = shape.const_size 5
311  %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
312  %2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
313    !shape.size, !shape.size -> !shape.size
314  return %2 : !shape.size
315}
316
317func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
318  %0 = shape.const_size 9
319  %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
320  %2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
321    !shape.size, !shape.size -> !shape.size
322  return %2 : !shape.size
323}
324