1// Verify the printed output can be parsed. 2// RUN: mlir-opt %s | mlir-opt | FileCheck %s 3// Verify the generic form can be parsed. 4// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s 5 6// CHECK-LABEL: shape_num_elements 7func @shape_num_elements(%shape : !shape.shape) -> !shape.size { 8 %init = shape.const_size 1 9 %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { 10 ^bb0(%index : index, %extent : !shape.size, %acc : !shape.size): 11 %acc_next = shape.mul %acc, %extent 12 : !shape.size, !shape.size -> !shape.size 13 shape.yield %acc_next : !shape.size 14 } 15 return %num_elements : !shape.size 16} 17 18// CHECK-LABEL: extent_tensor_num_elements 19func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index { 20 %init = constant 1 : index 21 %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index { 22 ^bb0(%index : index, %extent : index, %acc : index): 23 %acc_next = shape.mul %acc, %extent : index, index -> index 24 shape.yield %acc_next : index 25 } 26 return %num_elements : index 27} 28 29func @test_shape_num_elements_unknown() { 30 %0 = "shape.unknown_shape"() : () -> !shape.shape 31 %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size) 32 %2 = "shape.print"(%1) : (!shape.size) -> !shape.size 33 return 34} 35 36func @const_shape() { 37 %0 = shape.const_shape [1, 2, 3] : !shape.shape 38 %1 = shape.const_shape [4, 5, 6] : tensor<?xindex> 39 %2 = shape.const_shape [4, 5, 6] : tensor<3xindex> 40 return 41} 42 43func @test_shape_num_elements_fixed() { 44 %0 = shape.const_shape [1, 57, 92] : !shape.shape 45 %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size) 46 %3 = "shape.print"(%1) : (!shape.size) -> !shape.size 47 return 48} 49 50func @test_broadcast_fixed() { 51 %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape 52 %1 = shape.const_shape [4, 57, 92] : !shape.shape 53 %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape 54 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape 55 return 56} 57 58func @test_broadcast_extents() -> tensor<?xindex> { 59 %0 = shape.const_shape [10, 1, 57, 92] : tensor<?xindex> 60 %1 = shape.const_shape [4, 57, 92] : tensor<?xindex> 61 %2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex> 62 return %2 : tensor<?xindex> 63} 64 65func @test_shape_any_fixed() { 66 %0 = shape.const_shape [4, 57, 92] : !shape.shape 67 %1 = shape.const_shape [4, 57, 92] : !shape.shape 68 %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 69 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape 70 return 71} 72 73func @test_shape_any_unknown() { 74 %0 = shape.const_shape [4, -1, 92] : !shape.shape 75 %1 = shape.const_shape [-1, 57, 92] : !shape.shape 76 %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 77 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape 78 return 79} 80 81func @test_shape_any_fixed_mismatch() { 82 %0 = shape.const_shape [4, 57, 92] : !shape.shape 83 %1 = shape.const_shape [2, 57, 92] : !shape.shape 84 %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 85 %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape 86 return 87} 88 89func @test_parse_const_shape() { 90 %0 = shape.const_shape [] : !shape.shape 91 %1 = shape.const_shape [1, 2, 3] : !shape.shape 92 %2 = shape.const_shape [1, 2, 3] : tensor<?xindex> 93 return 94} 95 96func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> { 97 %0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<?xindex> 98 return %0 : tensor<?xindex> 99} 100 101func @test_constraints() { 102 %0 = shape.const_shape [] : !shape.shape 103 %1 = shape.const_shape [1, 2, 3] : !shape.shape 104 %true = constant true 105 %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape 106 %w1 = shape.cstr_eq %0, %1 : !shape.shape, !shape.shape 107 %w2 = shape.const_witness true 108 %w3 = shape.const_witness false 109 %w4 = shape.cstr_require %true, "msg" 110 %w_all = shape.assuming_all %w0, %w1, %w2, %w3, %w4 111 shape.assuming %w_all -> !shape.shape { 112 %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 113 shape.assuming_yield %2 : !shape.shape 114 } 115 return 116} 117 118func @eq_on_extent_tensors(%lhs : tensor<?xindex>, 119 %rhs : tensor<?xindex>) { 120 %w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex> 121 return 122} 123 124func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>, 125 %rhs : tensor<?xindex>) { 126 %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex> 127 return 128} 129 130func @mul(%size_arg : !shape.size, %index_arg : index) { 131 %size_prod = shape.mul %size_arg, %size_arg 132 : !shape.size, !shape.size -> !shape.size 133 %index_prod = shape.mul %index_arg, %index_arg : index, index -> index 134 %mixed_prod = shape.mul %size_arg, %index_arg 135 : !shape.size, index -> !shape.size 136 return 137} 138 139func @div(%size_arg : !shape.size, %index_arg : index) { 140 %size_div = shape.div %size_arg, %size_arg 141 : !shape.size, !shape.size -> !shape.size 142 %index_div = shape.div %index_arg, %index_arg : index, index -> index 143 %mixed_div = shape.div %size_arg, %index_arg 144 : !shape.size, index -> !shape.size 145 return 146} 147 148func @add(%size_arg : !shape.size, %index_arg : index) { 149 %size_sum = shape.add %size_arg, %size_arg 150 : !shape.size, !shape.size -> !shape.size 151 %index_sum = shape.add %index_arg, %index_arg : index, index -> index 152 %mixed_sum = shape.add %size_arg, %index_arg 153 : !shape.size, index -> !shape.size 154 return 155} 156 157func @const_size() { 158 // CHECK: %c1 = shape.const_size 1 159 // CHECK: %c2 = shape.const_size 2 160 // CHECK: %c2_0 = shape.const_size 2 161 %0 = shape.const_size 1 162 %1 = shape.const_size 2 163 %2 = shape.const_size 2 164 return 165} 166 167func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> { 168 %0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex> 169 return %0 : tensor<3xindex> 170} 171 172func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape { 173 %0 = shape.from_extent_tensor %arg : tensor<?xindex> 174 return %0 : !shape.shape 175} 176 177func @rank(%shape : !shape.shape) -> !shape.size { 178 %rank = shape.rank %shape : !shape.shape -> !shape.size 179 return %rank : !shape.size 180} 181 182func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index { 183 %rank = shape.rank %shape : tensor<?xindex> -> index 184 return %rank : index 185} 186 187func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 { 188 %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape 189 return %result : i1 190} 191 192func @shape_eq_on_tensors(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 { 193 %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> 194 return %result : i1 195} 196 197func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 { 198 %result = shape.shape_eq %a, %b : tensor<?xindex>, !shape.shape 199 return %result : i1 200} 201 202func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size { 203 %c0 = shape.const_size 0 204 %result = shape.get_extent %arg, %c0 : 205 !shape.shape, !shape.size -> !shape.size 206 return %result : !shape.size 207} 208 209func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index { 210 %c0 = constant 0 : index 211 %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index 212 return %result : index 213} 214 215func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size { 216 %c0 = shape.const_size 0 217 %result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size 218 return %result : !shape.size 219} 220 221func @any() { 222 %0 = shape.const_shape [1, 2, 3] : !shape.shape 223 %1 = shape.const_shape [4, 5, 6] : !shape.shape 224 %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 225 %3 = shape.const_shape [1, 2, 3] : tensor<?xindex> 226 %4 = shape.const_shape [4, 5, 6] : tensor<?xindex> 227 %5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex> 228 return 229} 230 231func @num_elements_extent_tensor(%arg : tensor<?xindex>) -> index { 232 %result = shape.num_elements %arg : tensor<?xindex> -> index 233 return %result : index 234} 235 236func @num_elements_shape(%arg : !shape.shape) -> !shape.size { 237 %result = shape.num_elements %arg : !shape.shape -> !shape.size 238 return %result : !shape.size 239} 240 241// Testing invoking shape function from another. shape_equal_shapes is merely 242// a trivial helper function to invoke elsewhere. 243func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { 244 %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape 245 %1 = shape.shape_of %b : !shape.value_shape -> !shape.shape 246 %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape 247 return %2 : !shape.shape 248} 249func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape { 250 %0 = shape.shape_of %a : !shape.value_shape -> !shape.shape 251 %1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape 252 %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape 253 return %2 : !shape.shape 254} 255 256func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape) 257 -> !shape.shape { 258 %result = shape.any %a, %b, %c 259 : !shape.shape, !shape.shape, !shape.shape -> !shape.shape 260 return %result : !shape.shape 261} 262 263func @any_on_mixed(%a : tensor<?xindex>, 264 %b : tensor<?xindex>, 265 %c : !shape.shape) -> !shape.shape { 266 %result = shape.any %a, %b, %c 267 : tensor<?xindex>, tensor<?xindex>, !shape.shape -> !shape.shape 268 return %result : !shape.shape 269} 270 271func @any_on_extent_tensors(%a : tensor<?xindex>, 272 %b : tensor<?xindex>, 273 %c : tensor<?xindex>) -> tensor<?xindex> { 274 %result = shape.any %a, %b, %c 275 : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex> 276 return %result : tensor<?xindex> 277} 278 279func @is_broadcastable_on_extent_tensors(%a : tensor<?xindex>, 280 %b : tensor<?xindex>) -> i1 { 281 %result = shape.is_broadcastable %a, %b 282 : tensor<?xindex>, tensor<?xindex> 283 return %result : i1 284} 285 286func @is_broadcastable_on_shapes(%a : !shape.shape, 287 %b : !shape.shape) -> i1 { 288 %result = shape.is_broadcastable %a, %b 289 : !shape.shape, !shape.shape 290 return %result : i1 291} 292 293func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape { 294 %0 = shape.const_shape [4, 57, 92] : !shape.shape 295 %1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape 296 %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : 297 !shape.shape, !shape.shape -> !shape.shape 298 return %2 : !shape.shape 299} 300 301func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape { 302 %0 = shape.const_shape [4, 57, 92] : !shape.shape 303 %1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape 304 %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : 305 !shape.shape, !shape.shape -> !shape.shape 306 return %2 : !shape.shape 307} 308 309func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size { 310 %0 = shape.const_size 5 311 %1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size 312 %2 = shape.join %0, %1, error="exceeded element-wise upper bound" : 313 !shape.size, !shape.size -> !shape.size 314 return %2 : !shape.size 315} 316 317func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size { 318 %0 = shape.const_size 9 319 %1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size 320 %2 = shape.join %0, %1, error="lower bound element-wise exceeded" : 321 !shape.size, !shape.size -> !shape.size 322 return %2 : !shape.size 323} 324