1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
18"""A prelude containing useful global functions and ADT definitions."""
19from tvm.ir import IRModule, TypeCall
20from tvm.relay.transform import ToANormalFormExpr
21
22from .ty import GlobalTypeVar, TensorType, Any, scalar_type
23from .expr import Var, GlobalVar, If, const
24from .function import Function
25from .op.tensor import add, subtract, equal
26from .adt import Constructor, TypeData, Clause, Match
27from .adt import PatternConstructor, PatternVar, PatternWildcard
28from . import op, transform
29from .analysis import free_vars
30
31
32def get_tensor_array_shape(expr, dtype, prelude):
33    """Get the static shape of a tensor array if it has fixed rank shape.
34
35    By design, static ADT tensor in TVM has type name in the format
36    of static_tensor_dim0_dim1_..._dimN_t.
37
38    Parameters
39    ----------
40    expr : Relay Expr
41        Input expression.
42
43    dtype : str
44        Data type.
45
46    prelude : Prelude
47        Tensor array prelude
48
49    Returns
50    -------
51    shape : tuple of (int, Any) or None
52        The output shape. None if input tensor array
53        has dynamic shape.
54    """
55    mod = prelude.mod
56    mod["main"] = Function(free_vars(expr), expr)
57    mod = transform.InferType()(mod)
58    checked_type = mod["main"].body.checked_type
59    assert isinstance(checked_type, TypeCall), "Input must be a tensor array."
60    ta_type_str = checked_type.args[0].func.name_hint
61    static_ta_ty_start = "static_tensor_{}".format(dtype)
62    if ta_type_str.startswith(static_ta_ty_start):
63        shape_str = ta_type_str.replace("{}_".format(static_ta_ty_start), "").replace("_t", "")
64        shape = []
65        if "scalar" not in shape_str:
66            for dim_str in shape_str.split("_"):
67                if dim_str == "?":
68                    shape.append(Any())
69                else:
70                    shape.append(int(dim_str))
71        return tuple(shape)
72    return None
73
74
75def _get_name_static(canonical, dtype, shape):
76    """Get name for static shape tensor array op corresponding
77    to the canonical name"""
78    shape_str = "_".join([str(dim) for dim in shape])
79    if len(shape_str) == 0:
80        shape_str = "scalar"
81    if canonical == "tensor_t":
82        return "static_tensor_{}_{}_t".format(dtype, shape_str)
83    return "{}_{}_{}".format(canonical, dtype, shape_str)
84
85
86class StaticTensorArrayOps(object):
87    """Contains tensor array related ops for fixed rank tensor array"""
88
89    def __init__(self, prelude, dtype, shape):
90        """Create tensor array ops registry"""
91        self.prelude = prelude
92        self.dtype = dtype
93        self.shape = shape
94
95    def get_name(self, canonical):
96        """Get name corresponding to the canonical name"""
97        return _get_name_static(canonical, self.dtype, self.shape)
98
99    def get_var(self, canonical):
100        """Get var corresponding to the canonical name"""
101        name = self.get_name(canonical)
102        return getattr(self.prelude, name)
103
104    def define_tensor_adt(self):
105        """Defines the static tensor ADT, which is the container for tensors
106        with fixed shapes."""
107        tensor_type_name = self.get_name("tensor_t")
108        # Skip register if tensor type is already registered.
109        global_type_names = set()
110        for g_ty_var in self.prelude.mod.get_global_type_vars():
111            global_type_names.add(g_ty_var.name_hint)
112        if tensor_type_name in global_type_names:
113            return
114
115        tensor_type_var = GlobalTypeVar(tensor_type_name)
116        setattr(self.prelude, tensor_type_name, tensor_type_var)
117        tensor_type = TensorType(self.shape, self.dtype)
118        tensor_constructor_name = self.get_name("tensor_constructor")
119
120        tensor_nil_name = self.get_name("tensor_nil")
121        tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
122        tensor_case = Constructor(tensor_constructor_name, [tensor_type], tensor_type_var)
123
124        setattr(self.prelude, tensor_nil_name, tensor_nil_case)
125        setattr(self.prelude, tensor_constructor_name, tensor_case)
126        self.prelude.mod[tensor_type_var] = TypeData(
127            tensor_type_var, [], [tensor_nil_case, tensor_case]
128        )
129
130    def define_tensor_array(self):
131        """Defines a function to create a tensor array with size n.
132        tensor_array(n) : Tensor[(), int32] -> list[tensor_t]
133        """
134        tensor_array_constructor_name = self.get_name("tensor_array")
135        tensor_array_constructor_var = self._create_global_var(tensor_array_constructor_name)
136        setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
137        tensor_nil_var = self.get_var("tensor_nil")
138        tensor_type_var = self.get_var("tensor_t")
139        n = Var("x", scalar_type("int32"))
140        body = If(
141            equal(n, const(0)),
142            self.prelude.nil(),
143            self.prelude.cons(
144                tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1)))
145            ),
146        )
147        self.prelude.mod[tensor_array_constructor_var] = Function(
148            [n], body, self.prelude.l(tensor_type_var()), []
149        )
150
151    def define_tensor_take(self):
152        """Defines a function to return a range of tensor_t on axis 0.
153        tensor_take(t, lower, upper) :
154        tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
155        """
156        # We don't register take for scalar tensor.
157        ndim = len(self.shape)
158        if ndim == 0:
159            return
160
161        take_name = self.get_name("tensor_take")
162        take_var = self._create_global_var(take_name)
163        setattr(self.prelude, take_name, take_var)
164        origin_tensor_constructor = self.get_var("tensor_constructor")
165
166        output_shape = [
167            Any(),
168        ] + list(self.shape[1:])
169        tensor_type_var, tensor_constructor = self._get_adt_by_shape(output_shape)
170
171        t = Var("tensor", self.get_var("tensor_t")())
172        lower = Var("lower", scalar_type("int32"))
173        upper = Var("upper", scalar_type("int32"))
174        tvar = Var("t")
175        case = Clause(
176            PatternConstructor(origin_tensor_constructor, [PatternVar(tvar)]),
177            tensor_constructor(op.take(tvar, op.arange(lower, upper, dtype="int32"), axis=0)),
178        )
179        self.prelude.mod[take_var] = Function(
180            [t, lower, upper], Match(t, [case], False), tensor_type_var(), []
181        )
182
183    def define_tensor_concatenate(self):
184        """Defines a function to concatenate two tensor_t on axis 0.
185        tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
186        """
187        # We don't register concatenate for scalar tensor.
188        ndim = len(self.shape)
189        if ndim == 0:
190            return
191
192        concat_name = self.get_name("tensor_concatenate")
193        concat_var = self._create_global_var(concat_name)
194        setattr(self.prelude, concat_name, concat_var)
195        output_shape = [
196            Any(),
197        ] + list(self.shape[1:])
198        tensor_type_var, tensor_constructor = self._get_adt_by_shape(output_shape)
199
200        origin_tensor_constructor = self.get_var("tensor_constructor")
201        origin_tensor_type_var = self.get_var("tensor_t")
202        x = Var("x", origin_tensor_type_var())
203        y = Var("y", origin_tensor_type_var())
204        t1 = Var("t1")
205        t2 = Var("t2")
206
207        case = Clause(
208            PatternConstructor(origin_tensor_constructor, [PatternVar(t1)]),
209            Match(
210                y,
211                [
212                    Clause(
213                        PatternConstructor(origin_tensor_constructor, [PatternVar(t2)]),
214                        tensor_constructor(op.concatenate([t1, t2], axis=0)),
215                    )
216                ],
217                False,
218            ),
219        )
220
221        self.prelude.mod[concat_var] = Function(
222            [x, y], Match(x, [case], False), tensor_type_var(), []
223        )
224
225    def define_tensor_expand_dims(self):
226        """Defines a function to grow a tensor_t's rank by adding one dimension in front
227        of the original tensor_t.
228        tensor_expand_dims(t) : tensor_t -> tensor_t
229        """
230        expand_dims_name = self.get_name("tensor_expand_dims")
231        expand_dims_var = self._create_global_var(expand_dims_name)
232        setattr(self.prelude, expand_dims_name, expand_dims_var)
233        origin_tensor_type_var = self.get_var("tensor_t")
234        origin_tensor_constructor = self.get_var("tensor_constructor")
235        x = Var("x", origin_tensor_type_var())
236
237        # Note: we set the added axis to be Any() instead of 1 due to
238        # in stack op, we need to recursively concatenate.
239        tensor_type_var, tensor_constructor = self._get_adt_by_shape(
240            [
241                Any(),
242            ]
243            + list(self.shape)
244        )
245        t = Var("t")
246        case = Clause(
247            PatternConstructor(origin_tensor_constructor, [PatternVar(t)]),
248            tensor_constructor(op.expand_dims(t, 0, 1)),
249        )
250
251        self.prelude.mod[expand_dims_var] = Function(
252            [x], Match(x, [case], False), tensor_type_var(), []
253        )
254
255    def define_tensor_array_read(self):
256        """Defines a function to get the nth element of a list. Assume the list has at least one
257        element.
258        tensor_array_read(ta, n) : list[static_tensor_t] -> Tensor[(), int32] ->
259        Tensor[self.shape, self.dtype]
260        """
261        read_name = self.get_name("tensor_array_read")
262        read_var = self._create_global_var(read_name)
263        setattr(self.prelude, read_name, read_var)
264        tensor_type_var = self.get_var("tensor_t")
265
266        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
267        n = Var("x", scalar_type("int32"))
268        self.prelude.mod[read_var] = Function(
269            [tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), []
270        )
271
272    def define_tensor_array_write(self):
273        """Defines a function to update a tensor array at index n with value v.
274        tensor_array_write(ta, n, v) :
275            list[static_tensor_t] -> Tensor[(), int32] -> Tensor[self.shape, self.dtype] ->
276            list[static_tensor_t]
277        """
278        write_name = self.get_name("tensor_array_write")
279        write_var = self._create_global_var(write_name)
280        setattr(self.prelude, write_name, write_var)
281        tensor_type_var = self.get_var("tensor_t")
282        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
283        n = Var("x", scalar_type("int32"))
284        v = Var("v", tensor_type_var())
285        self.prelude.mod[write_var] = Function(
286            [tensor_array, n, v],
287            self.prelude.update(tensor_array, n, v),
288            self.prelude.l(tensor_type_var()),
289            [],
290        )
291
292    def define_tensor_array_unstack(self):
293        """Defines a function to unstack the values of a tensor_t in a tensor array.
294        tensor_array_unstack_tensor(t) : tensor_t -> list[tensor_t]
295        """
296        ndim = len(self.shape)
297        # We don't register unstack for scalar tensor array
298        if ndim == 0:
299            return
300
301        helper_name = self.get_name("tensor_array_unstack_helper")
302        helper_var = self._create_global_var(helper_name)
303        setattr(self.prelude, helper_name, helper_var)
304        tensor = Var("t", TensorType(self.shape, self.dtype))
305        up = Var("up", scalar_type("int32"))
306        i = Var("i", scalar_type("int32"))
307        tensor_var = Var("tensor", TensorType(self.shape, self.dtype))
308
309        reduced_tensor_type_var, tensor_constructor = self._get_adt_by_shape(self.shape[1:])
310        helper_body = If(
311            equal(i, up),
312            self.prelude.nil(),
313            self.prelude.cons(
314                tensor_constructor(op.take(tensor, i, axis=0)),
315                helper_var(add(i, const(1)), up, tensor),
316            ),
317        )
318        self.prelude.mod[helper_var] = Function(
319            [i, up, tensor], helper_body, self.prelude.l(reduced_tensor_type_var()), []
320        )
321
322        unstack_name = self.get_name("tensor_array_unstack")
323        unstack_var = self._create_global_var(unstack_name)
324        setattr(self.prelude, unstack_name, unstack_var)
325        shape = op.shape_of(tensor_var)
326        unstack_length = op.take(shape, const(0))
327        self.prelude.mod[unstack_var] = Function(
328            [tensor_var],
329            helper_var(const(0), unstack_length, tensor_var),
330            self.prelude.l(reduced_tensor_type_var()),
331            [],
332        )
333
334    def define_tensor_array_scatter(self, indices_shape=None, force_update=False):
335        """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
336        tensor_array_scatter(ta, indices, value) :
337            list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t]
338
339        Set static indices shape by specifying indices_shape.
340        Set force_update to get static indices shape operator.
341        """
342        # When this operator has already been registered, only update
343        # when force_update is set. This should be used only when we need to
344        # redefine this op for static indices shape.
345        tensor_array_scatter_name = self.get_name("tensor_array_scatter")
346        if hasattr(self.prelude, tensor_array_scatter_name) and not force_update:
347            return
348
349        tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
350        tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name)
351        tensor_type_var = self.get_var("tensor_t")
352        ta = Var("ta", self.prelude.l(tensor_type_var()))
353        current = Var("current", scalar_type("int32"))
354        limit = Var("limit", scalar_type("int32"))
355        indices_ = Var("indices_", TensorType(indices_shape or [Any()], "int32"))
356        values_ = Var("values_", self.prelude.l(tensor_type_var()))
357        write_var = self.get_var("tensor_array_write")
358        read_var = self.get_var("tensor_array_read")
359        helper_body = If(
360            equal(current, limit),
361            ta,
362            tensor_array_scatter_helper_var(
363                write_var(ta, op.take(indices_, current), read_var(values_, current)),
364                add(current, const(1)),
365                limit,
366                indices_,
367                values_,
368            ),
369        )
370        self.prelude.mod[tensor_array_scatter_helper_var] = Function(
371            [ta, current, limit, indices_, values_],
372            helper_body,
373            self.prelude.l(tensor_type_var()),
374            [],
375        )
376
377        tensor_array_scatter_var = self._create_global_var(tensor_array_scatter_name)
378        setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
379        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
380
381        indices = Var("indices", TensorType(indices_shape or [Any()], "int32"))
382        values = Var("values", self.prelude.l(tensor_type_var()))
383        if indices_shape is None:
384            indices_shape = op.shape_of(indices)
385            limit = op.take(indices_shape, const(0))
386        else:
387            limit = const(indices_shape[0])
388
389        body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
390        self.prelude.mod[tensor_array_scatter_var] = Function(
391            [tensor_array, indices, values], body, self.prelude.l(tensor_type_var()), []
392        )
393
394    def define_tensor_array_split(self, value_shape=None, lengths_shape=None, force_update=False):
395        """Defines a function to split the values of a tensor_t into a tensor array.
396        tensor_array_split(ta, value, lengths) :
397            list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
398
399        Set static value and lengths shapes by specifying value_shape and lengths_shape.
400        Set force_update to get static value and lengths shape operator.
401        """
402        # Skip scalar case
403        ndim = len(self.shape)
404        if ndim == 0:
405            return
406
407        # When this operator has already been registered, only update
408        # when force_update is set. This should be used only when we need to
409        # redefine this op for static value/indices shape.
410        split_name = self.get_name("tensor_array_split")
411        if hasattr(self.prelude, split_name) and not force_update:
412            return
413
414        tensor_type_var = self.get_var("tensor_t")
415        tensor_array_split_helper_name = self.get_name("ta_split_helper")
416        tensor_array_split_helper_var = self._create_global_var(tensor_array_split_helper_name)
417        setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
418        output_shape = [
419            Any(),
420        ] + list(self.shape[1:])
421        output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
422
423        if value_shape is None:
424            value_type_var = tensor_type_var
425            take_var = self.get_var("tensor_take")
426        else:
427            value_type_var, _ = self._get_adt_by_shape(value_shape)
428            # Also get static shape take operator
429            origin_shape = list(self.shape)
430            self.shape = value_shape
431            self.define_tensor_take()
432            take_var = self.get_var("tensor_take")
433            self.shape = origin_shape
434
435        ta1 = Var("tensor_array", self.prelude.l(output_tensor_type_var()))
436        value1 = Var("value1", value_type_var())
437        offset1 = Var("offset1", scalar_type("int32"))
438        current1 = Var("current1", scalar_type("int32"))
439        limit1 = Var("limit1", scalar_type("int32"))
440        lengths1 = Var("lengths", TensorType(lengths_shape or [Any()], "int32"))
441
442        # Register write for output shape
443        origin_shape = list(self.shape)
444        self.shape = output_shape
445        self.define_tensor_array_write()
446        write_var = self.get_var("tensor_array_write")
447        self.shape = origin_shape
448        helper1_body = If(
449            equal(current1, limit1),
450            ta1,
451            write_var(
452                tensor_array_split_helper_var(
453                    ta1,
454                    value1,
455                    add(offset1, op.take(lengths1, current1)),
456                    add(current1, const(1)),
457                    limit1,
458                    lengths1,
459                ),
460                current1,
461                take_var(value1, offset1, add(op.take(lengths1, current1), offset1)),
462            ),
463        )
464        self.prelude.mod[tensor_array_split_helper_var] = Function(
465            [ta1, value1, offset1, current1, limit1, lengths1],
466            helper1_body,
467            self.prelude.l(output_tensor_type_var()),
468            [],
469        )
470        split_var = self._create_global_var(split_name)
471        setattr(self.prelude, split_name, split_var)
472        tensor_array = Var("tensor_array", self.prelude.l(output_tensor_type_var()))
473
474        value = Var("value", value_type_var())
475        lengths = Var("lengths", TensorType(lengths_shape or [Any()], "int32"))
476        if lengths_shape is None:
477            lengths_shape = op.shape_of(lengths)
478            lengths_limit = op.take(lengths_shape, const(0))
479        else:
480            lengths_limit = const(lengths_shape[0])
481        body = tensor_array_split_helper_var(
482            tensor_array, value, const(0), const(0), lengths_limit, lengths
483        )
484
485        self.prelude.mod[split_var] = Function(
486            [tensor_array, value, lengths], body, self.prelude.l(output_tensor_type_var()), []
487        )
488
489    def define_tensor_array_concat(self):
490        """Defines a function to return the values in the tensor array as concatenated tensor_t.
491        tensor_array_concat(ta) : list[tensor_t] -> tensor_t
492        """
493        # We don't register concat for scalar tensor array.
494        ndim = len(self.shape)
495        if ndim == 0:
496            return
497
498        concat_name = self.get_name("tensor_array_concat")
499        concat_var = self._create_global_var(concat_name)
500        setattr(self.prelude, concat_name, concat_var)
501
502        output_shape = [
503            Any(),
504        ] + list(self.shape[1:])
505        tensor_type_var, _ = self._get_adt_by_shape(output_shape)
506
507        # Register tensor concatenate and get tensor_nil var for output shape
508        origin_shape = self.shape
509        self.shape = output_shape
510        self.define_tensor_concatenate()
511        tensor_concat_var = self.get_var("tensor_concatenate")
512        tensor_nil_var = self.get_var("tensor_nil")
513        self.shape = origin_shape
514
515        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
516        hd = Var("hd")
517        tl = Var("tl")
518        nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
519        cons_case = Clause(
520            PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
521            Match(
522                tl,
523                [
524                    Clause(PatternConstructor(self.prelude.nil), hd),
525                    Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))),
526                ],
527                False,
528            ),
529        )
530        self.prelude.mod[concat_var] = Function(
531            [tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_type_var(), []
532        )
533
534    def define_tensor_array_stack(self):
535        """Defines a function to get the values in the tensor array as a stack tensor_t.
536        tensor_array_stack(l) : list[tensor_t] -> tensor_t
537        """
538        stack_name = self.get_name("tensor_array_stack")
539        stack_var = self._create_global_var(stack_name)
540        setattr(self.prelude, stack_name, stack_var)
541        tensor_type_var = self.get_var("tensor_t")
542        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
543        expand_dims_var = self.get_var("tensor_expand_dims")
544
545        # Register tensor_concatenate for output_shape
546        origin_shape = self.shape
547        output_shape = [
548            Any(),
549        ] + list(self.shape)
550        self.shape = output_shape
551        self.define_tensor_concatenate()
552        concat_var = self.get_var("tensor_concatenate")
553        self.shape = origin_shape
554
555        tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
556        tensors = self.prelude.foldl(
557            concat_var,
558            self.prelude.hd(tensor_array_expand_dims),
559            self.prelude.tl(tensor_array_expand_dims),
560        )
561        output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
562        self.prelude.mod[stack_var] = Function(
563            [tensor_array], tensors, output_tensor_type_var(), []
564        )
565
566    def define_tensor_array_gather(self):
567        """Defines a function to return the selected values in a tensor array as tensor_t.
568        tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t
569        """
570        helper_name = self.get_name("tensor_array_gather_helper")
571        helper_var = self._create_global_var(helper_name)
572        setattr(self.prelude, helper_name, helper_var)
573        tensor_type_var = self.get_var("tensor_t")
574        output_shape = [
575            Any(),
576        ] + list(self.shape)
577        output_tensor_type_var, _ = self._get_adt_by_shape(output_shape)
578        stack_var = self.get_var("tensor_array_stack")
579        read_var = self.get_var("tensor_array_read")
580        ta = Var("ta", self.prelude.l(tensor_type_var()))
581        accu = Var("accu", self.prelude.l(tensor_type_var()))
582        current = Var("current", scalar_type("int32"))
583        limit = Var("limit", scalar_type("int32"))
584        indices_ = Var("indices_", TensorType([Any()], "int32"))
585        helper_body = If(
586            equal(current, const(0)),
587            stack_var(accu),
588            helper_var(
589                ta,
590                self.prelude.cons(
591                    read_var(ta, op.take(indices_, subtract(current, const(1)))), accu
592                ),
593                subtract(current, const(1)),
594                limit,
595                indices_,
596            ),
597        )
598        self.prelude.mod[helper_var] = Function(
599            [ta, accu, current, limit, indices_], helper_body, output_tensor_type_var(), []
600        )
601        gather_name = self.get_name("tensor_array_gather")
602        gather_var = self._create_global_var(gather_name)
603        setattr(self.prelude, gather_name, gather_var)
604        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
605        indices = Var("indices", TensorType([Any()], "int32"))
606        indices_shape = op.shape_of(indices)
607        limit = op.take(indices_shape, const(0))
608        body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
609        self.prelude.mod[gather_var] = Function(
610            [tensor_array, indices], body, output_tensor_type_var(), []
611        )
612
613    def define_tensor_get_data(self):
614        """Defines a function to get a Tensor from tensor_t with given shape."""
615        tensor_get_data_name = self.get_name("tensor_get_data")
616        tensor_get_data_var = self._create_global_var(tensor_get_data_name)
617        setattr(self.prelude, tensor_get_data_name, tensor_get_data_var)
618        tensor_type_var = self.get_var("tensor_t")
619        tensor_constructor = self.get_var("tensor_constructor")
620        t = Var("tensor", tensor_type_var())
621        tvar = Var("t")
622        case = Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar)
623        self.prelude.mod[tensor_get_data_var] = Function(
624            [t], Match(t, [case], False), TensorType(self.shape, self.dtype), []
625        )
626
627    def register(self):
628        """Register all tensor array ops in Prelude"""
629        self.define_tensor_adt()
630        self.define_tensor_take()
631        self.define_tensor_concatenate()
632        self.define_tensor_expand_dims()
633        self.define_tensor_array()
634        self.define_tensor_array_read()
635        self.define_tensor_array_write()
636        self.define_tensor_array_unstack()
637        self.define_tensor_array_scatter()
638        self.define_tensor_array_split()
639        self.define_tensor_array_concat()
640        self.define_tensor_array_stack()
641        self.define_tensor_array_gather()
642        self.define_tensor_get_data()
643
644    def _get_adt_by_shape(self, shape):
645        """Get ADT type and constructor with given shape."""
646        origin_shape = self.shape
647        self.shape = shape
648        self.define_tensor_adt()
649        tensor_type_var = self.get_var("tensor_t")
650        tensor_constructor = self.get_var("tensor_constructor")
651        self.shape = origin_shape
652        return tensor_type_var, tensor_constructor
653
654    def _create_global_var(self, name):
655        """Create a GlobalVar if doesn't exist in prelude."""
656        global_var_name_set = set()
657        for g_var_name in self.prelude.mod.get_global_vars():
658            global_var_name_set.add(g_var_name.name_hint)
659        if name not in global_var_name_set:
660            gvar = GlobalVar(name)
661        else:
662            gvar = self.prelude.mod.get_global_var(name)
663
664        return gvar
665
666
667class TensorArrayOps(object):
668    """Contains tensor array related ops"""
669
670    def __init__(self, prelude, dtype):
671        """Create tensor array ops registry"""
672        self.prelude = prelude
673        self.dtype = dtype
674
675    def get_name(self, canonical):
676        """Get name corresponding to the canonical name"""
677        return self.prelude.get_name(canonical, self.dtype)
678
679    def get_var(self, canonical):
680        """Get var corresponding to the canonical name"""
681        return self.prelude.get_var(canonical, self.dtype)
682
683    def define_tensor_adt(self):
684        """Defines the dynamic tensor ADT, which is the container for tensors
685        with variable shapes."""
686        tensor_type_name = self.get_name("tensor_t")
687        tensor_type_var = GlobalTypeVar(tensor_type_name)
688        setattr(self.prelude, tensor_type_name, tensor_type_var)
689        tensor0_type = TensorType([], self.dtype)
690        tensor1_type = TensorType([Any()], self.dtype)
691        tensor2_type = TensorType([Any(), Any()], self.dtype)
692        tensor3_type = TensorType([Any(), Any(), Any()], self.dtype)
693        tensor4_type = TensorType([Any(), Any(), Any(), Any()], self.dtype)
694        tensor5_type = TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype)
695        tensor6_type = TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype)
696        tensor_nil_name = self.get_name("tensor_nil")
697        tensor0_name = self.get_name("tensor0")
698        tensor1_name = self.get_name("tensor1")
699        tensor2_name = self.get_name("tensor2")
700        tensor3_name = self.get_name("tensor3")
701        tensor4_name = self.get_name("tensor4")
702        tensor5_name = self.get_name("tensor5")
703        tensor6_name = self.get_name("tensor6")
704        tensor_nil_case = Constructor(tensor_nil_name, [], tensor_type_var)
705        tensor0_case = Constructor(tensor0_name, [tensor0_type], tensor_type_var)
706        tensor1_case = Constructor(tensor1_name, [tensor1_type], tensor_type_var)
707        tensor2_case = Constructor(tensor2_name, [tensor2_type], tensor_type_var)
708        tensor3_case = Constructor(tensor3_name, [tensor3_type], tensor_type_var)
709        tensor4_case = Constructor(tensor4_name, [tensor4_type], tensor_type_var)
710        tensor5_case = Constructor(tensor5_name, [tensor5_type], tensor_type_var)
711        tensor6_case = Constructor(tensor6_name, [tensor6_type], tensor_type_var)
712        setattr(self.prelude, tensor_nil_name, tensor_nil_case)
713        setattr(self.prelude, tensor0_name, tensor0_case)
714        setattr(self.prelude, tensor1_name, tensor1_case)
715        setattr(self.prelude, tensor2_name, tensor2_case)
716        setattr(self.prelude, tensor3_name, tensor3_case)
717        setattr(self.prelude, tensor4_name, tensor4_case)
718        setattr(self.prelude, tensor5_name, tensor5_case)
719        setattr(self.prelude, tensor6_name, tensor6_case)
720        self.prelude.mod[tensor_type_var] = TypeData(
721            tensor_type_var,
722            [],
723            [
724                tensor_nil_case,
725                tensor0_case,
726                tensor1_case,
727                tensor2_case,
728                tensor3_case,
729                tensor4_case,
730                tensor5_case,
731                tensor6_case,
732            ],
733        )
734
735    def define_tensor_take(self):
736        """Defines a function to return a range of tensor_t on axis 0.
737        tensor_take(t, lower, upper) :
738        tensor_t -> Tensor[(), int32] -> Tensor[(), int32] -> tensor_t
739        """
740        take_name = self.get_name("tensor_take")
741        take_var = GlobalVar(take_name)
742        setattr(self.prelude, take_name, take_var)
743        tensor_t = self.get_var("tensor_t")
744        tensor1_var = self.get_var("tensor1")
745        tensor2_var = self.get_var("tensor2")
746        tensor3_var = self.get_var("tensor3")
747        tensor4_var = self.get_var("tensor4")
748        tensor5_var = self.get_var("tensor5")
749        tensor6_var = self.get_var("tensor6")
750        t = Var("tensor", tensor_t())
751        lower = Var("lower", scalar_type("int32"))
752        upper = Var("upper", scalar_type("int32"))
753        t1 = Var("t1")
754        t2 = Var("t2")
755        t3 = Var("t3")
756        t4 = Var("t4")
757        t5 = Var("t5")
758        t6 = Var("t6")
759        tensor1_case = Clause(
760            PatternConstructor(tensor1_var, [PatternVar(t1)]),
761            tensor1_var(op.take(t1, op.arange(lower, upper, dtype="int32"))),
762        )
763        tensor2_case = Clause(
764            PatternConstructor(tensor2_var, [PatternVar(t2)]),
765            tensor2_var(op.take(t2, op.arange(lower, upper, dtype="int32"), axis=0)),
766        )
767        tensor3_case = Clause(
768            PatternConstructor(tensor3_var, [PatternVar(t3)]),
769            tensor3_var(op.take(t3, op.arange(lower, upper, dtype="int32"), axis=0)),
770        )
771        tensor4_case = Clause(
772            PatternConstructor(tensor4_var, [PatternVar(t4)]),
773            tensor4_var(op.take(t4, op.arange(lower, upper, dtype="int32"), axis=0)),
774        )
775        tensor5_case = Clause(
776            PatternConstructor(tensor5_var, [PatternVar(t5)]),
777            tensor5_var(op.take(t5, op.arange(lower, upper, dtype="int32"), axis=0)),
778        )
779        tensor6_case = Clause(
780            PatternConstructor(tensor6_var, [PatternVar(t6)]),
781            tensor6_var(op.take(t6, op.arange(lower, upper, dtype="int32"), axis=0)),
782        )
783        self.prelude.mod[take_var] = Function(
784            [t, lower, upper],
785            Match(
786                t,
787                [
788                    tensor1_case,
789                    tensor2_case,
790                    tensor3_case,
791                    tensor4_case,
792                    tensor5_case,
793                    tensor6_case,
794                ],
795                False,
796            ),
797            tensor_t(),
798            [],
799        )
800
801    def define_tensor_expand_dims(self):
802        """Defines a function to grow a tensor_t's rank by adding one dimension in front
803        of the original tensor_t.
804        tensor_expand_dims(t) : tensor_t -> tensor_t
805        """
806        expand_dims_name = self.get_name("tensor_expand_dims")
807        expand_dims_var = GlobalVar(expand_dims_name)
808        setattr(self.prelude, expand_dims_name, expand_dims_var)
809        tensor_type_var = self.get_var("tensor_t")
810        x = Var("x", tensor_type_var())
811        t0 = Var("t0")
812        t1 = Var("t1")
813        t2 = Var("t2")
814        t3 = Var("t3")
815        t4 = Var("t4")
816        t5 = Var("t5")
817        tensor0_var = self.get_var("tensor0")
818        tensor1_var = self.get_var("tensor1")
819        tensor2_var = self.get_var("tensor2")
820        tensor3_var = self.get_var("tensor3")
821        tensor4_var = self.get_var("tensor4")
822        tensor5_var = self.get_var("tensor5")
823        tensor6_var = self.get_var("tensor6")
824        tensor0_case = Clause(
825            PatternConstructor(tensor0_var, [PatternVar(t0)]), tensor1_var(op.expand_dims(t0, 0, 1))
826        )
827        tensor1_case = Clause(
828            PatternConstructor(tensor1_var, [PatternVar(t1)]), tensor2_var(op.expand_dims(t1, 0, 1))
829        )
830        tensor2_case = Clause(
831            PatternConstructor(tensor2_var, [PatternVar(t2)]), tensor3_var(op.expand_dims(t2, 0, 1))
832        )
833        tensor3_case = Clause(
834            PatternConstructor(tensor3_var, [PatternVar(t3)]), tensor4_var(op.expand_dims(t3, 0, 1))
835        )
836        tensor4_case = Clause(
837            PatternConstructor(tensor4_var, [PatternVar(t4)]), tensor5_var(op.expand_dims(t4, 0, 1))
838        )
839        tensor5_case = Clause(
840            PatternConstructor(tensor5_var, [PatternVar(t5)]), tensor6_var(op.expand_dims(t5, 0, 1))
841        )
842        self.prelude.mod[expand_dims_var] = Function(
843            [x],
844            Match(
845                x,
846                [
847                    tensor0_case,
848                    tensor1_case,
849                    tensor2_case,
850                    tensor3_case,
851                    tensor4_case,
852                    tensor5_case,
853                ],
854                False,
855            ),
856        )
857
858    def define_tensor_concat(self):
859        """Defines a function to concatenate two tensor_t on the first axis
860
861        tensor_concatenate(t) : tensor_t -> tensor_t -> tensor_t
862        """
863        concat_name = self.get_name("tensor_concatenate")
864        concat_var = GlobalVar(concat_name)
865        setattr(self.prelude, concat_name, concat_var)
866        tensor_type_var = self.get_var("tensor_t")
867        x = Var("x", tensor_type_var())
868        y = Var("y", tensor_type_var())
869
870        tensor1_var = self.get_var("tensor1")
871        tensor2_var = self.get_var("tensor2")
872        tensor3_var = self.get_var("tensor3")
873        tensor4_var = self.get_var("tensor4")
874        t11 = Var("t11")
875        t12 = Var("t12")
876        t21 = Var("t21")
877        t22 = Var("t22")
878        t31 = Var("t31")
879        t32 = Var("t32")
880        t41 = Var("t41")
881        t42 = Var("t42")
882        tensor1_case = Clause(
883            PatternConstructor(tensor1_var, [PatternVar(t11)]),
884            Match(
885                y,
886                [
887                    Clause(
888                        PatternConstructor(tensor1_var, [PatternVar(t12)]),
889                        tensor1_var(op.concatenate([t11, t12], axis=0)),
890                    )
891                ],
892                False,
893            ),
894        )
895        tensor2_case = Clause(
896            PatternConstructor(tensor2_var, [PatternVar(t21)]),
897            Match(
898                y,
899                [
900                    Clause(
901                        PatternConstructor(tensor2_var, [PatternVar(t22)]),
902                        tensor2_var(op.concatenate([t21, t22], axis=0)),
903                    )
904                ],
905                False,
906            ),
907        )
908        tensor3_case = Clause(
909            PatternConstructor(tensor3_var, [PatternVar(t31)]),
910            Match(
911                y,
912                [
913                    Clause(
914                        PatternConstructor(tensor3_var, [PatternVar(t32)]),
915                        tensor3_var(op.concatenate([t31, t32], axis=0)),
916                    )
917                ],
918                False,
919            ),
920        )
921        tensor4_case = Clause(
922            PatternConstructor(tensor4_var, [PatternVar(t41)]),
923            Match(
924                y,
925                [
926                    Clause(
927                        PatternConstructor(tensor4_var, [PatternVar(t42)]),
928                        tensor4_var(op.concatenate([t41, t42], axis=0)),
929                    )
930                ],
931                False,
932            ),
933        )
934        # op.concatenate does not support tensor with rank higher than 4
935        self.prelude.mod[concat_var] = Function(
936            [x, y], Match(x, [tensor1_case, tensor2_case, tensor3_case, tensor4_case], False)
937        )
938
939    def define_tensor_array(self):
940        """Defines a function to create a tensor array with size n.
941        tensor_array(n) : Tensor[(), int32] -> list[tensor_t]
942        """
943        tensor_array_constructor_name = self.get_name("tensor_array")
944        tensor_array_constructor_var = GlobalVar(tensor_array_constructor_name)
945        setattr(self.prelude, tensor_array_constructor_name, tensor_array_constructor_var)
946        tensor_nil_var = self.get_var("tensor_nil")
947        tensor_type_var = self.get_var("tensor_t")
948        n = Var("x", scalar_type("int32"))
949        body = If(
950            equal(n, const(0)),
951            self.prelude.nil(),
952            self.prelude.cons(
953                tensor_nil_var(), tensor_array_constructor_var(subtract(n, const(1)))
954            ),
955        )
956        self.prelude.mod[tensor_array_constructor_var] = Function(
957            [n], body, self.prelude.l(tensor_type_var()), []
958        )
959
960    def define_tensor_array_read(self):
961        """Defines a function to get the head of a list. Assume the list has at least one
962        element.
963
964        tensor_array_read(ta, n) : list[tensor_t] -> Tensor[(), int32] -> tensor_t
965        """
966        read_name = self.get_name("tensor_array_read")
967        read_var = GlobalVar(read_name)
968        setattr(self.prelude, read_name, read_var)
969        tensor_type_var = self.get_var("tensor_t")
970
971        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
972        n = Var("x", scalar_type("int32"))
973        self.prelude.mod[read_var] = Function(
974            [tensor_array, n], self.prelude.nth(tensor_array, n), tensor_type_var(), []
975        )
976
977    def define_tensor_array_write(self):
978        """Defines a function to update a tensor array at index n with value v.
979        tensor_array_write(ta, n, v) :
980            list[tensor_t] -> Tensor[(), int32] -> tensor_t -> list[tensor_t]
981        """
982        write_name = self.get_name("tensor_array_write")
983        write_var = GlobalVar(write_name)
984        setattr(self.prelude, write_name, write_var)
985        tensor_type_var = self.get_var("tensor_t")
986        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
987        n = Var("x", scalar_type("int32"))
988        v = Var("v", tensor_type_var())
989        self.prelude.mod[write_var] = Function(
990            [tensor_array, n, v],
991            self.prelude.update(tensor_array, n, v),
992            self.prelude.l(tensor_type_var()),
993            [],
994        )
995
996    def define_tensor_array_unstack_tensor1(self):
997        """Defines a function to unstack the values of a tensor_t with rank 1 in a tensor array.
998        tensor_array_unstack_tensor1(t) : tensor_t -> list[tensor_t]
999        """
1000        helper_name = self.get_name("tensor_array_unstack_tensor1_helper")
1001        helper_var = GlobalVar(helper_name)
1002        setattr(self.prelude, helper_name, helper_var)
1003        tensor = Var("t", TensorType([Any()], self.dtype))
1004        up = Var("up", scalar_type("int32"))
1005        i = Var("i", scalar_type("int32"))
1006        tensor_type_var = self.get_var("tensor_t")
1007        tensor0_var = self.get_var("tensor0")
1008        helper_body = If(
1009            equal(i, up),
1010            self.prelude.nil(),
1011            self.prelude.cons(
1012                tensor0_var(op.take(tensor, i)), helper_var(add(i, const(1)), up, tensor)
1013            ),
1014        )
1015        self.prelude.mod[helper_var] = Function(
1016            [i, up, tensor], helper_body, self.prelude.l(tensor_type_var()), []
1017        )
1018        unstack_name = self.get_name("tensor_array_unstack_tensor1")
1019        unstack_var = GlobalVar(unstack_name)
1020        setattr(self.prelude, unstack_name, unstack_var)
1021        tensor1 = Var("tensor", TensorType([Any()], self.dtype))
1022        shape = op.shape_of(tensor1)
1023        ndim = op.take(shape, const(0))
1024        self.prelude.mod[unstack_var] = Function(
1025            [tensor1], helper_var(const(0), ndim, tensor1), self.prelude.l(tensor_type_var()), []
1026        )
1027
1028    def define_tensor_array_unstack_tensor2(self):
1029        """Defines a function to unstack the values of a tensor_t with rank 2 in a tensor array.
1030
1031        tensor_array_unstack_tensor2(t) : tensor_t -> list[tensor_t]
1032        """
1033        helper_name = self.get_name("tensor_array_unstack_tensor2_helper")
1034        helper_var = GlobalVar(helper_name)
1035        setattr(self.prelude, helper_name, helper_var)
1036        tensor = Var("t", TensorType([Any(), Any()], self.dtype))
1037        up = Var("up", scalar_type("int32"))
1038        i = Var("i", scalar_type("int32"))
1039
1040        helper_body = If(
1041            equal(i, up),
1042            self.prelude.nil(),
1043            self.prelude.cons(
1044                self.get_var("tensor1")(op.take(tensor, i, axis=0)),
1045                helper_var(add(i, const(1)), up, tensor),
1046            ),
1047        )
1048        self.prelude.mod[helper_var] = Function(
1049            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
1050        )
1051
1052        tensor_array_unstack_tensor2_name = self.get_name("tensor_array_unstack_tensor2")
1053        tensor_array_unstack_tensor2_var = GlobalVar(tensor_array_unstack_tensor2_name)
1054        setattr(self.prelude, tensor_array_unstack_tensor2_name, tensor_array_unstack_tensor2_var)
1055        tensor2 = Var("tensor", TensorType([Any(), Any()], self.dtype))
1056        shape = op.shape_of(tensor2)
1057        ndim = op.take(shape, const(0))
1058        self.prelude.mod[tensor_array_unstack_tensor2_var] = Function(
1059            [tensor2],
1060            helper_var(const(0), ndim, tensor2),
1061            self.prelude.l(self.get_var("tensor_t")()),
1062            [],
1063        )
1064
1065    def define_tensor_array_unstack_tensor3(self):
1066        """Defines a function to unstack the values of a tensor_t with rank 3 in a tensor array.
1067
1068        tensor_array_unstack_tensor3(t) : tensor_t -> list[tensor_t]
1069        """
1070        helper_name = self.get_name("tensor_array_unstack_tensor3_helper")
1071        helper_var = GlobalVar(helper_name)
1072        setattr(self.prelude, helper_name, helper_var)
1073        tensor = Var("t", TensorType([Any(), Any(), Any()], self.dtype))
1074        up = Var("up", scalar_type("int32"))
1075        i = Var("i", scalar_type("int32"))
1076
1077        helper_body = If(
1078            equal(i, up),
1079            self.prelude.nil(),
1080            self.prelude.cons(
1081                self.get_var("tensor2")(op.take(tensor, i, axis=0)),
1082                helper_var(add(i, const(1)), up, tensor),
1083            ),
1084        )
1085        self.prelude.mod[helper_var] = Function(
1086            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
1087        )
1088
1089        tensor_array_unstack_tensor3_name = self.get_name("tensor_array_unstack_tensor3")
1090        tensor_array_unstack_tensor3_var = GlobalVar(tensor_array_unstack_tensor3_name)
1091        setattr(self.prelude, tensor_array_unstack_tensor3_name, tensor_array_unstack_tensor3_var)
1092        tensor3 = Var("tensor", TensorType([Any(), Any(), Any()], self.dtype))
1093        shape = op.shape_of(tensor3)
1094        ndim = op.take(shape, const(0))
1095        self.prelude.mod[tensor_array_unstack_tensor3_var] = Function(
1096            [tensor3],
1097            helper_var(const(0), ndim, tensor3),
1098            self.prelude.l(self.get_var("tensor_t")()),
1099            [],
1100        )
1101
1102    def define_tensor_array_unstack_tensor4(self):
1103        """Defines a function to unstack the values of a tensor_t with rank 4 in a tensor array.
1104
1105        tensor_array_unstack_tensor4(t) : tensor_t -> list[tensor_t]
1106        """
1107        helper_name = self.get_name("tensor_array_unstack_tensor4_helper")
1108        helper_var = GlobalVar(helper_name)
1109        setattr(self.prelude, helper_name, helper_var)
1110        tensor = Var("t", TensorType([Any(), Any(), Any(), Any()], self.dtype))
1111        up = Var("up", scalar_type("int32"))
1112        i = Var("i", scalar_type("int32"))
1113
1114        helper_body = If(
1115            equal(i, up),
1116            self.prelude.nil(),
1117            self.prelude.cons(
1118                self.get_var("tensor3")(op.take(tensor, i, axis=0)),
1119                helper_var(add(i, const(1)), up, tensor),
1120            ),
1121        )
1122        self.prelude.mod[helper_var] = Function(
1123            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
1124        )
1125
1126        tensor_array_unstack_tensor4_name = self.get_name("tensor_array_unstack_tensor4")
1127        tensor_array_unstack_tensor4_var = GlobalVar(tensor_array_unstack_tensor4_name)
1128        setattr(self.prelude, tensor_array_unstack_tensor4_name, tensor_array_unstack_tensor4_var)
1129        tensor4 = Var("tensor", TensorType([Any(), Any(), Any(), Any()], self.dtype))
1130        shape = op.shape_of(tensor4)
1131        ndim = op.take(shape, const(0))
1132        self.prelude.mod[tensor_array_unstack_tensor4_var] = Function(
1133            [tensor4],
1134            helper_var(const(0), ndim, tensor4),
1135            self.prelude.l(self.get_var("tensor_t")()),
1136            [],
1137        )
1138
1139    def define_tensor_array_unstack_tensor5(self):
1140        """Defines a function to unstack the values of a tensor_t with rank 5 in a tensor array.
1141
1142        tensor_array_unstack_tensor5(t) : tensor_t -> list[tensor_t]
1143        """
1144        helper_name = self.get_name("tensor_array_unstack_tensor5_helper")
1145        helper_var = GlobalVar(helper_name)
1146        setattr(self.prelude, helper_name, helper_var)
1147        tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
1148        up = Var("up", scalar_type("int32"))
1149        i = Var("i", scalar_type("int32"))
1150
1151        helper_body = If(
1152            equal(i, up),
1153            self.prelude.nil(),
1154            self.prelude.cons(
1155                self.get_var("tensor4")(op.take(tensor, i, axis=0)),
1156                helper_var(add(i, const(1)), up, tensor),
1157            ),
1158        )
1159        self.prelude.mod[helper_var] = Function(
1160            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
1161        )
1162
1163        tensor_array_unstack_tensor5_name = self.get_name("tensor_array_unstack_tensor5")
1164        tensor_array_unstack_tensor5_var = GlobalVar(tensor_array_unstack_tensor5_name)
1165        setattr(self.prelude, tensor_array_unstack_tensor5_name, tensor_array_unstack_tensor5_var)
1166        tensor5 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any()], self.dtype))
1167        shape = op.shape_of(tensor5)
1168        ndim = op.take(shape, const(0))
1169        self.prelude.mod[tensor_array_unstack_tensor5_var] = Function(
1170            [tensor5],
1171            helper_var(const(0), ndim, tensor5),
1172            self.prelude.l(self.get_var("tensor_t")()),
1173            [],
1174        )
1175
1176    def define_tensor_array_unstack_tensor6(self):
1177        """Defines a function to unstack the values of a tensor_t with rank 6 in a tensor array.
1178
1179        tensor_array_unstack_tensor6(t) : tensor_t -> list[tensor_t]
1180        """
1181        helper_name = self.get_name("tensor_array_unstack_tensor6_helper")
1182        helper_var = GlobalVar(helper_name)
1183        setattr(self.prelude, helper_name, helper_var)
1184        tensor = Var("t", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
1185        up = Var("up", scalar_type("int32"))
1186        i = Var("i", scalar_type("int32"))
1187
1188        helper_body = If(
1189            equal(i, up),
1190            self.prelude.nil(),
1191            self.prelude.cons(
1192                self.get_var("tensor5")(op.take(tensor, i, axis=0)),
1193                helper_var(add(i, const(1)), up, tensor),
1194            ),
1195        )
1196        self.prelude.mod[helper_var] = Function(
1197            [i, up, tensor], helper_body, self.prelude.l(self.get_var("tensor_t")()), []
1198        )
1199
1200        tensor_array_unstack_tensor6_name = self.get_name("tensor_array_unstack_tensor6")
1201        tensor_array_unstack_tensor6_var = GlobalVar(tensor_array_unstack_tensor6_name)
1202        setattr(self.prelude, tensor_array_unstack_tensor6_name, tensor_array_unstack_tensor6_var)
1203        tensor6 = Var("tensor", TensorType([Any(), Any(), Any(), Any(), Any(), Any()], self.dtype))
1204        shape = op.shape_of(tensor6)
1205        ndim = op.take(shape, const(0))
1206        self.prelude.mod[tensor_array_unstack_tensor6_var] = Function(
1207            [tensor6],
1208            helper_var(const(0), ndim, tensor6),
1209            self.prelude.l(self.get_var("tensor_t")()),
1210            [],
1211        )
1212
1213    def define_tensor_array_scatter(self):
1214        """Defines a function to scatter the values of a tensor_t in indices of a tensor array.
1215        tensor_array_scatter(ta, indices, value) :
1216            list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t]
1217        """
1218        tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper")
1219        tensor_array_scatter_helper_var = GlobalVar(tensor_array_scatter_helper_name)
1220        tensor_t = self.get_var("tensor_t")
1221        ta = Var("ta", self.prelude.l(tensor_t()))
1222        current = Var("current", scalar_type("int32"))
1223        limit = Var("limit", scalar_type("int32"))
1224        indices_ = Var("indices_", TensorType([Any()], "int32"))
1225        values_ = Var("values_", self.prelude.l(tensor_t()))
1226        write_var = self.get_var("tensor_array_write")
1227        read_var = self.get_var("tensor_array_read")
1228        helper_body = If(
1229            equal(current, limit),
1230            ta,
1231            tensor_array_scatter_helper_var(
1232                write_var(ta, op.take(indices_, current), read_var(values_, current)),
1233                add(current, const(1)),
1234                limit,
1235                indices_,
1236                values_,
1237            ),
1238        )
1239        self.prelude.mod[tensor_array_scatter_helper_var] = Function(
1240            [ta, current, limit, indices_, values_], helper_body, self.prelude.l(tensor_t()), []
1241        )
1242        tensor_array_scatter_name = self.get_name("tensor_array_scatter")
1243        tensor_array_scatter_var = GlobalVar(tensor_array_scatter_name)
1244        setattr(self.prelude, tensor_array_scatter_name, tensor_array_scatter_var)
1245        tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
1246        indices = Var("indices", TensorType([Any()], "int32"))
1247        values = Var("values", self.prelude.l(tensor_t()))
1248        indices_shape = op.shape_of(indices)
1249        limit = op.take(indices_shape, const(0))
1250        body = tensor_array_scatter_helper_var(tensor_array, const(0), limit, indices, values)
1251        self.prelude.mod[tensor_array_scatter_var] = Function(
1252            [tensor_array, indices, values], body, self.prelude.l(tensor_t()), []
1253        )
1254
1255    def define_tensor_array_split(self):
1256        """Defines a function to split the values of a tensor_t into a tensor array.
1257        tensor_array_split(ta, value, lengths) :
1258            list[tensor_t] -> tensor_t -> Tensor[(Any), int32] -> list[tensor_t]
1259        """
1260        tensor_t = self.get_var("tensor_t")
1261        tensor_array_split_helper_name = self.get_name("ta_split_helper")
1262        tensor_array_split_helper_var = GlobalVar(tensor_array_split_helper_name)
1263        setattr(self.prelude, tensor_array_split_helper_name, tensor_array_split_helper_var)
1264        ta1 = Var("tensor_array", self.prelude.l(tensor_t()))
1265        value1 = Var("value1", tensor_t())
1266        offset1 = Var("offset1", scalar_type("int32"))
1267        current1 = Var("current1", scalar_type("int32"))
1268        limit1 = Var("limit1", scalar_type("int32"))
1269        lengths1 = Var("lengths", TensorType([Any()], "int32"))
1270        write_var = self.get_var("tensor_array_write")
1271        take_var = self.get_var("tensor_take")
1272        helper1_body = If(
1273            equal(current1, limit1),
1274            ta1,
1275            write_var(
1276                tensor_array_split_helper_var(
1277                    ta1,
1278                    value1,
1279                    add(offset1, op.take(lengths1, current1)),
1280                    add(current1, const(1)),
1281                    limit1,
1282                    lengths1,
1283                ),
1284                current1,
1285                take_var(value1, offset1, add(op.take(lengths1, current1), offset1)),
1286            ),
1287        )
1288        self.prelude.mod[tensor_array_split_helper_var] = Function(
1289            [ta1, value1, offset1, current1, limit1, lengths1],
1290            helper1_body,
1291            self.prelude.l(tensor_t()),
1292            [],
1293        )
1294        split_name = self.get_name("tensor_array_split")
1295        split_var = GlobalVar(split_name)
1296        setattr(self.prelude, split_name, split_var)
1297        tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
1298        value = Var("value", tensor_t())
1299        lengths = Var("lengths", TensorType([Any()], "int32"))
1300        lengths_shape = op.shape_of(lengths)
1301        lengths_limit = op.take(lengths_shape, const(0))
1302        body = tensor_array_split_helper_var(
1303            tensor_array, value, const(0), const(0), lengths_limit, lengths
1304        )
1305        self.prelude.mod[split_var] = Function(
1306            [tensor_array, value, lengths], body, self.prelude.l(tensor_t()), []
1307        )
1308
1309    def define_tensor_array_concat(self):
1310        """Defines a function to return the values in the tensor array as concatenated tensor_t.
1311        tensor_array_concat(ta) : list[tensor_t] -> tensor_t
1312        """
1313        concat_name = self.get_name("tensor_array_concat")
1314        concat_var = GlobalVar(concat_name)
1315        setattr(self.prelude, concat_name, concat_var)
1316        tensor_concat_var = self.get_var("tensor_concatenate")
1317        tensor_t = self.get_var("tensor_t")
1318        tensor_nil_var = self.get_var("tensor_nil")
1319        tensor_array = Var("tensor_array", self.prelude.l(tensor_t()))
1320        hd = Var("hd")
1321        tl = Var("tl")
1322        nil_case = Clause(PatternConstructor(self.prelude.nil), tensor_nil_var())
1323        cons_case = Clause(
1324            PatternConstructor(self.prelude.cons, [PatternVar(hd), PatternVar(tl)]),
1325            Match(
1326                tl,
1327                [
1328                    Clause(PatternConstructor(self.prelude.nil), hd),
1329                    Clause(PatternWildcard(), tensor_concat_var(hd, concat_var(tl))),
1330                ],
1331                False,
1332            ),
1333        )
1334        self.prelude.mod[concat_var] = Function(
1335            [tensor_array], Match(tensor_array, [nil_case, cons_case], False), tensor_t(), []
1336        )
1337
1338    def define_tensor_array_gather(self):
1339        """Defines a function to return the selected values in a tensor array as tensor_t.
1340        tensor_array_gather(ta, indices) : list[tensor_t] -> Tensor[(Any), int32] -> tensor_t
1341        """
1342        helper_name = self.get_name("tensor_array_gather_helper")
1343        helper_var = GlobalVar(helper_name)
1344        setattr(self.prelude, helper_name, helper_var)
1345        tensor_type_var = self.get_var("tensor_t")
1346        stack_var = self.get_var("tensor_array_stack")
1347        read_var = self.get_var("tensor_array_read")
1348        ta = Var("ta", self.prelude.l(tensor_type_var()))
1349        accu = Var("accu", self.prelude.l(tensor_type_var()))
1350        current = Var("current", scalar_type("int32"))
1351        limit = Var("limit", scalar_type("int32"))
1352        indices_ = Var("indices_", TensorType([Any()], "int32"))
1353        helper_body = If(
1354            equal(current, const(0)),
1355            stack_var(accu),
1356            helper_var(
1357                ta,
1358                self.prelude.cons(
1359                    read_var(ta, op.take(indices_, subtract(current, const(1)))), accu
1360                ),
1361                subtract(current, const(1)),
1362                limit,
1363                indices_,
1364            ),
1365        )
1366        self.prelude.mod[helper_var] = Function(
1367            [ta, accu, current, limit, indices_], helper_body, tensor_type_var(), []
1368        )
1369        gather_name = self.get_name("tensor_array_gather")
1370        gather_var = GlobalVar(gather_name)
1371        setattr(self.prelude, gather_name, gather_var)
1372        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
1373        indices = Var("indices", TensorType([Any()], "int32"))
1374        indices_shape = op.shape_of(indices)
1375        limit = op.take(indices_shape, const(0))
1376        body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices)
1377        self.prelude.mod[gather_var] = Function(
1378            [tensor_array, indices], body, tensor_type_var(), []
1379        )
1380
1381    def define_tensor_array_stack(self):
1382        """Defines a function to get the values in the tensor array as a stack tensor_t.
1383        tensor_array_stack(l) : list[tensor_t] -> tensor_t
1384        """
1385        stack_name = self.get_name("tensor_array_stack")
1386        stack_var = GlobalVar(stack_name)
1387        setattr(self.prelude, stack_name, stack_var)
1388        tensor_type_var = self.get_var("tensor_t")
1389        tensor_array = Var("tensor_array", self.prelude.l(tensor_type_var()))
1390        expand_dims_var = self.get_var("tensor_expand_dims")
1391        concat_var = self.get_var("tensor_concatenate")
1392        tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array)
1393        tensors = self.prelude.foldl(
1394            concat_var,
1395            self.prelude.hd(tensor_array_expand_dims),
1396            self.prelude.tl(tensor_array_expand_dims),
1397        )
1398        self.prelude.mod[stack_var] = Function(
1399            [tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), []
1400        )
1401
1402    def register(self):
1403        """Register all tensor array ops in Prelude"""
1404        self.define_tensor_adt()
1405        self.define_tensor_take()
1406        self.define_tensor_expand_dims()
1407        self.define_tensor_concat()
1408        self.define_tensor_array()
1409        self.define_tensor_array_read()
1410        self.define_tensor_array_write()
1411        self.define_tensor_array_unstack_tensor1()
1412        self.define_tensor_array_unstack_tensor2()
1413        self.define_tensor_array_unstack_tensor3()
1414        self.define_tensor_array_unstack_tensor4()
1415        self.define_tensor_array_unstack_tensor5()
1416        self.define_tensor_array_unstack_tensor6()
1417        self.define_tensor_array_scatter()
1418        self.define_tensor_array_split()
1419        self.define_tensor_array_concat()
1420        self.define_tensor_array_stack()
1421        # TODO(wweic): Gather fails in PartialEvaluate
1422        # self.define_tensor_array_gather()
1423
1424
1425class Prelude:
1426    """Contains standard definitions."""
1427
1428    def __init__(self, mod=None):
1429        if mod is None:
1430            mod = IRModule()
1431        self.mod = mod
1432        self.load_prelude()
1433
1434    def get_name(self, canonical, dtype):
1435        """Get name corresponding to the canonical name"""
1436        if canonical == "tensor_t":
1437            return "tensor_{}_t".format(dtype)
1438        return "{}_{}".format(canonical, dtype)
1439
1440    def get_var(self, canonical, dtype):
1441        """Get var corresponding to the canonical name"""
1442        name = self.get_name(canonical, dtype)
1443        return getattr(self, name)
1444
1445    def get_name_static(self, canonical, dtype, shape):
1446        """Get name corresponding to the canonical name"""
1447        return _get_name_static(canonical, dtype, shape)
1448
1449    def get_var_static(self, canonical, dtype, shape):
1450        """Get var corresponding to the canonical name"""
1451        name = self.get_name_static(canonical, dtype, shape)
1452        return getattr(self, name)
1453
1454    def load_prelude(self):
1455        """Parses the Prelude from Relay's text format into a module."""
1456        # TODO(@jroesch): we should remove this helper when we port over prelude
1457        self.mod.import_from_std("prelude.rly")
1458
1459        self.l = self.mod.get_global_type_var("List")
1460        list_adt = self.mod[self.l]
1461        self.cons = list_adt.constructors[0]
1462        self.nil = list_adt.constructors[1]
1463
1464        self.optional = self.mod.get_global_type_var("Option")
1465        optional_adt = self.mod[self.optional]
1466        self.some = optional_adt.constructors[0]
1467        self.none = optional_adt.constructors[1]
1468
1469        self.tree = self.mod.get_global_type_var("Tree")
1470        tree_adt = self.mod[self.tree]
1471        self.rose = tree_adt.constructors[0]
1472
1473        GLOBAL_DEFS = [
1474            "id",
1475            "compose",
1476            "flip",
1477            "hd",
1478            "tl",
1479            "nth",
1480            "update",
1481            "map",
1482            "foldl",
1483            "foldr",
1484            "foldr1",
1485            "concat",
1486            "filter",
1487            "zip",
1488            "rev",
1489            "map_accuml",
1490            "map_accumr",
1491            "unfoldl",
1492            "unfoldr",
1493            "sum",
1494            "length",
1495            "tmap",
1496            "size",
1497            "iterate",
1498        ]
1499        for global_def in GLOBAL_DEFS:
1500            setattr(self, global_def, self.mod.get_global_var(global_def))
1501
1502        for dtype in [
1503            "float32",
1504            "float16",
1505            "float64",
1506            "int32",
1507            "uint8",
1508            "int8",
1509            "int16",
1510            "uint16",
1511            "int64",
1512        ]:
1513            tensor_array_ops = TensorArrayOps(self, dtype)
1514            tensor_array_ops.register()
1515