1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name 18"""A prelude containing useful global functions and ADT definitions.""" 19from tvm.ir import IRModule, TypeCall 20from tvm.relay.transform import ToANormalFormExpr 21 22from .ty import GlobalTypeVar, TensorType, Any, scalar_type 23from .expr import Var, GlobalVar, If, const 24from .function import Function 25from .op.tensor import add, subtract, equal 26from .adt import Constructor, TypeData, Clause, Match 27from .adt import PatternConstructor, PatternVar, PatternWildcard 28from . import op, transform 29from .analysis import free_vars 30 31 32def get_tensor_array_shape(expr, dtype, prelude): 33 """Get the static shape of a tensor array if it has fixed rank shape. 34 35 By design, static ADT tensor in TVM has type name in the format 36 of static_tensor_dim0_dim1_..._dimN_t. 37 38 Parameters 39 ---------- 40 expr : Relay Expr 41 Input expression. 42 43 dtype : str 44 Data type. 45 46 prelude : Prelude 47 Tensor array prelude 48 49 Returns 50 ------- 51 shape : tuple of (int, Any) or None 52 The output shape. None if input tensor array 53 has dynamic shape. 54 """ 55 mod = prelude.mod 56 mod["main"] = Function(free_vars(expr), expr) 57 mod = transform.InferType()(mod) 58 checked_type = mod["main"].body.checked_type 59 assert isinstance(checked_type, TypeCall), "Input must be a tensor array." 60 ta_type_str = checked_type.args[0].func.name_hint 61 static_ta_ty_start = "static_tensor_{}".format(dtype) 62 if ta_type_str.startswith(static_ta_ty_start): 63 shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), "").replace("_t", "") 64 shape = [] 65 if "scalar" not in shape_str: 66 for dim_str in shape_str.split("_"): 67 if dim_str == "?": 68 shape.append(Any()) 69 else: 70 shape.append(int(dim_str)) 71 return tuple(shape) 72 return None 73 74 75def _get_name_static(canonical, dtype, shape): 76 """Get name for static shape tensor array op corresponding 77 to the canonical name""" 78 shape_str = "_".join([str(dim) for dim in shape]) 79 if len(shape_str) == 0: 80 shape_str = "scalar" 81 if canonical == "tensor_t": 82 return "static_tensor_{}_{}_t".format(dtype, shape_str) 83 return "{}_{}_{}".format(canonical, dtype, shape_str) 84 85 86class StaticTensorArrayOps(object): 87 """Contains tensor array related ops for fixed rank tensor array""" 88 89 def __init__(self, prelude, dtype, shape): 90 """Create tensor array ops registry""" 91 self.prelude = prelude 92 self.dtype = dtype 93 self.shape = shape 94 95 def get_name(self, canonical): 96 """Get name corresponding to the canonical name""" 97 return _get_name_static(canonical, self.dtype, self.shape) 98 99 def get_var(self, canonical): 100 """Get var corresponding to the canonical name""" 101 name = self.get_name(canonical) 102 return getattr(self.prelude, name) 103 104 def define_tensor_adt(self): 105 """Defines the static tensor ADT, which is the container for tensors 106 with fixed shapes.""" 107 tensor_type_name = self.get_name("tensor_t") 108 # Skip register if tensor type is already registered. 109 global_type_names = set() 110 for g_ty_var in self.prelude.mod.get_global_type_vars(): 111 global_type_names.add(g_ty_var.name_hint) 112 if tensor_type_name in global_type_names: 113 return 114 115 tensor_type_var = GlobalTypeVar(tensor_type_name) 116 setattr(self.prelude, tensor_type_name, tensor_type_var) 117 tensor_type = TensorType(self.shape, self.dtype) 118 tensor_constructor_name = self.get_name("tensor_constructor") 119 120 tensor_nil_name = self.get_name("tensor_nil") 121 tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var) 122 tensor_case = Constructor(tensor_constructor_name, [tensor_type], tensor_type_var) 123 124 setattr(self.prelude, tensor_nil_name, tensor_nil_case) 125 setattr(self.prelude, tensor_constructor_name, tensor_case) 126 self.prelude.mod[tensor_type_var] = TypeData( 127 tensor_type_var, [], [tensor_nil_case, tensor_case] 128 ) 129 130 def define_tensor_array(self): 131 """Defines a function to create a tensor array with size n. 132 tensor_array(n) : Tensor[(), int32] -> list[tensor_t] 133 """ 134 tensor_array_constructor_name = self.get_name("tensor_array") 135 tensor_array_constructor_var = self._create_global_var(tensor_array_constructor_name) 136 setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) 137 tensor_nil_var = self.get_var("tensor_nil") 138 tensor_type_var = self.get_var("tensor_t") 139 n = Var("x", scalar_type("int32")) 140 body = If( 141 equal(n, const(0)), 142 self.prelude.nil(), 143 self.prelude.cons( 144 tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1))) 145 ), 146 ) 147 self.prelude.mod[tensor_array_constructor_var] = Function( 148 [n], body, self.prelude.l(tensor_type_var()), [] 149 ) 150 151 def define_tensor_take(self): 152 """Defines a function to return a range of tensor_t on axis 0. 153 tensor_take(t, lower, upper) : 154 tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t 155 """ 156 # We don't register take for scalar tensor. 157 ndim = len(self.shape) 158 if ndim == 0: 159 return 160 161 take_name = self.get_name("tensor_take") 162 take_var = self._create_global_var(take_name) 163 setattr(self.prelude, take_name, take_var) 164 origin_tensor_constructor = self.get_var("tensor_constructor") 165 166 output_shape = [ 167 Any(), 168 ] + list(self.shape[1:]) 169 tensor_type_var, tensor_constructor = self._get_adt_by_shape(output_shape) 170 171 t = Var("tensor", self.get_var("tensor_t")()) 172 lower = Var("lower", scalar_type("int32")) 173 upper = Var("upper", scalar_type("int32")) 174 tvar = Var("t") 175 case = Clause( 176 PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]), 177 tensor_constructor(op.take(tvar, op.arange(lower, upper, dtype="int32"), axis=0)), 178 ) 179 self.prelude.mod[take_var] = Function( 180 [t, lower, upper], Match(t, [case], False), tensor_type_var(), [] 181 ) 182 183 def define_tensor_concatenate(self): 184 """Defines a function to concatenate two tensor_t on axis 0. 185 tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t 186 """ 187 # We don't register concatenate for scalar tensor. 188 ndim = len(self.shape) 189 if ndim == 0: 190 return 191 192 concat_name = self.get_name("tensor_concatenate") 193 concat_var = self._create_global_var(concat_name) 194 setattr(self.prelude, concat_name, concat_var) 195 output_shape = [ 196 Any(), 197 ] + list(self.shape[1:]) 198 tensor_type_var, tensor_constructor = self._get_adt_by_shape(output_shape) 199 200 origin_tensor_constructor = self.get_var("tensor_constructor") 201 origin_tensor_type_var = self.get_var("tensor_t") 202 x = Var("x", origin_tensor_type_var()) 203 y = Var("y", origin_tensor_type_var()) 204 t1 = Var("t1") 205 t2 = Var("t2") 206 207 case = Clause( 208 PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]), 209 Match( 210 y, 211 [ 212 Clause( 213 PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]), 214 tensor_constructor(op.concatenate([t1, t2], axis=0)), 215 ) 216 ], 217 False, 218 ), 219 ) 220 221 self.prelude.mod[concat_var] = Function( 222 [x, y], Match(x, [case], False), tensor_type_var(), [] 223 ) 224 225 def define_tensor_expand_dims(self): 226 """Defines a function to grow a tensor_t's rank by adding one dimension in front 227 of the original tensor_t. 228 tensor_expand_dims(t) : tensor_t -> tensor_t 229 """ 230 expand_dims_name = self.get_name("tensor_expand_dims") 231 expand_dims_var = self._create_global_var(expand_dims_name) 232 setattr(self.prelude, expand_dims_name, expand_dims_var) 233 origin_tensor_type_var = self.get_var("tensor_t") 234 origin_tensor_constructor = self.get_var("tensor_constructor") 235 x = Var("x", origin_tensor_type_var()) 236 237 # Note: we set the added axis to be Any() instead of 1 due to 238 # in stack op, we need to recursively concatenate. 239 tensor_type_var, tensor_constructor = self._get_adt_by_shape( 240 [ 241 Any(), 242 ] 243 + list(self.shape) 244 ) 245 t = Var("t") 246 case = Clause( 247 PatternConstructor(origin_tensor_constructor, [PatternVar(t)]), 248 tensor_constructor(op.expand_dims(t, 0, 1)), 249 ) 250 251 self.prelude.mod[expand_dims_var] = Function( 252 [x], Match(x, [case], False), tensor_type_var(), [] 253 ) 254 255 def define_tensor_array_read(self): 256 """Defines a function to get the nth element of a list. Assume the list has at least one 257 element. 258 tensor_array_read(ta, n) : list[static_tensor_t] -> Tensor[(), int32] -> 259 Tensor[self.shape, self.dtype] 260 """ 261 read_name = self.get_name("tensor_array_read") 262 read_var = self._create_global_var(read_name) 263 setattr(self.prelude, read_name, read_var) 264 tensor_type_var = self.get_var("tensor_t") 265 266 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 267 n = Var("x", scalar_type("int32")) 268 self.prelude.mod[read_var] = Function( 269 [tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [] 270 ) 271 272 def define_tensor_array_write(self): 273 """Defines a function to update a tensor array at index n with value v. 274 tensor_array_write(ta, n, v) : 275 list[static_tensor_t] -> Tensor[(), int32] -> Tensor[self.shape, self.dtype] -> 276 list[static_tensor_t] 277 """ 278 write_name = self.get_name("tensor_array_write") 279 write_var = self._create_global_var(write_name) 280 setattr(self.prelude, write_name, write_var) 281 tensor_type_var = self.get_var("tensor_t") 282 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 283 n = Var("x", scalar_type("int32")) 284 v = Var("v", tensor_type_var()) 285 self.prelude.mod[write_var] = Function( 286 [tensor_array, n, v], 287 self.prelude.update(tensor_array, n, v), 288 self.prelude.l(tensor_type_var()), 289 [], 290 ) 291 292 def define_tensor_array_unstack(self): 293 """Defines a function to unstack the values of a tensor_t in a tensor array. 294 tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t] 295 """ 296 ndim = len(self.shape) 297 # We don't register unstack for scalar tensor array 298 if ndim == 0: 299 return 300 301 helper_name = self.get_name("tensor_array_unstack_helper") 302 helper_var = self._create_global_var(helper_name) 303 setattr(self.prelude, helper_name, helper_var) 304 tensor = Var("t", TensorType(self.shape, self.dtype)) 305 up = Var("up", scalar_type("int32")) 306 i = Var("i", scalar_type("int32")) 307 tensor_var = Var("tensor", TensorType(self.shape, self.dtype)) 308 309 reduced_tensor_type_var, tensor_constructor = self._get_adt_by_shape(self.shape[1:]) 310 helper_body = If( 311 equal(i, up), 312 self.prelude.nil(), 313 self.prelude.cons( 314 tensor_constructor(op.take(tensor, i, axis=0)), 315 helper_var(add(i, const(1)), up, tensor), 316 ), 317 ) 318 self.prelude.mod[helper_var] = Function( 319 [i, up, tensor], helper_body, self.prelude.l(reduced_tensor_type_var()), [] 320 ) 321 322 unstack_name = self.get_name("tensor_array_unstack") 323 unstack_var = self._create_global_var(unstack_name) 324 setattr(self.prelude, unstack_name, unstack_var) 325 shape = op.shape_of(tensor_var) 326 unstack_length = op.take(shape, const(0)) 327 self.prelude.mod[unstack_var] = Function( 328 [tensor_var], 329 helper_var(const(0), unstack_length, tensor_var), 330 self.prelude.l(reduced_tensor_type_var()), 331 [], 332 ) 333 334 def define_tensor_array_scatter(self, indices_shape=None, force_update=False): 335 """Defines a function to scatter the values of a tensor_t in indices of a tensor array. 336 tensor_array_scatter(ta, indices, value) : 337 list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] 338 339 Set static indices shape by specifying indices_shape. 340 Set force_update to get static indices shape operator. 341 """ 342 # When this operator has already been registered, only update 343 # when force_update is set. This should be used only when we need to 344 # redefine this op for static indices shape. 345 tensor_array_scatter_name = self.get_name("tensor_array_scatter") 346 if hasattr(self.prelude, tensor_array_scatter_name) and not force_update: 347 return 348 349 tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") 350 tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name) 351 tensor_type_var = self.get_var("tensor_t") 352 ta = Var("ta", self.prelude.l(tensor_type_var())) 353 current = Var("current", scalar_type("int32")) 354 limit = Var("limit", scalar_type("int32")) 355 indices_ = Var("indices_", TensorType(indices_shape or [Any()], "int32")) 356 values_ = Var("values_", self.prelude.l(tensor_type_var())) 357 write_var = self.get_var("tensor_array_write") 358 read_var = self.get_var("tensor_array_read") 359 helper_body = If( 360 equal(current, limit), 361 ta, 362 tensor_array_scatter_helper_var( 363 write_var(ta, op.take(indices_, current), read_var(values_, current)), 364 add(current, const(1)), 365 limit, 366 indices_, 367 values_, 368 ), 369 ) 370 self.prelude.mod[tensor_array_scatter_helper_var] = Function( 371 [ta, current, limit, indices_, values_], 372 helper_body, 373 self.prelude.l(tensor_type_var()), 374 [], 375 ) 376 377 tensor_array_scatter_var = self._create_global_var(tensor_array_scatter_name) 378 setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) 379 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 380 381 indices = Var("indices", TensorType(indices_shape or [Any()], "int32")) 382 values = Var("values", self.prelude.l(tensor_type_var())) 383 if indices_shape is None: 384 indices_shape = op.shape_of(indices) 385 limit = op.take(indices_shape, const(0)) 386 else: 387 limit = const(indices_shape[0]) 388 389 body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) 390 self.prelude.mod[tensor_array_scatter_var] = Function( 391 [tensor_array, indices, values], body, self.prelude.l(tensor_type_var()), [] 392 ) 393 394 def define_tensor_array_split(self, value_shape=None, lengths_shape=None, force_update=False): 395 """Defines a function to split the values of a tensor_t into a tensor array. 396 tensor_array_split(ta, value, lengths) : 397 list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] 398 399 Set static value and lengths shapes by specifying value_shape and lengths_shape. 400 Set force_update to get static value and lengths shape operator. 401 """ 402 # Skip scalar case 403 ndim = len(self.shape) 404 if ndim == 0: 405 return 406 407 # When this operator has already been registered, only update 408 # when force_update is set. This should be used only when we need to 409 # redefine this op for static value/indices shape. 410 split_name = self.get_name("tensor_array_split") 411 if hasattr(self.prelude, split_name) and not force_update: 412 return 413 414 tensor_type_var = self.get_var("tensor_t") 415 tensor_array_split_helper_name = self.get_name("ta_split_helper") 416 tensor_array_split_helper_var = self._create_global_var(tensor_array_split_helper_name) 417 setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) 418 output_shape = [ 419 Any(), 420 ] + list(self.shape[1:]) 421 output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) 422 423 if value_shape is None: 424 value_type_var = tensor_type_var 425 take_var = self.get_var("tensor_take") 426 else: 427 value_type_var, _ = self._get_adt_by_shape(value_shape) 428 # Also get static shape take operator 429 origin_shape = list(self.shape) 430 self.shape = value_shape 431 self.define_tensor_take() 432 take_var = self.get_var("tensor_take") 433 self.shape = origin_shape 434 435 ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var())) 436 value1 = Var("value1", value_type_var()) 437 offset1 = Var("offset1", scalar_type("int32")) 438 current1 = Var("current1", scalar_type("int32")) 439 limit1 = Var("limit1", scalar_type("int32")) 440 lengths1 = Var("lengths", TensorType(lengths_shape or [Any()], "int32")) 441 442 # Register write for output shape 443 origin_shape = list(self.shape) 444 self.shape = output_shape 445 self.define_tensor_array_write() 446 write_var = self.get_var("tensor_array_write") 447 self.shape = origin_shape 448 helper1_body = If( 449 equal(current1, limit1), 450 ta1, 451 write_var( 452 tensor_array_split_helper_var( 453 ta1, 454 value1, 455 add(offset1, op.take(lengths1, current1)), 456 add(current1, const(1)), 457 limit1, 458 lengths1, 459 ), 460 current1, 461 take_var(value1, offset1, add(op.take(lengths1, current1), offset1)), 462 ), 463 ) 464 self.prelude.mod[tensor_array_split_helper_var] = Function( 465 [ta1, value1, offset1, current1, limit1, lengths1], 466 helper1_body, 467 self.prelude.l(output_tensor_type_var()), 468 [], 469 ) 470 split_var = self._create_global_var(split_name) 471 setattr(self.prelude, split_name, split_var) 472 tensor_array = Var("tensor_array", self.prelude.l(output_tensor_type_var())) 473 474 value = Var("value", value_type_var()) 475 lengths = Var("lengths", TensorType(lengths_shape or [Any()], "int32")) 476 if lengths_shape is None: 477 lengths_shape = op.shape_of(lengths) 478 lengths_limit = op.take(lengths_shape, const(0)) 479 else: 480 lengths_limit = const(lengths_shape[0]) 481 body = tensor_array_split_helper_var( 482 tensor_array, value, const(0), const(0), lengths_limit, lengths 483 ) 484 485 self.prelude.mod[split_var] = Function( 486 [tensor_array, value, lengths], body, self.prelude.l(output_tensor_type_var()), [] 487 ) 488 489 def define_tensor_array_concat(self): 490 """Defines a function to return the values in the tensor array as concatenated tensor_t. 491 tensor_array_concat(ta) : list[tensor_t] -> tensor_t 492 """ 493 # We don't register concat for scalar tensor array. 494 ndim = len(self.shape) 495 if ndim == 0: 496 return 497 498 concat_name = self.get_name("tensor_array_concat") 499 concat_var = self._create_global_var(concat_name) 500 setattr(self.prelude, concat_name, concat_var) 501 502 output_shape = [ 503 Any(), 504 ] + list(self.shape[1:]) 505 tensor_type_var, _ = self._get_adt_by_shape(output_shape) 506 507 # Register tensor concatenate and get tensor_nil var for output shape 508 origin_shape = self.shape 509 self.shape = output_shape 510 self.define_tensor_concatenate() 511 tensor_concat_var = self.get_var("tensor_concatenate") 512 tensor_nil_var = self.get_var("tensor_nil") 513 self.shape = origin_shape 514 515 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 516 hd = Var("hd") 517 tl = Var("tl") 518 nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) 519 cons_case = Clause( 520 PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), 521 Match( 522 tl, 523 [ 524 Clause(PatternConstructor(self.prelude.nil), hd), 525 Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))), 526 ], 527 False, 528 ), 529 ) 530 self.prelude.mod[concat_var] = Function( 531 [tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), [] 532 ) 533 534 def define_tensor_array_stack(self): 535 """Defines a function to get the values in the tensor array as a stack tensor_t. 536 tensor_array_stack(l) : list[tensor_t] -> tensor_t 537 """ 538 stack_name = self.get_name("tensor_array_stack") 539 stack_var = self._create_global_var(stack_name) 540 setattr(self.prelude, stack_name, stack_var) 541 tensor_type_var = self.get_var("tensor_t") 542 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 543 expand_dims_var = self.get_var("tensor_expand_dims") 544 545 # Register tensor_concatenate for output_shape 546 origin_shape = self.shape 547 output_shape = [ 548 Any(), 549 ] + list(self.shape) 550 self.shape = output_shape 551 self.define_tensor_concatenate() 552 concat_var = self.get_var("tensor_concatenate") 553 self.shape = origin_shape 554 555 tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) 556 tensors = self.prelude.foldl( 557 concat_var, 558 self.prelude.hd(tensor_array_expand_dims), 559 self.prelude.tl(tensor_array_expand_dims), 560 ) 561 output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) 562 self.prelude.mod[stack_var] = Function( 563 [tensor_array], tensors, output_tensor_type_var(), [] 564 ) 565 566 def define_tensor_array_gather(self): 567 """Defines a function to return the selected values in a tensor array as tensor_t. 568 tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t 569 """ 570 helper_name = self.get_name("tensor_array_gather_helper") 571 helper_var = self._create_global_var(helper_name) 572 setattr(self.prelude, helper_name, helper_var) 573 tensor_type_var = self.get_var("tensor_t") 574 output_shape = [ 575 Any(), 576 ] + list(self.shape) 577 output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) 578 stack_var = self.get_var("tensor_array_stack") 579 read_var = self.get_var("tensor_array_read") 580 ta = Var("ta", self.prelude.l(tensor_type_var())) 581 accu = Var("accu", self.prelude.l(tensor_type_var())) 582 current = Var("current", scalar_type("int32")) 583 limit = Var("limit", scalar_type("int32")) 584 indices_ = Var("indices_", TensorType([Any()], "int32")) 585 helper_body = If( 586 equal(current, const(0)), 587 stack_var(accu), 588 helper_var( 589 ta, 590 self.prelude.cons( 591 read_var(ta, op.take(indices_, subtract(current, const(1)))), accu 592 ), 593 subtract(current, const(1)), 594 limit, 595 indices_, 596 ), 597 ) 598 self.prelude.mod[helper_var] = Function( 599 [ta, accu, current, limit, indices_], helper_body, output_tensor_type_var(), [] 600 ) 601 gather_name = self.get_name("tensor_array_gather") 602 gather_var = self._create_global_var(gather_name) 603 setattr(self.prelude, gather_name, gather_var) 604 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 605 indices = Var("indices", TensorType([Any()], "int32")) 606 indices_shape = op.shape_of(indices) 607 limit = op.take(indices_shape, const(0)) 608 body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) 609 self.prelude.mod[gather_var] = Function( 610 [tensor_array, indices], body, output_tensor_type_var(), [] 611 ) 612 613 def define_tensor_get_data(self): 614 """Defines a function to get a Tensor from tensor_t with given shape.""" 615 tensor_get_data_name = self.get_name("tensor_get_data") 616 tensor_get_data_var = self._create_global_var(tensor_get_data_name) 617 setattr(self.prelude, tensor_get_data_name, tensor_get_data_var) 618 tensor_type_var = self.get_var("tensor_t") 619 tensor_constructor = self.get_var("tensor_constructor") 620 t = Var("tensor", tensor_type_var()) 621 tvar = Var("t") 622 case = Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar) 623 self.prelude.mod[tensor_get_data_var] = Function( 624 [t], Match(t, [case], False), TensorType(self.shape, self.dtype), [] 625 ) 626 627 def register(self): 628 """Register all tensor array ops in Prelude""" 629 self.define_tensor_adt() 630 self.define_tensor_take() 631 self.define_tensor_concatenate() 632 self.define_tensor_expand_dims() 633 self.define_tensor_array() 634 self.define_tensor_array_read() 635 self.define_tensor_array_write() 636 self.define_tensor_array_unstack() 637 self.define_tensor_array_scatter() 638 self.define_tensor_array_split() 639 self.define_tensor_array_concat() 640 self.define_tensor_array_stack() 641 self.define_tensor_array_gather() 642 self.define_tensor_get_data() 643 644 def _get_adt_by_shape(self, shape): 645 """Get ADT type and constructor with given shape.""" 646 origin_shape = self.shape 647 self.shape = shape 648 self.define_tensor_adt() 649 tensor_type_var = self.get_var("tensor_t") 650 tensor_constructor = self.get_var("tensor_constructor") 651 self.shape = origin_shape 652 return tensor_type_var, tensor_constructor 653 654 def _create_global_var(self, name): 655 """Create a GlobalVar if doesn't exist in prelude.""" 656 global_var_name_set = set() 657 for g_var_name in self.prelude.mod.get_global_vars(): 658 global_var_name_set.add(g_var_name.name_hint) 659 if name not in global_var_name_set: 660 gvar = GlobalVar(name) 661 else: 662 gvar = self.prelude.mod.get_global_var(name) 663 664 return gvar 665 666 667class TensorArrayOps(object): 668 """Contains tensor array related ops""" 669 670 def __init__(self, prelude, dtype): 671 """Create tensor array ops registry""" 672 self.prelude = prelude 673 self.dtype = dtype 674 675 def get_name(self, canonical): 676 """Get name corresponding to the canonical name""" 677 return self.prelude.get_name(canonical, self.dtype) 678 679 def get_var(self, canonical): 680 """Get var corresponding to the canonical name""" 681 return self.prelude.get_var(canonical, self.dtype) 682 683 def define_tensor_adt(self): 684 """Defines the dynamic tensor ADT, which is the container for tensors 685 with variable shapes.""" 686 tensor_type_name = self.get_name("tensor_t") 687 tensor_type_var = GlobalTypeVar(tensor_type_name) 688 setattr(self.prelude, tensor_type_name, tensor_type_var) 689 tensor0_type = TensorType([], self.dtype) 690 tensor1_type = TensorType([Any()], self.dtype) 691 tensor2_type = TensorType([Any(), Any()], self.dtype) 692 tensor3_type = TensorType([Any(), Any(), Any()], self.dtype) 693 tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype) 694 tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype) 695 tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype) 696 tensor_nil_name = self.get_name("tensor_nil") 697 tensor0_name = self.get_name("tensor0") 698 tensor1_name = self.get_name("tensor1") 699 tensor2_name = self.get_name("tensor2") 700 tensor3_name = self.get_name("tensor3") 701 tensor4_name = self.get_name("tensor4") 702 tensor5_name = self.get_name("tensor5") 703 tensor6_name = self.get_name("tensor6") 704 tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var) 705 tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var) 706 tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var) 707 tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var) 708 tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var) 709 tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var) 710 tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var) 711 tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var) 712 setattr(self.prelude, tensor_nil_name, tensor_nil_case) 713 setattr(self.prelude, tensor0_name, tensor0_case) 714 setattr(self.prelude, tensor1_name, tensor1_case) 715 setattr(self.prelude, tensor2_name, tensor2_case) 716 setattr(self.prelude, tensor3_name, tensor3_case) 717 setattr(self.prelude, tensor4_name, tensor4_case) 718 setattr(self.prelude, tensor5_name, tensor5_case) 719 setattr(self.prelude, tensor6_name, tensor6_case) 720 self.prelude.mod[tensor_type_var] = TypeData( 721 tensor_type_var, 722 [], 723 [ 724 tensor_nil_case, 725 tensor0_case, 726 tensor1_case, 727 tensor2_case, 728 tensor3_case, 729 tensor4_case, 730 tensor5_case, 731 tensor6_case, 732 ], 733 ) 734 735 def define_tensor_take(self): 736 """Defines a function to return a range of tensor_t on axis 0. 737 tensor_take(t, lower, upper) : 738 tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t 739 """ 740 take_name = self.get_name("tensor_take") 741 take_var = GlobalVar(take_name) 742 setattr(self.prelude, take_name, take_var) 743 tensor_t = self.get_var("tensor_t") 744 tensor1_var = self.get_var("tensor1") 745 tensor2_var = self.get_var("tensor2") 746 tensor3_var = self.get_var("tensor3") 747 tensor4_var = self.get_var("tensor4") 748 tensor5_var = self.get_var("tensor5") 749 tensor6_var = self.get_var("tensor6") 750 t = Var("tensor", tensor_t()) 751 lower = Var("lower", scalar_type("int32")) 752 upper = Var("upper", scalar_type("int32")) 753 t1 = Var("t1") 754 t2 = Var("t2") 755 t3 = Var("t3") 756 t4 = Var("t4") 757 t5 = Var("t5") 758 t6 = Var("t6") 759 tensor1_case = Clause( 760 PatternConstructor(tensor1_var, [PatternVar(t1)]), 761 tensor1_var(op.take(t1, op.arange(lower, upper, dtype="int32"))), 762 ) 763 tensor2_case = Clause( 764 PatternConstructor(tensor2_var, [PatternVar(t2)]), 765 tensor2_var(op.take(t2, op.arange(lower, upper, dtype="int32"), axis=0)), 766 ) 767 tensor3_case = Clause( 768 PatternConstructor(tensor3_var, [PatternVar(t3)]), 769 tensor3_var(op.take(t3, op.arange(lower, upper, dtype="int32"), axis=0)), 770 ) 771 tensor4_case = Clause( 772 PatternConstructor(tensor4_var, [PatternVar(t4)]), 773 tensor4_var(op.take(t4, op.arange(lower, upper, dtype="int32"), axis=0)), 774 ) 775 tensor5_case = Clause( 776 PatternConstructor(tensor5_var, [PatternVar(t5)]), 777 tensor5_var(op.take(t5, op.arange(lower, upper, dtype="int32"), axis=0)), 778 ) 779 tensor6_case = Clause( 780 PatternConstructor(tensor6_var, [PatternVar(t6)]), 781 tensor6_var(op.take(t6, op.arange(lower, upper, dtype="int32"), axis=0)), 782 ) 783 self.prelude.mod[take_var] = Function( 784 [t, lower, upper], 785 Match( 786 t, 787 [ 788 tensor1_case, 789 tensor2_case, 790 tensor3_case, 791 tensor4_case, 792 tensor5_case, 793 tensor6_case, 794 ], 795 False, 796 ), 797 tensor_t(), 798 [], 799 ) 800 801 def define_tensor_expand_dims(self): 802 """Defines a function to grow a tensor_t's rank by adding one dimension in front 803 of the original tensor_t. 804 tensor_expand_dims(t) : tensor_t -> tensor_t 805 """ 806 expand_dims_name = self.get_name("tensor_expand_dims") 807 expand_dims_var = GlobalVar(expand_dims_name) 808 setattr(self.prelude, expand_dims_name, expand_dims_var) 809 tensor_type_var = self.get_var("tensor_t") 810 x = Var("x", tensor_type_var()) 811 t0 = Var("t0") 812 t1 = Var("t1") 813 t2 = Var("t2") 814 t3 = Var("t3") 815 t4 = Var("t4") 816 t5 = Var("t5") 817 tensor0_var = self.get_var("tensor0") 818 tensor1_var = self.get_var("tensor1") 819 tensor2_var = self.get_var("tensor2") 820 tensor3_var = self.get_var("tensor3") 821 tensor4_var = self.get_var("tensor4") 822 tensor5_var = self.get_var("tensor5") 823 tensor6_var = self.get_var("tensor6") 824 tensor0_case = Clause( 825 PatternConstructor(tensor0_var, [PatternVar(t0)]), tensor1_var(op.expand_dims(t0, 0, 1)) 826 ) 827 tensor1_case = Clause( 828 PatternConstructor(tensor1_var, [PatternVar(t1)]), tensor2_var(op.expand_dims(t1, 0, 1)) 829 ) 830 tensor2_case = Clause( 831 PatternConstructor(tensor2_var, [PatternVar(t2)]), tensor3_var(op.expand_dims(t2, 0, 1)) 832 ) 833 tensor3_case = Clause( 834 PatternConstructor(tensor3_var, [PatternVar(t3)]), tensor4_var(op.expand_dims(t3, 0, 1)) 835 ) 836 tensor4_case = Clause( 837 PatternConstructor(tensor4_var, [PatternVar(t4)]), tensor5_var(op.expand_dims(t4, 0, 1)) 838 ) 839 tensor5_case = Clause( 840 PatternConstructor(tensor5_var, [PatternVar(t5)]), tensor6_var(op.expand_dims(t5, 0, 1)) 841 ) 842 self.prelude.mod[expand_dims_var] = Function( 843 [x], 844 Match( 845 x, 846 [ 847 tensor0_case, 848 tensor1_case, 849 tensor2_case, 850 tensor3_case, 851 tensor4_case, 852 tensor5_case, 853 ], 854 False, 855 ), 856 ) 857 858 def define_tensor_concat(self): 859 """Defines a function to concatenate two tensor_t on the first axis 860 861 tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t 862 """ 863 concat_name = self.get_name("tensor_concatenate") 864 concat_var = GlobalVar(concat_name) 865 setattr(self.prelude, concat_name, concat_var) 866 tensor_type_var = self.get_var("tensor_t") 867 x = Var("x", tensor_type_var()) 868 y = Var("y", tensor_type_var()) 869 870 tensor1_var = self.get_var("tensor1") 871 tensor2_var = self.get_var("tensor2") 872 tensor3_var = self.get_var("tensor3") 873 tensor4_var = self.get_var("tensor4") 874 t11 = Var("t11") 875 t12 = Var("t12") 876 t21 = Var("t21") 877 t22 = Var("t22") 878 t31 = Var("t31") 879 t32 = Var("t32") 880 t41 = Var("t41") 881 t42 = Var("t42") 882 tensor1_case = Clause( 883 PatternConstructor(tensor1_var, [PatternVar(t11)]), 884 Match( 885 y, 886 [ 887 Clause( 888 PatternConstructor(tensor1_var, [PatternVar(t12)]), 889 tensor1_var(op.concatenate([t11, t12], axis=0)), 890 ) 891 ], 892 False, 893 ), 894 ) 895 tensor2_case = Clause( 896 PatternConstructor(tensor2_var, [PatternVar(t21)]), 897 Match( 898 y, 899 [ 900 Clause( 901 PatternConstructor(tensor2_var, [PatternVar(t22)]), 902 tensor2_var(op.concatenate([t21, t22], axis=0)), 903 ) 904 ], 905 False, 906 ), 907 ) 908 tensor3_case = Clause( 909 PatternConstructor(tensor3_var, [PatternVar(t31)]), 910 Match( 911 y, 912 [ 913 Clause( 914 PatternConstructor(tensor3_var, [PatternVar(t32)]), 915 tensor3_var(op.concatenate([t31, t32], axis=0)), 916 ) 917 ], 918 False, 919 ), 920 ) 921 tensor4_case = Clause( 922 PatternConstructor(tensor4_var, [PatternVar(t41)]), 923 Match( 924 y, 925 [ 926 Clause( 927 PatternConstructor(tensor4_var, [PatternVar(t42)]), 928 tensor4_var(op.concatenate([t41, t42], axis=0)), 929 ) 930 ], 931 False, 932 ), 933 ) 934 # op.concatenate does not support tensor with rank higher than 4 935 self.prelude.mod[concat_var] = Function( 936 [x, y], Match(x, [tensor1_case, tensor2_case, tensor3_case, tensor4_case], False) 937 ) 938 939 def define_tensor_array(self): 940 """Defines a function to create a tensor array with size n. 941 tensor_array(n) : Tensor[(), int32] -> list[tensor_t] 942 """ 943 tensor_array_constructor_name = self.get_name("tensor_array") 944 tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name) 945 setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var) 946 tensor_nil_var = self.get_var("tensor_nil") 947 tensor_type_var = self.get_var("tensor_t") 948 n = Var("x", scalar_type("int32")) 949 body = If( 950 equal(n, const(0)), 951 self.prelude.nil(), 952 self.prelude.cons( 953 tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1))) 954 ), 955 ) 956 self.prelude.mod[tensor_array_constructor_var] = Function( 957 [n], body, self.prelude.l(tensor_type_var()), [] 958 ) 959 960 def define_tensor_array_read(self): 961 """Defines a function to get the head of a list. Assume the list has at least one 962 element. 963 964 tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t 965 """ 966 read_name = self.get_name("tensor_array_read") 967 read_var = GlobalVar(read_name) 968 setattr(self.prelude, read_name, read_var) 969 tensor_type_var = self.get_var("tensor_t") 970 971 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 972 n = Var("x", scalar_type("int32")) 973 self.prelude.mod[read_var] = Function( 974 [tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), [] 975 ) 976 977 def define_tensor_array_write(self): 978 """Defines a function to update a tensor array at index n with value v. 979 tensor_array_write(ta, n, v) : 980 list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t] 981 """ 982 write_name = self.get_name("tensor_array_write") 983 write_var = GlobalVar(write_name) 984 setattr(self.prelude, write_name, write_var) 985 tensor_type_var = self.get_var("tensor_t") 986 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 987 n = Var("x", scalar_type("int32")) 988 v = Var("v", tensor_type_var()) 989 self.prelude.mod[write_var] = Function( 990 [tensor_array, n, v], 991 self.prelude.update(tensor_array, n, v), 992 self.prelude.l(tensor_type_var()), 993 [], 994 ) 995 996 def define_tensor_array_unstack_tensor1(self): 997 """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array. 998 tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t] 999 """ 1000 helper_name = self.get_name("tensor_array_unstack_tensor1_helper") 1001 helper_var = GlobalVar(helper_name) 1002 setattr(self.prelude, helper_name, helper_var) 1003 tensor = Var("t", TensorType([Any()], self.dtype)) 1004 up = Var("up", scalar_type("int32")) 1005 i = Var("i", scalar_type("int32")) 1006 tensor_type_var = self.get_var("tensor_t") 1007 tensor0_var = self.get_var("tensor0") 1008 helper_body = If( 1009 equal(i, up), 1010 self.prelude.nil(), 1011 self.prelude.cons( 1012 tensor0_var(op.take(tensor, i)), helper_var(add(i, const(1)), up, tensor) 1013 ), 1014 ) 1015 self.prelude.mod[helper_var] = Function( 1016 [i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), [] 1017 ) 1018 unstack_name = self.get_name("tensor_array_unstack_tensor1") 1019 unstack_var = GlobalVar(unstack_name) 1020 setattr(self.prelude, unstack_name, unstack_var) 1021 tensor1 = Var("tensor", TensorType([Any()], self.dtype)) 1022 shape = op.shape_of(tensor1) 1023 ndim = op.take(shape, const(0)) 1024 self.prelude.mod[unstack_var] = Function( 1025 [tensor1], helper_var(const(0), ndim, tensor1), self.prelude.l(tensor_type_var()), [] 1026 ) 1027 1028 def define_tensor_array_unstack_tensor2(self): 1029 """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array. 1030 1031 tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t] 1032 """ 1033 helper_name = self.get_name("tensor_array_unstack_tensor2_helper") 1034 helper_var = GlobalVar(helper_name) 1035 setattr(self.prelude, helper_name, helper_var) 1036 tensor = Var("t", TensorType([Any(), Any()], self.dtype)) 1037 up = Var("up", scalar_type("int32")) 1038 i = Var("i", scalar_type("int32")) 1039 1040 helper_body = If( 1041 equal(i, up), 1042 self.prelude.nil(), 1043 self.prelude.cons( 1044 self.get_var("tensor1")(op.take(tensor, i, axis=0)), 1045 helper_var(add(i, const(1)), up, tensor), 1046 ), 1047 ) 1048 self.prelude.mod[helper_var] = Function( 1049 [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] 1050 ) 1051 1052 tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2") 1053 tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name) 1054 setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var) 1055 tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype)) 1056 shape = op.shape_of(tensor2) 1057 ndim = op.take(shape, const(0)) 1058 self.prelude.mod[tensor_array_unstack_tensor2_var] = Function( 1059 [tensor2], 1060 helper_var(const(0), ndim, tensor2), 1061 self.prelude.l(self.get_var("tensor_t")()), 1062 [], 1063 ) 1064 1065 def define_tensor_array_unstack_tensor3(self): 1066 """Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array. 1067 1068 tensor_array_unstack_tensor3(t) : tensor_t -> list[tensor_t] 1069 """ 1070 helper_name = self.get_name("tensor_array_unstack_tensor3_helper") 1071 helper_var = GlobalVar(helper_name) 1072 setattr(self.prelude, helper_name, helper_var) 1073 tensor = Var("t", TensorType([Any(), Any(), Any()], self.dtype)) 1074 up = Var("up", scalar_type("int32")) 1075 i = Var("i", scalar_type("int32")) 1076 1077 helper_body = If( 1078 equal(i, up), 1079 self.prelude.nil(), 1080 self.prelude.cons( 1081 self.get_var("tensor2")(op.take(tensor, i, axis=0)), 1082 helper_var(add(i, const(1)), up, tensor), 1083 ), 1084 ) 1085 self.prelude.mod[helper_var] = Function( 1086 [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] 1087 ) 1088 1089 tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3") 1090 tensor_array_unstack_tensor3_var = GlobalVar(tensor_array_unstack_tensor3_name) 1091 setattr(self.prelude, tensor_array_unstack_tensor3_name, tensor_array_unstack_tensor3_var) 1092 tensor3 = Var("tensor", TensorType([Any(), Any(), Any()], self.dtype)) 1093 shape = op.shape_of(tensor3) 1094 ndim = op.take(shape, const(0)) 1095 self.prelude.mod[tensor_array_unstack_tensor3_var] = Function( 1096 [tensor3], 1097 helper_var(const(0), ndim, tensor3), 1098 self.prelude.l(self.get_var("tensor_t")()), 1099 [], 1100 ) 1101 1102 def define_tensor_array_unstack_tensor4(self): 1103 """Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array. 1104 1105 tensor_array_unstack_tensor4(t) : tensor_t -> list[tensor_t] 1106 """ 1107 helper_name = self.get_name("tensor_array_unstack_tensor4_helper") 1108 helper_var = GlobalVar(helper_name) 1109 setattr(self.prelude, helper_name, helper_var) 1110 tensor = Var("t", TensorType([Any(), Any(), Any(), Any()], self.dtype)) 1111 up = Var("up", scalar_type("int32")) 1112 i = Var("i", scalar_type("int32")) 1113 1114 helper_body = If( 1115 equal(i, up), 1116 self.prelude.nil(), 1117 self.prelude.cons( 1118 self.get_var("tensor3")(op.take(tensor, i, axis=0)), 1119 helper_var(add(i, const(1)), up, tensor), 1120 ), 1121 ) 1122 self.prelude.mod[helper_var] = Function( 1123 [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] 1124 ) 1125 1126 tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4") 1127 tensor_array_unstack_tensor4_var = GlobalVar(tensor_array_unstack_tensor4_name) 1128 setattr(self.prelude, tensor_array_unstack_tensor4_name, tensor_array_unstack_tensor4_var) 1129 tensor4 = Var("tensor", TensorType([Any(), Any(), Any(), Any()], self.dtype)) 1130 shape = op.shape_of(tensor4) 1131 ndim = op.take(shape, const(0)) 1132 self.prelude.mod[tensor_array_unstack_tensor4_var] = Function( 1133 [tensor4], 1134 helper_var(const(0), ndim, tensor4), 1135 self.prelude.l(self.get_var("tensor_t")()), 1136 [], 1137 ) 1138 1139 def define_tensor_array_unstack_tensor5(self): 1140 """Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array. 1141 1142 tensor_array_unstack_tensor5(t) : tensor_t -> list[tensor_t] 1143 """ 1144 helper_name = self.get_name("tensor_array_unstack_tensor5_helper") 1145 helper_var = GlobalVar(helper_name) 1146 setattr(self.prelude, helper_name, helper_var) 1147 tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)) 1148 up = Var("up", scalar_type("int32")) 1149 i = Var("i", scalar_type("int32")) 1150 1151 helper_body = If( 1152 equal(i, up), 1153 self.prelude.nil(), 1154 self.prelude.cons( 1155 self.get_var("tensor4")(op.take(tensor, i, axis=0)), 1156 helper_var(add(i, const(1)), up, tensor), 1157 ), 1158 ) 1159 self.prelude.mod[helper_var] = Function( 1160 [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] 1161 ) 1162 1163 tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5") 1164 tensor_array_unstack_tensor5_var = GlobalVar(tensor_array_unstack_tensor5_name) 1165 setattr(self.prelude, tensor_array_unstack_tensor5_name, tensor_array_unstack_tensor5_var) 1166 tensor5 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)) 1167 shape = op.shape_of(tensor5) 1168 ndim = op.take(shape, const(0)) 1169 self.prelude.mod[tensor_array_unstack_tensor5_var] = Function( 1170 [tensor5], 1171 helper_var(const(0), ndim, tensor5), 1172 self.prelude.l(self.get_var("tensor_t")()), 1173 [], 1174 ) 1175 1176 def define_tensor_array_unstack_tensor6(self): 1177 """Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array. 1178 1179 tensor_array_unstack_tensor6(t) : tensor_t -> list[tensor_t] 1180 """ 1181 helper_name = self.get_name("tensor_array_unstack_tensor6_helper") 1182 helper_var = GlobalVar(helper_name) 1183 setattr(self.prelude, helper_name, helper_var) 1184 tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)) 1185 up = Var("up", scalar_type("int32")) 1186 i = Var("i", scalar_type("int32")) 1187 1188 helper_body = If( 1189 equal(i, up), 1190 self.prelude.nil(), 1191 self.prelude.cons( 1192 self.get_var("tensor5")(op.take(tensor, i, axis=0)), 1193 helper_var(add(i, const(1)), up, tensor), 1194 ), 1195 ) 1196 self.prelude.mod[helper_var] = Function( 1197 [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), [] 1198 ) 1199 1200 tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6") 1201 tensor_array_unstack_tensor6_var = GlobalVar(tensor_array_unstack_tensor6_name) 1202 setattr(self.prelude, tensor_array_unstack_tensor6_name, tensor_array_unstack_tensor6_var) 1203 tensor6 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)) 1204 shape = op.shape_of(tensor6) 1205 ndim = op.take(shape, const(0)) 1206 self.prelude.mod[tensor_array_unstack_tensor6_var] = Function( 1207 [tensor6], 1208 helper_var(const(0), ndim, tensor6), 1209 self.prelude.l(self.get_var("tensor_t")()), 1210 [], 1211 ) 1212 1213 def define_tensor_array_scatter(self): 1214 """Defines a function to scatter the values of a tensor_t in indices of a tensor array. 1215 tensor_array_scatter(ta, indices, value) : 1216 list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] 1217 """ 1218 tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") 1219 tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name) 1220 tensor_t = self.get_var("tensor_t") 1221 ta = Var("ta", self.prelude.l(tensor_t())) 1222 current = Var("current", scalar_type("int32")) 1223 limit = Var("limit", scalar_type("int32")) 1224 indices_ = Var("indices_", TensorType([Any()], "int32")) 1225 values_ = Var("values_", self.prelude.l(tensor_t())) 1226 write_var = self.get_var("tensor_array_write") 1227 read_var = self.get_var("tensor_array_read") 1228 helper_body = If( 1229 equal(current, limit), 1230 ta, 1231 tensor_array_scatter_helper_var( 1232 write_var(ta, op.take(indices_, current), read_var(values_, current)), 1233 add(current, const(1)), 1234 limit, 1235 indices_, 1236 values_, 1237 ), 1238 ) 1239 self.prelude.mod[tensor_array_scatter_helper_var] = Function( 1240 [ta, current, limit, indices_, values_], helper_body, self.prelude.l(tensor_t()), [] 1241 ) 1242 tensor_array_scatter_name = self.get_name("tensor_array_scatter") 1243 tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name) 1244 setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var) 1245 tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) 1246 indices = Var("indices", TensorType([Any()], "int32")) 1247 values = Var("values", self.prelude.l(tensor_t())) 1248 indices_shape = op.shape_of(indices) 1249 limit = op.take(indices_shape, const(0)) 1250 body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values) 1251 self.prelude.mod[tensor_array_scatter_var] = Function( 1252 [tensor_array, indices, values], body, self.prelude.l(tensor_t()), [] 1253 ) 1254 1255 def define_tensor_array_split(self): 1256 """Defines a function to split the values of a tensor_t into a tensor array. 1257 tensor_array_split(ta, value, lengths) : 1258 list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t] 1259 """ 1260 tensor_t = self.get_var("tensor_t") 1261 tensor_array_split_helper_name = self.get_name("ta_split_helper") 1262 tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name) 1263 setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var) 1264 ta1 = Var("tensor_array", self.prelude.l(tensor_t())) 1265 value1 = Var("value1", tensor_t()) 1266 offset1 = Var("offset1", scalar_type("int32")) 1267 current1 = Var("current1", scalar_type("int32")) 1268 limit1 = Var("limit1", scalar_type("int32")) 1269 lengths1 = Var("lengths", TensorType([Any()], "int32")) 1270 write_var = self.get_var("tensor_array_write") 1271 take_var = self.get_var("tensor_take") 1272 helper1_body = If( 1273 equal(current1, limit1), 1274 ta1, 1275 write_var( 1276 tensor_array_split_helper_var( 1277 ta1, 1278 value1, 1279 add(offset1, op.take(lengths1, current1)), 1280 add(current1, const(1)), 1281 limit1, 1282 lengths1, 1283 ), 1284 current1, 1285 take_var(value1, offset1, add(op.take(lengths1, current1), offset1)), 1286 ), 1287 ) 1288 self.prelude.mod[tensor_array_split_helper_var] = Function( 1289 [ta1, value1, offset1, current1, limit1, lengths1], 1290 helper1_body, 1291 self.prelude.l(tensor_t()), 1292 [], 1293 ) 1294 split_name = self.get_name("tensor_array_split") 1295 split_var = GlobalVar(split_name) 1296 setattr(self.prelude, split_name, split_var) 1297 tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) 1298 value = Var("value", tensor_t()) 1299 lengths = Var("lengths", TensorType([Any()], "int32")) 1300 lengths_shape = op.shape_of(lengths) 1301 lengths_limit = op.take(lengths_shape, const(0)) 1302 body = tensor_array_split_helper_var( 1303 tensor_array, value, const(0), const(0), lengths_limit, lengths 1304 ) 1305 self.prelude.mod[split_var] = Function( 1306 [tensor_array, value, lengths], body, self.prelude.l(tensor_t()), [] 1307 ) 1308 1309 def define_tensor_array_concat(self): 1310 """Defines a function to return the values in the tensor array as concatenated tensor_t. 1311 tensor_array_concat(ta) : list[tensor_t] -> tensor_t 1312 """ 1313 concat_name = self.get_name("tensor_array_concat") 1314 concat_var = GlobalVar(concat_name) 1315 setattr(self.prelude, concat_name, concat_var) 1316 tensor_concat_var = self.get_var("tensor_concatenate") 1317 tensor_t = self.get_var("tensor_t") 1318 tensor_nil_var = self.get_var("tensor_nil") 1319 tensor_array = Var("tensor_array", self.prelude.l(tensor_t())) 1320 hd = Var("hd") 1321 tl = Var("tl") 1322 nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var()) 1323 cons_case = Clause( 1324 PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]), 1325 Match( 1326 tl, 1327 [ 1328 Clause(PatternConstructor(self.prelude.nil), hd), 1329 Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))), 1330 ], 1331 False, 1332 ), 1333 ) 1334 self.prelude.mod[concat_var] = Function( 1335 [tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_t(), [] 1336 ) 1337 1338 def define_tensor_array_gather(self): 1339 """Defines a function to return the selected values in a tensor array as tensor_t. 1340 tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t 1341 """ 1342 helper_name = self.get_name("tensor_array_gather_helper") 1343 helper_var = GlobalVar(helper_name) 1344 setattr(self.prelude, helper_name, helper_var) 1345 tensor_type_var = self.get_var("tensor_t") 1346 stack_var = self.get_var("tensor_array_stack") 1347 read_var = self.get_var("tensor_array_read") 1348 ta = Var("ta", self.prelude.l(tensor_type_var())) 1349 accu = Var("accu", self.prelude.l(tensor_type_var())) 1350 current = Var("current", scalar_type("int32")) 1351 limit = Var("limit", scalar_type("int32")) 1352 indices_ = Var("indices_", TensorType([Any()], "int32")) 1353 helper_body = If( 1354 equal(current, const(0)), 1355 stack_var(accu), 1356 helper_var( 1357 ta, 1358 self.prelude.cons( 1359 read_var(ta, op.take(indices_, subtract(current, const(1)))), accu 1360 ), 1361 subtract(current, const(1)), 1362 limit, 1363 indices_, 1364 ), 1365 ) 1366 self.prelude.mod[helper_var] = Function( 1367 [ta, accu, current, limit, indices_], helper_body, tensor_type_var(), [] 1368 ) 1369 gather_name = self.get_name("tensor_array_gather") 1370 gather_var = GlobalVar(gather_name) 1371 setattr(self.prelude, gather_name, gather_var) 1372 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 1373 indices = Var("indices", TensorType([Any()], "int32")) 1374 indices_shape = op.shape_of(indices) 1375 limit = op.take(indices_shape, const(0)) 1376 body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) 1377 self.prelude.mod[gather_var] = Function( 1378 [tensor_array, indices], body, tensor_type_var(), [] 1379 ) 1380 1381 def define_tensor_array_stack(self): 1382 """Defines a function to get the values in the tensor array as a stack tensor_t. 1383 tensor_array_stack(l) : list[tensor_t] -> tensor_t 1384 """ 1385 stack_name = self.get_name("tensor_array_stack") 1386 stack_var = GlobalVar(stack_name) 1387 setattr(self.prelude, stack_name, stack_var) 1388 tensor_type_var = self.get_var("tensor_t") 1389 tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var())) 1390 expand_dims_var = self.get_var("tensor_expand_dims") 1391 concat_var = self.get_var("tensor_concatenate") 1392 tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) 1393 tensors = self.prelude.foldl( 1394 concat_var, 1395 self.prelude.hd(tensor_array_expand_dims), 1396 self.prelude.tl(tensor_array_expand_dims), 1397 ) 1398 self.prelude.mod[stack_var] = Function( 1399 [tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), [] 1400 ) 1401 1402 def register(self): 1403 """Register all tensor array ops in Prelude""" 1404 self.define_tensor_adt() 1405 self.define_tensor_take() 1406 self.define_tensor_expand_dims() 1407 self.define_tensor_concat() 1408 self.define_tensor_array() 1409 self.define_tensor_array_read() 1410 self.define_tensor_array_write() 1411 self.define_tensor_array_unstack_tensor1() 1412 self.define_tensor_array_unstack_tensor2() 1413 self.define_tensor_array_unstack_tensor3() 1414 self.define_tensor_array_unstack_tensor4() 1415 self.define_tensor_array_unstack_tensor5() 1416 self.define_tensor_array_unstack_tensor6() 1417 self.define_tensor_array_scatter() 1418 self.define_tensor_array_split() 1419 self.define_tensor_array_concat() 1420 self.define_tensor_array_stack() 1421 # TODO(wweic): Gather fails in PartialEvaluate 1422 # self.define_tensor_array_gather() 1423 1424 1425class Prelude: 1426 """Contains standard definitions.""" 1427 1428 def __init__(self, mod=None): 1429 if mod is None: 1430 mod = IRModule() 1431 self.mod = mod 1432 self.load_prelude() 1433 1434 def get_name(self, canonical, dtype): 1435 """Get name corresponding to the canonical name""" 1436 if canonical == "tensor_t": 1437 return "tensor_{}_t".format(dtype) 1438 return "{}_{}".format(canonical, dtype) 1439 1440 def get_var(self, canonical, dtype): 1441 """Get var corresponding to the canonical name""" 1442 name = self.get_name(canonical, dtype) 1443 return getattr(self, name) 1444 1445 def get_name_static(self, canonical, dtype, shape): 1446 """Get name corresponding to the canonical name""" 1447 return _get_name_static(canonical, dtype, shape) 1448 1449 def get_var_static(self, canonical, dtype, shape): 1450 """Get var corresponding to the canonical name""" 1451 name = self.get_name_static(canonical, dtype, shape) 1452 return getattr(self, name) 1453 1454 def load_prelude(self): 1455 """Parses the Prelude from Relay's text format into a module.""" 1456 # TODO(@jroesch): we should remove this helper when we port over prelude 1457 self.mod.import_from_std("prelude.rly") 1458 1459 self.l = self.mod.get_global_type_var("List") 1460 list_adt = self.mod[self.l] 1461 self.cons = list_adt.constructors[0] 1462 self.nil = list_adt.constructors[1] 1463 1464 self.optional = self.mod.get_global_type_var("Option") 1465 optional_adt = self.mod[self.optional] 1466 self.some = optional_adt.constructors[0] 1467 self.none = optional_adt.constructors[1] 1468 1469 self.tree = self.mod.get_global_type_var("Tree") 1470 tree_adt = self.mod[self.tree] 1471 self.rose = tree_adt.constructors[0] 1472 1473 GLOBAL_DEFS = [ 1474 "id", 1475 "compose", 1476 "flip", 1477 "hd", 1478 "tl", 1479 "nth", 1480 "update", 1481 "map", 1482 "foldl", 1483 "foldr", 1484 "foldr1", 1485 "concat", 1486 "filter", 1487 "zip", 1488 "rev", 1489 "map_accuml", 1490 "map_accumr", 1491 "unfoldl", 1492 "unfoldr", 1493 "sum", 1494 "length", 1495 "tmap", 1496 "size", 1497 "iterate", 1498 ] 1499 for global_def in GLOBAL_DEFS: 1500 setattr(self, global_def, self.mod.get_global_var(global_def)) 1501 1502 for dtype in [ 1503 "float32", 1504 "float16", 1505 "float64", 1506 "int32", 1507 "uint8", 1508 "int8", 1509 "int16", 1510 "uint16", 1511 "int64", 1512 ]: 1513 tensor_array_ops = TensorArrayOps(self, dtype) 1514 tensor_array_ops.register() 1515