1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-variable 18"""NN operator common utilities""" 19from tvm.ir import container 20 21 22def get_pad_tuple1d(padding): 23 """Common code to get the 1 dimensional pad option 24 Parameters 25 ---------- 26 padding : Union[int, Tuple[int, ...]] 27 Padding size 28 Returns 29 ------- 30 pad_left : int 31 Padding size on left 32 pad_right : int 33 Padding size on right. 34 """ 35 # compute the padding size 36 if isinstance(padding, container.Array): 37 padding = list(padding) 38 if isinstance(padding, (tuple, list)): 39 if len(padding) == 1: 40 pad_w = padding[0] * 2 41 elif len(padding) == 2: 42 return padding[0], padding[1] 43 else: 44 raise ValueError("Size of padding can only be 1 or 2") 45 elif isinstance(padding, int): 46 pad_w = padding * 2 47 else: 48 raise ValueError("Unknown padding option %s" % padding) 49 pad_left = (pad_w + 1) // 2 50 return pad_left, pad_w - pad_left 51 52 53def get_pad_tuple2d(padding): 54 """Common code to get the pad option 55 Parameters 56 ---------- 57 padding : Union[int, Tuple[int, ...]] 58 Padding size 59 Returns 60 ------- 61 pad_top : int 62 Padding size on top 63 pad_left : int 64 Padding size on left 65 pad_down : int 66 Padding size on down. 67 pad_right : int 68 Padding size on right. 69 """ 70 # compute the padding size 71 if isinstance(padding, container.Array): 72 padding = list(padding) 73 if isinstance(padding, (tuple, list)): 74 if len(padding) == 2: 75 pad_h = padding[0] * 2 76 pad_w = padding[1] * 2 77 elif len(padding) == 4: 78 return padding[0], padding[1], padding[2], padding[3] 79 else: 80 raise ValueError("Size of padding can only be 2 or 4") 81 elif isinstance(padding, int): 82 pad_h = pad_w = padding * 2 83 else: 84 raise ValueError("Unknown padding option %s" % padding) 85 pad_top = (pad_h + 1) // 2 86 pad_left = (pad_w + 1) // 2 87 return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left 88 89 90def get_pad_tuple3d(padding): 91 """Common code to get the pad option 92 Parameters 93 ---------- 94 padding : Union[int, Tuple[int, ...]] 95 Padding size 96 Returns 97 ------- 98 pad_front : int 99 Padding size on front 100 pad_top : int 101 Padding size on top 102 pad_left : int 103 Padding size on left 104 pad_back : int 105 Padding size on back 106 pad_down : int 107 Padding size on down. 108 pad_right : int 109 Padding size on right. 110 """ 111 # compute the padding size 112 if isinstance(padding, container.Array): 113 padding = list(padding) 114 if isinstance(padding, (tuple, list)): 115 if len(padding) == 3: 116 pad_d = padding[0] * 2 117 pad_h = padding[1] * 2 118 pad_w = padding[2] * 2 119 elif len(padding) == 6: 120 return padding[0], padding[1], padding[2], padding[3], padding[4], padding[5] 121 else: 122 raise ValueError("Size of padding can only be 3 or 6") 123 elif isinstance(padding, int): 124 pad_d = pad_h = pad_w = padding * 2 125 else: 126 raise ValueError("Unknown padding option %s" % padding) 127 pad_front = (pad_d + 1) // 2 128 pad_top = (pad_h + 1) // 2 129 pad_left = (pad_w + 1) // 2 130 return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left 131