1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-variable, too-many-locals
18# pylint: disable=unused-argument, redefined-builtin
19"""Dilation2D operators"""
20from __future__ import absolute_import as _abs
21from tvm import te
22from tvm.topi.util import simplify
23from ..nn.pad import pad
24from ..nn.util import get_pad_tuple
25
26
27def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None):
28    """Morphological dilation operator in NCHW layout.
29
30    Parameters
31    ----------
32    input : tvm.te.Tensor
33        4-D with shape [batch, in_channel, in_height, in_width]
34
35    filter : tvm.te.Tensor
36        3-D with shape [ in_channel, filter_height, filter_width]
37
38    stride : int or a list/tuple of two ints
39        Stride size, or [stride_height, stride_width]
40
41    padding : int or str
42        Padding size
43
44    dilations: int or a list/tuple of two ints
45        dilation size, or [dilation_height, dilation_width]
46
47    out_dtype : Optional[str]
48        Specifies the output data type.
49
50    Returns
51    -------
52    Output : tvm.te.Tensor
53        4-D with shape [batch, in_channel, out_height, out_width]
54    """
55    if out_dtype is None:
56        out_dtype = input.dtype
57    assert isinstance(stride, int) or len(stride) == 2
58    assert isinstance(dilations, int) or len(dilations) == 2
59    if isinstance(stride, int):
60        stride_h = stride_w = stride
61    else:
62        stride_h, stride_w = stride
63
64    if isinstance(dilations, int):
65        dilation_h = dilation_w = dilations
66    else:
67        dilation_h, dilation_w = dilations
68
69    batch, in_channel, in_height, in_width = input.shape
70    channel, kernel_h, kernel_w = filter.shape
71    assert (
72        in_channel.value == channel.value
73    ), "For Dilation2D input and filter channels should be same."
74
75    # compute the output shape
76    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
77    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
78    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
79        padding, (dilated_kernel_h, dilated_kernel_w)
80    )
81
82    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
83    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
84    # compute graph
85    pad_before = [0, 0, pad_top, pad_left]
86    pad_after = [0, 0, pad_down, pad_right]
87    temp = pad(input, pad_before, pad_after, name="pad_temp")
88    ry = te.reduce_axis((0, kernel_h), name="ry")
89    rx = te.reduce_axis((0, kernel_w), name="rx")
90
91    return te.compute(
92        (batch, in_channel, out_height, out_width),
93        lambda nn, ff, yy, xx: te.max(
94            temp[nn, ff, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(
95                out_dtype
96            )
97            + filter[ff, ry, rx].astype(out_dtype),
98            axis=[ry, rx],
99        ),
100        tag="dilation2d_nchw",
101    )
102
103
104def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None):
105    """Morphological 2d dilation NHWC layout.
106
107    Parameters
108    ----------
109    input : tvm.te.Tensor
110        4-D with shape [batch, in_height, in_width, in_channel]
111
112    filter : tvm.te.Tensor
113        3-D with shape [filter_height, filter_width, in_channel]
114
115    stride : int or a list/tuple of two ints
116        Stride size, or [stride_height, stride_width]
117
118    padding : int
119        Padding size
120
121    dilations: int or a list/tuple of two ints
122        dilation size, or [dilation_height, dilation_width]
123
124    out_dtype : Optional[str]
125        Specifies the output data type.
126
127    Returns
128    -------
129    Output : tvm.te.Tensor
130        4-D with shape [batch, out_height, out_width, in_channel]
131    """
132    if out_dtype is None:
133        out_dtype = input.dtype
134    assert isinstance(stride, int) or len(stride) == 2
135    assert isinstance(dilations, int) or len(dilations) == 2
136    if isinstance(stride, int):
137        stride_h = stride_w = stride
138    else:
139        stride_h, stride_w = stride
140
141    if isinstance(dilations, int):
142        dilation_h = dilation_w = dilations
143    else:
144        dilation_h, dilation_w = dilations
145
146    batch, in_height, in_width, in_channel = input.shape
147    kernel_h, kernel_w, channel = filter.shape
148    assert (
149        in_channel.value == channel.value
150    ), "For Dilation2D input and filter channels should be same."
151
152    # compute the output shape
153    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
154    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
155    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
156        padding, (dilated_kernel_h, dilated_kernel_w)
157    )
158
159    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
160    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
161    pad_before = [0, pad_top, pad_left, 0]
162    pad_after = [0, pad_down, pad_right, 0]
163    padded_input = pad(input, pad_before, pad_after, name="padded_input")
164    ry = te.reduce_axis((0, kernel_h), name="ry")
165    rx = te.reduce_axis((0, kernel_w), name="rx")
166
167    return te.compute(
168        (batch, out_height, out_width, in_channel),
169        lambda nn, yy, xx, ff: te.max(
170            padded_input[
171                nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ff
172            ].astype(out_dtype)
173            + filter[ry, rx, ff].astype(out_dtype),
174            axis=[ry, rx],
175        ),
176        tag="dilation2d_nhcw",
177    )
178