1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Pad the data by constant value """
18from __future__ import absolute_import as _abs
19import tvm
20from ..util import equal_const_int
21from .. import tag
22
23@tvm.tag_scope(tag=tag.INJECTIVE+",pad")
24def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
25    """Pad Input with zeros.
26
27    Parameters
28    ----------
29    data : tvm.Tensor
30        n-D input, can be any layout.
31
32    pad_before : list / tuple of n ints
33        Pad width on each dimension to pad the before the axis begin.
34
35    pad_after : list / tuple of n ints, optional
36        Pad width each dimension to pad the after the axis end.
37
38    pad_value : float, optional
39        The value to be padded.
40
41    name : str, optional
42        The name prefix operators generated
43
44    Returns
45    -------
46    Output : tvm.Tensor
47        n-D, the same layout as Input.
48    """
49    n = len(data.shape)
50    pad_after = pad_after if pad_after else pad_before
51    if len(pad_before) != n:
52        raise ValueError("Input dimension and pad_before dismatch : %d vs %d" % (
53            n, len(pad_before)))
54    if len(pad_after) != n:
55        raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (
56            n, len(pad_before)))
57    out_shape = tuple(
58        tvm.ir_pass.Simplify(
59            (data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n))
60    pad_value = (pad_value if isinstance(pad_value, tvm.expr.Expr)
61                 else tvm.const(pad_value, data.dtype))
62    def _pad(*indices):
63        not_zero = []
64        index_tuple = []
65        for i in range(n):
66            if equal_const_int(pad_before[i], 0) and equal_const_int(pad_after[i], 0):
67                index_tuple.append(indices[i])
68            else:
69                index_tuple.append(indices[i] - pad_before[i])
70                not_zero.append(indices[i] >= pad_before[i])
71                not_zero.append(indices[i] < data.shape[i] + pad_before[i])
72        if not_zero:
73            not_zero = tvm.all(*not_zero)
74            return tvm.if_then_else(not_zero, data(*index_tuple), pad_value)
75        return data(*index_tuple)
76    return tvm.compute(out_shape, _pad, name=name)
77
78
79@tvm.tag_scope(tag=tag.INJECTIVE + ",pad")
80def mirror_pad(data,
81               pad_before,
82               pad_after=None,
83               mode='SYMMETRIC',
84               name="MirrorPadInput"):
85    """Pad Input with mirroring either symmetric or reflected.
86
87    Parameters
88    ----------
89    data : tvm.Tensor
90        n-D input, can be any layout.
91
92    pad_before : list / tuple of n ints
93        Pad width on each dimension to pad the before the axis begin.
94
95    pad_after : list / tuple of n ints, optional
96        Pad width each dimension to pad the after the axis end.
97
98    mode: str, optional
99        Type of mirror padding to apply. Must be SYMMETRIC or REFLECT
100
101    name : str, optional
102        The name prefix operators generated
103
104    Returns
105    -------
106    Output : tvm.Tensor
107        n-D, the same layout as Input.
108    """
109    n = len(data.shape)
110    pad_after = pad_after if pad_after else pad_before
111    if len(pad_before) != n:
112        raise ValueError("Input dimension and pad_before dismatch : %d vs %d" %
113                         (n, len(pad_before)))
114    if len(pad_after) != n:
115        raise ValueError("Input dimension and pad_after dismatch : %d vs %d" %
116                         (n, len(pad_before)))
117    out_shape = tuple(
118        tvm.ir_pass.Simplify((data.shape[i] + pad_before[i] + pad_after[i]))
119        for i in range(n))
120    assert mode in ('SYMMETRIC', 'REFLECT')
121    mode = int(mode == 'SYMMETRIC')
122
123    def _pad(*indices):
124        index_tuple = []
125        above = []
126        below = []
127        for i in range(n):
128            if equal_const_int(pad_before[i], 0) and equal_const_int(
129                    pad_after[i], 0):
130                index_tuple.append(indices[i])
131                above.append(False)
132                below.append(False)
133            else:
134                index_tuple.append(indices[i] - pad_before[i])
135                above.append(indices[i] >= data.shape[i] + pad_before[i])
136                below.append(indices[i] < pad_before[i])
137        mapped_tuple = []
138        for i, axis in enumerate(index_tuple):
139            mapped_axis = tvm.if_then_else(below[i], -axis - mode, axis)
140            mapped_axis = tvm.if_then_else(
141                above[i], (2 * (data.shape[i] - 1)) - axis + mode, mapped_axis)
142            mapped_tuple.append(mapped_axis)
143        return data(*mapped_tuple)
144
145    return tvm.compute(out_shape, _pad, name=name)
146