1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, too-many-locals, too-many-arguments 18"""Deformable Conv2D operators""" 19import tvm 20from tvm import te 21 22from .util import get_pad_tuple 23from ..util import get_const_tuple 24from ..cpp.util import bilinear_sample_nchw 25 26 27def deformable_conv2d_nchw( 28 data, offset, kernel, strides, padding, dilation, deformable_groups, groups, out_dtype 29): 30 """Deformable conv2D operator in NCHW layout. 31 32 The deformable convolution operation is described in https://arxiv.org/abs/1703.06211 33 34 Parameters 35 ---------- 36 data : tvm.te.Tensor 37 4-D with shape [batch, in_channel, in_height, in_width] 38 39 offset : tvm.te.Tensor 40 4-D with shape [batch, deformable_groups * filter_height * filter_width * 2, 41 out_height, out_width]. 42 43 kernel : tvm.te.Tensor 44 4-D with shape [num_filter, in_channel, filter_height, filter_width] 45 46 strides : int or a list/tuple of two ints 47 stride size, or [stride_height, stride_width] 48 49 padding : int or a list/tuple of two ints 50 padding size, or [pad_height, pad_width] 51 52 dilation : int or a list/tuple of two ints 53 dilation size, or [dilation_height, dilation_width] 54 55 deformable_groups : int 56 number of deformable groups 57 58 groups : int 59 number of groups 60 61 Returns 62 ------- 63 output : tvm.te.Tensor 64 4-D with shape [batch, out_channel, out_height, out_width] 65 """ 66 if out_dtype is None: 67 out_dtype = data.dtype 68 69 if isinstance(strides, int): 70 stride_h = stride_w = strides 71 else: 72 stride_h, stride_w = strides 73 74 if isinstance(dilation, int): 75 dilation_h = dilation_w = dilation 76 else: 77 dilation_h, dilation_w = dilation 78 79 batch, in_channel, in_height, in_width = get_const_tuple(data.shape) 80 out_channel, channel, kernel_h, kernel_w = get_const_tuple(kernel.shape) 81 _, _, out_height, out_width = get_const_tuple(offset.shape) 82 assert in_channel % deformable_groups == 0, "Input cahnnels must divide deformable group size" 83 assert groups == 1, "deformable_conv2d_nchw does not support groups > 1" 84 85 ic_per_dgroup = channel // deformable_groups 86 87 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 88 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 89 pad_top, pad_left, _, _ = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w)) 90 rc = te.reduce_axis((0, in_channel), name="rc") 91 ry = te.reduce_axis((0, kernel_h), name="ry") 92 rx = te.reduce_axis((0, kernel_w), name="rx") 93 94 zero = tvm.tir.const(0.0, data.dtype) 95 96 def _bilinear(n, c, h, w): 97 outside = tvm.tir.any(h < 0, w < 0, h >= in_height, w >= in_width) 98 val = bilinear_sample_nchw(data, (n, c, h, w), in_height - 1, in_width - 1) 99 return tvm.tir.if_then_else(outside, zero, val) 100 101 data_deform = te.compute( 102 (batch, in_channel, kernel_h, kernel_w, out_height, out_width), 103 lambda n, c, kh, kw, y, x: _bilinear( 104 n, 105 c, 106 y * stride_h 107 - pad_top 108 + kh * dilation_h 109 + offset[ 110 n, c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2, y, x 111 ], 112 x * stride_w 113 - pad_left 114 + kw * dilation_w 115 + offset[ 116 n, 117 c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2 + 1, 118 y, 119 x, 120 ], 121 ), 122 tag="data_deform", 123 ) 124 return te.compute( 125 (batch, out_channel, out_height, out_width), 126 lambda n, f, y, x: te.sum( 127 data_deform[n, rc, ry, rx, y, x].astype(out_dtype) 128 * kernel[f, rc, ry, rx].astype(out_dtype), 129 axis=[rc, ry, rx], 130 ), 131 tag="deformable_conv2d_nchw", 132 ) 133