1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-variable, unused-argument
18"""1D convolution operators."""
19from tvm import te
20from .pad import pad
21from ..util import simplify
22from .util import get_pad_tuple1d
23
24
25def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", out_dtype=None):
26    """1D convolution forward operator.
27
28    Parameters
29    ----------
30    data : tvm.te.Tensor
31        3-D input shape [batch, in_channel, in_width] for layout == 'NCW'
32        and [batch, in_width, in_channel] for layout == 'NWC'
33
34    kernel : tvm.te.Tensor
35        3-D kernel with shape [num_filter, in_channel, filter_size] for layout == 'NCW'
36        and [filter_size, in_channel, num_filter] for layout == 'NWC'
37
38    strides : int or tuple
39        The spatial stride along width
40
41    padding : int or str
42        Padding size, or ['VALID', 'SAME']
43
44    dilation : int or tuple
45        Dilation rate if convolution should be dilated.
46
47    layout : str
48        How input data is laid out, must be one of ['NCW', 'NWC']
49
50    out_dtype : str
51        The output data type. If None then output is same type as input.
52    """
53    if out_dtype is None:
54        out_dtype = data.dtype
55    if isinstance(strides, (tuple, list)):
56        strides = strides[0]
57    if isinstance(dilation, (tuple, list)):
58        dilation = dilation[0]
59
60    if layout == "NCW":
61        return conv1d_ncw(data, kernel, strides, padding, dilation, out_dtype)
62    if layout == "NWC":
63        return conv1d_nwc(data, kernel, strides, padding, dilation, out_dtype)
64    raise ValueError("This layout is not yet supported: {}".format(layout))
65
66
67def conv1d_ncw(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
68    """1D convolution forward operator for NCW layout.
69
70    Parameters
71    ----------
72    data : tvm.te.Tensor
73        3-D with shape [batch, in_channel, in_width]
74
75    kernel : tvm.te.Tensor
76        3-D with shape [num_filter, in_channel, filter_size]
77
78    strides : int or tuple
79        The spatial stride along width
80
81    padding : int, tuple, or str
82        Padding size can be an integer for equal padding,
83        a tuple of (left, right) or a string in ['VALID', 'SAME'].
84
85    dilation : int or tuple
86        Dilation rate if convolution should be dilated.
87
88    out_dtype : str
89        The output data type. If None then output is same type as input.
90    """
91    if out_dtype is None:
92        out_dtype = data.dtype
93    if isinstance(strides, (tuple, list)):
94        strides = strides[0]
95    if isinstance(dilation, (tuple, list)):
96        dilation = dilation[0]
97
98    batch, in_channels, data_width = data.shape
99    out_channels, _, kernel_size = kernel.shape
100
101    # Compute the output shape
102    dilated_kernel_size = (kernel_size - 1) * dilation + 1
103    pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size,))
104    out_channels = simplify(out_channels)
105    out_width = simplify((data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
106
107    # Apply padding
108    pad_before = [0, 0, pad_left]
109    pad_after = [0, 0, pad_right]
110    temp = pad(data, pad_before, pad_after, name="pad_temp")
111
112    # Compute graph
113    rc = te.reduce_axis((0, in_channels), name="rc")
114    rw = te.reduce_axis((0, kernel_size), name="rw")
115
116    return te.compute(
117        (batch, out_channels, out_width),
118        lambda b, c, w: te.sum(
119            temp[b, rc, w * strides + rw * dilation].astype(out_dtype)
120            * kernel[c, rc, rw].astype(out_dtype),
121            axis=[rc, rw],
122        ),
123        tag="conv1d_ncw",
124    )
125
126
127def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
128    """1D convolution forward operator for NWC layout.
129
130    Parameters
131    ----------
132    data : tvm.te.Tensor
133        3-D with shape [batch, in_width, in_channel]
134
135    kernel : tvm.te.Tensor
136        3-D with shape [filter_size, in_channel, num_filter]
137
138    strides : int or tuple
139        The spatial stride along width
140
141    padding : int, tuple, or str
142        Padding size can be an integer for equal padding,
143        a tuple of (left, right) or a string in ['VALID', 'SAME'].
144
145    dilation : int or tuple
146        Dilation rate if convolution should be dilated.
147
148    out_dtype : str
149        The output data type. If None then output is same type as input.
150    """
151    if out_dtype is None:
152        out_dtype = data.dtype
153    if isinstance(strides, (tuple, list)):
154        strides = strides[0]
155    if isinstance(dilation, (tuple, list)):
156        dilation = dilation[0]
157
158    batch, data_width, in_channels = data.shape
159    kernel_size, _, out_channels = kernel.shape
160
161    # Compute the output shape
162    dilated_kernel_size = (kernel_size - 1) * dilation + 1
163    pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size,))
164    out_channels = simplify(out_channels)
165    out_width = simplify((data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
166
167    # Apply padding
168    pad_before = [0, pad_left, 0]
169    pad_after = [0, pad_right, 0]
170    temp = pad(data, pad_before, pad_after, name="pad_temp")
171
172    # Compute graph
173    rc = te.reduce_axis((0, in_channels), name="rc")
174    rw = te.reduce_axis((0, kernel_size), name="rw")
175
176    return te.compute(
177        (batch, out_width, out_channels),
178        lambda b, w, c: te.sum(
179            temp[b, w * strides + rw * dilation, rc].astype(out_dtype)
180            * kernel[rw, rc, c].astype(out_dtype),
181            axis=[rc, rw],
182        ),
183        tag="conv1d_nwc",
184    )
185