1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-variable
18"""NN operator common utilities"""
19from tvm.ir import container
20
21
22def get_pad_tuple1d(padding):
23    """Common code to get the 1 dimensional pad option
24    Parameters
25    ----------
26    padding : Union[int, Tuple[int, ...]]
27        Padding size
28    Returns
29    -------
30    pad_left : int
31        Padding size on left
32    pad_right : int
33        Padding size on right.
34    """
35    # compute the padding size
36    if isinstance(padding, container.Array):
37        padding = list(padding)
38    if isinstance(padding, (tuple, list)):
39        if len(padding) == 1:
40            pad_w = padding[0] * 2
41        elif len(padding) == 2:
42            return padding[0], padding[1]
43        else:
44            raise ValueError("Size of padding can only be 1 or 2")
45    elif isinstance(padding, int):
46        pad_w = padding * 2
47    else:
48        raise ValueError("Unknown padding option %s" % padding)
49    pad_left = (pad_w + 1) // 2
50    return pad_left, pad_w - pad_left
51
52
53def get_pad_tuple2d(padding):
54    """Common code to get the pad option
55    Parameters
56    ----------
57    padding : Union[int, Tuple[int, ...]]
58        Padding size
59    Returns
60    -------
61    pad_top : int
62        Padding size on top
63    pad_left : int
64        Padding size on left
65    pad_down : int
66        Padding size on down.
67    pad_right : int
68        Padding size on right.
69    """
70    # compute the padding size
71    if isinstance(padding, container.Array):
72        padding = list(padding)
73    if isinstance(padding, (tuple, list)):
74        if len(padding) == 2:
75            pad_h = padding[0] * 2
76            pad_w = padding[1] * 2
77        elif len(padding) == 4:
78            return padding[0], padding[1], padding[2], padding[3]
79        else:
80            raise ValueError("Size of padding can only be 2 or 4")
81    elif isinstance(padding, int):
82        pad_h = pad_w = padding * 2
83    else:
84        raise ValueError("Unknown padding option %s" % padding)
85    pad_top = (pad_h + 1) // 2
86    pad_left = (pad_w + 1) // 2
87    return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
88
89
90def get_pad_tuple3d(padding):
91    """Common code to get the pad option
92    Parameters
93    ----------
94    padding : Union[int, Tuple[int, ...]]
95        Padding size
96    Returns
97    -------
98    pad_front : int
99        Padding size on front
100    pad_top : int
101        Padding size on top
102    pad_left : int
103        Padding size on left
104    pad_back : int
105        Padding size on back
106    pad_down : int
107        Padding size on down.
108    pad_right : int
109        Padding size on right.
110    """
111    # compute the padding size
112    if isinstance(padding, container.Array):
113        padding = list(padding)
114    if isinstance(padding, (tuple, list)):
115        if len(padding) == 3:
116            pad_d = padding[0] * 2
117            pad_h = padding[1] * 2
118            pad_w = padding[2] * 2
119        elif len(padding) == 6:
120            return padding[0], padding[1], padding[2], padding[3], padding[4], padding[5]
121        else:
122            raise ValueError("Size of padding can only be 3 or 6")
123    elif isinstance(padding, int):
124        pad_d = pad_h = pad_w = padding * 2
125    else:
126        raise ValueError("Unknown padding option %s" % padding)
127    pad_front = (pad_d + 1) // 2
128    pad_top = (pad_h + 1) // 2
129    pad_left = (pad_w + 1) // 2
130    return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left
131