1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Pad the data by constant value """ 18from __future__ import absolute_import as _abs 19import tvm 20from ..util import equal_const_int 21from .. import tag 22 23@tvm.tag_scope(tag=tag.INJECTIVE+",pad") 24def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): 25 """Pad Input with zeros. 26 27 Parameters 28 ---------- 29 data : tvm.Tensor 30 n-D input, can be any layout. 31 32 pad_before : list / tuple of n ints 33 Pad width on each dimension to pad the before the axis begin. 34 35 pad_after : list / tuple of n ints, optional 36 Pad width each dimension to pad the after the axis end. 37 38 pad_value : float, optional 39 The value to be padded. 40 41 name : str, optional 42 The name prefix operators generated 43 44 Returns 45 ------- 46 Output : tvm.Tensor 47 n-D, the same layout as Input. 48 """ 49 n = len(data.shape) 50 pad_after = pad_after if pad_after else pad_before 51 if len(pad_before) != n: 52 raise ValueError("Input dimension and pad_before dismatch : %d vs %d" % ( 53 n, len(pad_before))) 54 if len(pad_after) != n: 55 raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % ( 56 n, len(pad_before))) 57 out_shape = tuple( 58 tvm.ir_pass.Simplify( 59 (data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n)) 60 pad_value = (pad_value if isinstance(pad_value, tvm.expr.Expr) 61 else tvm.const(pad_value, data.dtype)) 62 def _pad(*indices): 63 not_zero = [] 64 index_tuple = [] 65 for i in range(n): 66 if equal_const_int(pad_before[i], 0) and equal_const_int(pad_after[i], 0): 67 index_tuple.append(indices[i]) 68 else: 69 index_tuple.append(indices[i] - pad_before[i]) 70 not_zero.append(indices[i] >= pad_before[i]) 71 not_zero.append(indices[i] < data.shape[i] + pad_before[i]) 72 if not_zero: 73 not_zero = tvm.all(*not_zero) 74 return tvm.if_then_else(not_zero, data(*index_tuple), pad_value) 75 return data(*index_tuple) 76 return tvm.compute(out_shape, _pad, name=name) 77 78 79@tvm.tag_scope(tag=tag.INJECTIVE + ",pad") 80def mirror_pad(data, 81 pad_before, 82 pad_after=None, 83 mode='SYMMETRIC', 84 name="MirrorPadInput"): 85 """Pad Input with mirroring either symmetric or reflected. 86 87 Parameters 88 ---------- 89 data : tvm.Tensor 90 n-D input, can be any layout. 91 92 pad_before : list / tuple of n ints 93 Pad width on each dimension to pad the before the axis begin. 94 95 pad_after : list / tuple of n ints, optional 96 Pad width each dimension to pad the after the axis end. 97 98 mode: str, optional 99 Type of mirror padding to apply. Must be SYMMETRIC or REFLECT 100 101 name : str, optional 102 The name prefix operators generated 103 104 Returns 105 ------- 106 Output : tvm.Tensor 107 n-D, the same layout as Input. 108 """ 109 n = len(data.shape) 110 pad_after = pad_after if pad_after else pad_before 111 if len(pad_before) != n: 112 raise ValueError("Input dimension and pad_before dismatch : %d vs %d" % 113 (n, len(pad_before))) 114 if len(pad_after) != n: 115 raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % 116 (n, len(pad_before))) 117 out_shape = tuple( 118 tvm.ir_pass.Simplify((data.shape[i] + pad_before[i] + pad_after[i])) 119 for i in range(n)) 120 assert mode in ('SYMMETRIC', 'REFLECT') 121 mode = int(mode == 'SYMMETRIC') 122 123 def _pad(*indices): 124 index_tuple = [] 125 above = [] 126 below = [] 127 for i in range(n): 128 if equal_const_int(pad_before[i], 0) and equal_const_int( 129 pad_after[i], 0): 130 index_tuple.append(indices[i]) 131 above.append(False) 132 below.append(False) 133 else: 134 index_tuple.append(indices[i] - pad_before[i]) 135 above.append(indices[i] >= data.shape[i] + pad_before[i]) 136 below.append(indices[i] < pad_before[i]) 137 mapped_tuple = [] 138 for i, axis in enumerate(index_tuple): 139 mapped_axis = tvm.if_then_else(below[i], -axis - mode, axis) 140 mapped_axis = tvm.if_then_else( 141 above[i], (2 * (data.shape[i] - 1)) - axis + mode, mapped_axis) 142 mapped_tuple.append(mapped_axis) 143 return data(*mapped_tuple) 144 145 return tvm.compute(out_shape, _pad, name=name) 146