1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-variable, too-many-locals 18# pylint: disable=unused-argument, redefined-builtin 19"""Dilation2D operators""" 20from __future__ import absolute_import as _abs 21from tvm import te 22from tvm.topi.util import simplify 23from ..nn.pad import pad 24from ..nn.util import get_pad_tuple 25 26 27def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None): 28 """Morphological dilation operator in NCHW layout. 29 30 Parameters 31 ---------- 32 input : tvm.te.Tensor 33 4-D with shape [batch, in_channel, in_height, in_width] 34 35 filter : tvm.te.Tensor 36 3-D with shape [ in_channel, filter_height, filter_width] 37 38 stride : int or a list/tuple of two ints 39 Stride size, or [stride_height, stride_width] 40 41 padding : int or str 42 Padding size 43 44 dilations: int or a list/tuple of two ints 45 dilation size, or [dilation_height, dilation_width] 46 47 out_dtype : Optional[str] 48 Specifies the output data type. 49 50 Returns 51 ------- 52 Output : tvm.te.Tensor 53 4-D with shape [batch, in_channel, out_height, out_width] 54 """ 55 if out_dtype is None: 56 out_dtype = input.dtype 57 assert isinstance(stride, int) or len(stride) == 2 58 assert isinstance(dilations, int) or len(dilations) == 2 59 if isinstance(stride, int): 60 stride_h = stride_w = stride 61 else: 62 stride_h, stride_w = stride 63 64 if isinstance(dilations, int): 65 dilation_h = dilation_w = dilations 66 else: 67 dilation_h, dilation_w = dilations 68 69 batch, in_channel, in_height, in_width = input.shape 70 channel, kernel_h, kernel_w = filter.shape 71 assert ( 72 in_channel.value == channel.value 73 ), "For Dilation2D input and filter channels should be same." 74 75 # compute the output shape 76 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 77 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 78 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 79 padding, (dilated_kernel_h, dilated_kernel_w) 80 ) 81 82 out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) 83 out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) 84 # compute graph 85 pad_before = [0, 0, pad_top, pad_left] 86 pad_after = [0, 0, pad_down, pad_right] 87 temp = pad(input, pad_before, pad_after, name="pad_temp") 88 ry = te.reduce_axis((0, kernel_h), name="ry") 89 rx = te.reduce_axis((0, kernel_w), name="rx") 90 91 return te.compute( 92 (batch, in_channel, out_height, out_width), 93 lambda nn, ff, yy, xx: te.max( 94 temp[nn, ff, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype( 95 out_dtype 96 ) 97 + filter[ff, ry, rx].astype(out_dtype), 98 axis=[ry, rx], 99 ), 100 tag="dilation2d_nchw", 101 ) 102 103 104def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None): 105 """Morphological 2d dilation NHWC layout. 106 107 Parameters 108 ---------- 109 input : tvm.te.Tensor 110 4-D with shape [batch, in_height, in_width, in_channel] 111 112 filter : tvm.te.Tensor 113 3-D with shape [filter_height, filter_width, in_channel] 114 115 stride : int or a list/tuple of two ints 116 Stride size, or [stride_height, stride_width] 117 118 padding : int 119 Padding size 120 121 dilations: int or a list/tuple of two ints 122 dilation size, or [dilation_height, dilation_width] 123 124 out_dtype : Optional[str] 125 Specifies the output data type. 126 127 Returns 128 ------- 129 Output : tvm.te.Tensor 130 4-D with shape [batch, out_height, out_width, in_channel] 131 """ 132 if out_dtype is None: 133 out_dtype = input.dtype 134 assert isinstance(stride, int) or len(stride) == 2 135 assert isinstance(dilations, int) or len(dilations) == 2 136 if isinstance(stride, int): 137 stride_h = stride_w = stride 138 else: 139 stride_h, stride_w = stride 140 141 if isinstance(dilations, int): 142 dilation_h = dilation_w = dilations 143 else: 144 dilation_h, dilation_w = dilations 145 146 batch, in_height, in_width, in_channel = input.shape 147 kernel_h, kernel_w, channel = filter.shape 148 assert ( 149 in_channel.value == channel.value 150 ), "For Dilation2D input and filter channels should be same." 151 152 # compute the output shape 153 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 154 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 155 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( 156 padding, (dilated_kernel_h, dilated_kernel_w) 157 ) 158 159 out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) 160 out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) 161 pad_before = [0, pad_top, pad_left, 0] 162 pad_after = [0, pad_down, pad_right, 0] 163 padded_input = pad(input, pad_before, pad_after, name="padded_input") 164 ry = te.reduce_axis((0, kernel_h), name="ry") 165 rx = te.reduce_axis((0, kernel_w), name="rx") 166 167 return te.compute( 168 (batch, out_height, out_width, in_channel), 169 lambda nn, yy, xx, ff: te.max( 170 padded_input[ 171 nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ff 172 ].astype(out_dtype) 173 + filter[ry, rx, ff].astype(out_dtype), 174 axis=[ry, rx], 175 ), 176 tag="dilation2d_nhcw", 177 ) 178