1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, too-many-locals, too-many-arguments
18"""Deformable Conv2D operators"""
19import tvm
20from tvm import te
21
22from .util import get_pad_tuple
23from ..util import get_const_tuple
24from ..cpp.util import bilinear_sample_nchw
25
26
27def deformable_conv2d_nchw(
28    data, offset, kernel, strides, padding, dilation, deformable_groups, groups, out_dtype
29):
30    """Deformable conv2D operator in NCHW layout.
31
32    The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
33
34    Parameters
35    ----------
36    data : tvm.te.Tensor
37        4-D with shape [batch, in_channel, in_height, in_width]
38
39    offset : tvm.te.Tensor
40        4-D with shape [batch, deformable_groups * filter_height * filter_width * 2,
41        out_height, out_width].
42
43    kernel : tvm.te.Tensor
44        4-D with shape [num_filter, in_channel, filter_height, filter_width]
45
46    strides : int or a list/tuple of two ints
47        stride size, or [stride_height, stride_width]
48
49    padding : int or a list/tuple of two ints
50        padding size, or [pad_height, pad_width]
51
52    dilation : int or a list/tuple of two ints
53        dilation size, or [dilation_height, dilation_width]
54
55    deformable_groups : int
56        number of deformable groups
57
58    groups : int
59        number of groups
60
61    Returns
62    -------
63    output : tvm.te.Tensor
64        4-D with shape [batch, out_channel, out_height, out_width]
65    """
66    if out_dtype is None:
67        out_dtype = data.dtype
68
69    if isinstance(strides, int):
70        stride_h = stride_w = strides
71    else:
72        stride_h, stride_w = strides
73
74    if isinstance(dilation, int):
75        dilation_h = dilation_w = dilation
76    else:
77        dilation_h, dilation_w = dilation
78
79    batch, in_channel, in_height, in_width = get_const_tuple(data.shape)
80    out_channel, channel, kernel_h, kernel_w = get_const_tuple(kernel.shape)
81    _, _, out_height, out_width = get_const_tuple(offset.shape)
82    assert in_channel % deformable_groups == 0, "Input cahnnels must divide deformable group size"
83    assert groups == 1, "deformable_conv2d_nchw does not support groups > 1"
84
85    ic_per_dgroup = channel // deformable_groups
86
87    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
88    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
89    pad_top, pad_left, _, _ = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
90    rc = te.reduce_axis((0, in_channel), name="rc")
91    ry = te.reduce_axis((0, kernel_h), name="ry")
92    rx = te.reduce_axis((0, kernel_w), name="rx")
93
94    zero = tvm.tir.const(0.0, data.dtype)
95
96    def _bilinear(n, c, h, w):
97        outside = tvm.tir.any(h < 0, w < 0, h >= in_height, w >= in_width)
98        val = bilinear_sample_nchw(data, (n, c, h, w), in_height - 1, in_width - 1)
99        return tvm.tir.if_then_else(outside, zero, val)
100
101    data_deform = te.compute(
102        (batch, in_channel, kernel_h, kernel_w, out_height, out_width),
103        lambda n, c, kh, kw, y, x: _bilinear(
104            n,
105            c,
106            y * stride_h
107            - pad_top
108            + kh * dilation_h
109            + offset[
110                n, c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2, y, x
111            ],
112            x * stride_w
113            - pad_left
114            + kw * dilation_w
115            + offset[
116                n,
117                c // ic_per_dgroup * (kernel_w * kernel_h * 2) + (kh * kernel_w + kw) * 2 + 1,
118                y,
119                x,
120            ],
121        ),
122        tag="data_deform",
123    )
124    return te.compute(
125        (batch, out_channel, out_height, out_width),
126        lambda n, f, y, x: te.sum(
127            data_deform[n, rc, ry, rx, y, x].astype(out_dtype)
128            * kernel[f, rc, ry, rx].astype(out_dtype),
129            axis=[rc, ry, rx],
130        ),
131        tag="deformable_conv2d_nchw",
132    )
133