1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-variable, too-many-locals, unused-argument
18"""Depthwise convolution operators"""
19from __future__ import absolute_import as _abs
20from collections import namedtuple
21import tvm
22from tvm import te
23
24from .dilate import dilate
25from .pad import pad
26from .util import get_pad_tuple
27from ..util import simplify
28
29# workload description of depthwise-conv2d
30Workload = namedtuple(
31    "Workload",
32    [
33        "in_dtype",
34        "out_dtype",
35        "height",
36        "width",
37        "in_filter",
38        "out_filter",
39        "hkernel",
40        "wkernel",
41        "hpad",
42        "wpad",
43        "hstride",
44        "wstride",
45    ],
46)
47
48
49def _get_workload(data, kernel, stride, padding, out_dtype):
50    """ Get the workload structure. """
51    _, in_channel, height, width = [x.value for x in data.shape]
52    channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape]
53    out_channel = channel * channel_multiplier
54    HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
55    if isinstance(stride, (tuple, list)):
56        HSTR, WSTR = stride
57    else:
58        HSTR, WSTR = stride, stride
59    assert (data.dtype == kernel.dtype) or (
60        data.dtype == "uint8" and kernel.dtype == "int8"
61    ), "Do not support inputs with different data types now. ' \
62        '{} vs. {}".format(
63        data.dtype, kernel.dtype
64    )
65    return Workload(
66        data.dtype,
67        out_dtype,
68        height,
69        width,
70        in_channel,
71        out_channel,
72        kh,
73        kw,
74        HPAD,
75        WPAD,
76        HSTR,
77        WSTR,
78    )
79
80
81def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
82    """Depthwise convolution nchw forward operator.
83
84    Parameters
85    ----------
86    Input : tvm.te.Tensor
87        4-D with shape [batch, in_channel, in_height, in_width]
88
89    Filter : tvm.te.Tensor
90        4-D with shape [in_channel, channel_multiplier, filter_height, filter_width]
91
92    stride : tuple of two ints
93        The spatial stride along height and width
94
95    padding : int or str
96        Padding size, or ['VALID', 'SAME']
97
98    dilation: int or a list/tuple of two ints
99        dilation size, or [dilation_height, dilation_width]
100
101    out_dtype: str, optional
102        Output data type
103
104    Returns
105    -------
106    Output : tvm.te.Tensor
107        4-D with shape [batch, out_channel, out_height, out_width]
108    """
109    out_dtype = Input.dtype if out_dtype is None else out_dtype
110
111    if isinstance(stride, int):
112        stride_h = stride_w = stride
113    else:
114        stride_h, stride_w = stride
115
116    if isinstance(dilation, int):
117        dilation_h = dilation_w = dilation
118    else:
119        dilation_h, dilation_w = dilation
120
121    batch, in_channel, in_height, in_width = Input.shape
122    # shape of dilated kernel
123    filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape
124
125    dilated_kernel_h = (filter_height - 1) * dilation_h + 1
126    dilated_kernel_w = (filter_width - 1) * dilation_w + 1
127    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
128        padding, (dilated_kernel_h, dilated_kernel_w)
129    )
130    out_channel = simplify(in_channel * channel_multiplier)
131    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
132    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
133
134    # padding stage
135    pad_before = [0, 0, pad_top, pad_left]
136    pad_after = [0, 0, pad_down, pad_right]
137    PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
138    # depthconv stage
139    idxdiv = tvm.tir.indexdiv
140    idxmod = tvm.tir.indexmod
141    di = te.reduce_axis((0, filter_height), name="di")
142    dj = te.reduce_axis((0, filter_width), name="dj")
143    Output = te.compute(
144        (batch, out_channel, out_height, out_width),
145        lambda b, c, i, j: te.sum(
146            (
147                PaddedInput[
148                    b,
149                    idxdiv(c, channel_multiplier),
150                    i * stride_h + di * dilation_h,
151                    j * stride_w + dj * dilation_w,
152                ].astype(out_dtype)
153                * Filter[
154                    idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier), di, dj
155                ].astype(out_dtype)
156            ),
157            axis=[di, dj],
158        ),
159        name="DepthwiseConv2d",
160        tag="depthwise_conv2d_nchw",
161    )
162    return Output
163
164
165def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=None):
166    """Depthwise convolution nhwc forward operator.
167
168    Parameters
169    ----------
170    Input : tvm.te.Tensor
171        4-D with shape [batch, in_height, in_width, in_channel]
172
173    Filter : tvm.te.Tensor
174        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
175
176    stride : tuple of two ints
177        The spatial stride along height and width
178
179    padding : int or str
180        Padding size, or ['VALID', 'SAME']
181
182    dilation: int or a list/tuple of two ints
183        dilation size, or [dilation_height, dilation_width]
184
185    out_dtype: str, optional
186        Output data type
187
188    Returns
189    -------
190    Output : tvm.te.Tensor
191        4-D with shape [batch, out_height, out_width, out_channel]
192    """
193    out_dtype = Input.dtype if out_dtype is None else out_dtype
194
195    if isinstance(stride, int):
196        stride_h = stride_w = stride
197    else:
198        stride_h, stride_w = stride
199
200    if isinstance(dilation, int):
201        dilation_h = dilation_w = dilation
202    else:
203        dilation_h, dilation_w = dilation
204
205    batch, in_height, in_width, in_channel = Input.shape
206    # shape of dilated kernel
207    filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape
208
209    dilated_kernel_h = (filter_height - 1) * dilation_h + 1
210    dilated_kernel_w = (filter_width - 1) * dilation_w + 1
211    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
212        padding, (dilated_kernel_h, dilated_kernel_w)
213    )
214    out_channel = simplify(in_channel * channel_multiplier)
215    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
216    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
217
218    # padding stage
219    pad_before = [0, pad_top, pad_left, 0]
220    pad_after = [0, pad_down, pad_right, 0]
221    PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
222    # depthconv stage
223    idxdiv = tvm.tir.indexdiv
224    idxmod = tvm.tir.indexmod
225
226    di = te.reduce_axis((0, filter_height), name="di")
227    dj = te.reduce_axis((0, filter_width), name="dj")
228    Output = te.compute(
229        (batch, out_height, out_width, out_channel),
230        lambda b, i, j, c: te.sum(
231            (
232                PaddedInput[
233                    b,
234                    i * stride_h + di * dilation_h,
235                    j * stride_w + dj * dilation_w,
236                    idxdiv(c, channel_multiplier),
237                ].astype(out_dtype)
238                * Filter[
239                    di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)
240                ].astype(out_dtype)
241            ),
242            axis=[di, dj],
243        ),
244        name="DepthwiseConv2d",
245        tag="depthwise_conv2d_nhwc",
246    )
247    return Output
248
249
250def depthwise_conv2d_backward_input_nhwc(Filter, Out_grad, oshape, ishape, stride, padding):
251    """Depthwise convolution nhwc backward wrt input operator.
252
253    Parameters
254    ----------
255    Filter : tvm.te.Tensor
256        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
257
258    Out_grad : tvm.te.Tensor
259        4-D with shape [batch, out_height, out_width, out_channel]
260
261    stride : tuple of two ints
262        The spatial stride along height and width
263
264    padding : int or str
265        Padding size, or ['VALID', 'SAME']
266
267    Returns
268    -------
269    Output : tvm.te.Tensor
270        4-D with shape [batch, in_height, in_width, in_channel]
271    """
272    batch, in_h, in_w, in_c = ishape
273    _, out_h, out_w, out_c = oshape
274    filter_h, filter_w, _, channel_multiplier = Filter.shape
275    if isinstance(stride, int):
276        stride_h = stride_w = stride
277    else:
278        stride_h, stride_w = stride
279
280    dilated_out_grad = dilate(Out_grad, [1, stride_h, stride_w, 1], name="dilated_out_grad")
281
282    # padding params in forward propagation
283    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
284    # padding params in backward propagation
285    bpad_top = filter_h - 1 - fpad_top
286    bpad_bottom = (filter_h - 1 - fpad_bottom) + (stride_h - 1)
287    bpad_left = filter_w - 1 - fpad_left
288    bpad_right = (filter_w - 1 - fpad_right) + (stride_w - 1)
289
290    padded_out_grad = pad(
291        dilated_out_grad,
292        [0, bpad_top, bpad_left, 0],
293        [0, bpad_bottom, bpad_right, 0],
294        name="padded_out_grad",
295    )
296
297    dh = te.reduce_axis((0, filter_h), name="dh")
298    dw = te.reduce_axis((0, filter_w), name="dw")
299    dc = te.reduce_axis((0, channel_multiplier), name="dc")
300
301    In_grad = te.compute(
302        (batch, in_h, in_w, in_c),
303        lambda b, h, w, c: te.sum(
304            padded_out_grad[b, h + dh, w + dw, c * channel_multiplier + dc]
305            * Filter[filter_h - 1 - dh, filter_w - 1 - dw, c, dc],
306            axis=[dh, dw, dc],
307        ),
308        tag="depthwise_conv2d_backward_input_nhwc",
309    )
310
311    return In_grad
312
313
314def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, stride, padding):
315    """Depthwise convolution nhwc backward wrt weight operator.
316
317    Parameters
318    ----------
319    Input : tvm.te.Tensor
320        4-D with shape [batch, in_height, in_width, in_channel]
321
322    Out_grad : tvm.te.Tensor
323        4-D with shape [batch, out_height, out_width, out_channel]
324
325    stride : tuple of two ints
326        The spatial stride along height and width
327
328    padding : int or str
329        Padding size, or ['VALID', 'SAME']
330
331    Returns
332    -------
333    Output : tvm.te.Tensor
334        4-D with shape [filter_height, filter_width, in_channel, channel_multiplier]
335    """
336    batch, out_h, out_w, out_c = oshape
337    filter_h, filter_w, _, channel_multiplier = fshape
338    in_c = Input.shape[3].value
339    if isinstance(stride, int):
340        stride_h = stride_w = stride
341    else:
342        stride_h, stride_w = stride
343
344    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (filter_h, filter_w))
345
346    padded_in = pad(
347        Input, [0, pad_top, pad_left, 0], [0, pad_bottom, pad_right, 0], name="padded_in"
348    )
349
350    dh = te.reduce_axis((0, Out_grad.shape[1].value), name="dh")
351    dw = te.reduce_axis((0, Out_grad.shape[2].value), name="dw")
352    db = te.reduce_axis((0, batch), name="db")
353    idxdiv = tvm.tir.indexdiv
354    idxmod = tvm.tir.indexmod
355
356    Weight_grad = te.compute(
357        (filter_h, filter_w, in_c, channel_multiplier),
358        lambda fh, fw, c, m: te.sum(
359            Out_grad[db, dh, dw, c * channel_multiplier + idxmod(m, channel_multiplier)]
360            * padded_in[db, fh + dh * stride_h, fw + dw * stride_w, c],
361            axis=[db, dh, dw],
362        ),
363        tag="depthwise_conv2d_backward_weight_nhwc",
364    )
365
366    return Weight_grad
367
368
369def depthwise_conv2d_NCHWc(
370    Input, Filter, stride, padding, dilation, layout, out_layout, out_dtype=None
371):
372    """Depthwise convolution NCHW[x]c forward operator.
373
374    Parameters
375    ----------
376    Input : tvm.te.Tensor
377        5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
378
379    Filter : tvm.te.Tensor
380        6-D with shape [out_channel_chunk, 1, filter_height, filter_width, 1, out_channel_block]
381        In NCHWc depthwise convolution,
382        we group kernel's in_channel and channel_multiplier together then do the tiling.
383
384    stride : tuple of two ints
385        The spatial stride along height and width
386
387    padding : int or str
388        Padding size, or ['VALID', 'SAME']
389
390    dilation: int or a list/tuple of two ints
391         dilation size, or [dilation_height, dilation_width]
392
393    layout : str
394        Input data layout
395
396    out_layout : str
397        Output data layout
398
399    out_dtype: str, optional
400        Output data type
401
402    Returns
403    -------
404    Output : tvm.te.Tensor
405        5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
406    """
407    raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
408
409
410@tvm.target.generic_func
411def depthwise_conv2d_infer_layout(workload, cfg):
412    """Infer input/output shapes and layouts from a workload and cfg.
413
414    Parameters
415    ----------
416    workload : tuple
417        conv2d workload
418
419    cfg : tuple
420        tvm.autotvm config
421
422    Returns
423    -------
424    Output : [tuple of tuple and str, tuple of tuple and str]
425        Input shapes and layouts, and output shapes and layouts
426    """
427    raise ValueError("missing register for topi.nn.depthwise_conv2d_infer_layout")
428