1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name 18"""Dilation operators""" 19from __future__ import absolute_import as _abs 20import tvm 21from .. import util 22from .. import tag 23 24@tvm.tag_scope(tag=tag.INJECTIVE+",dilate") 25def dilate(data, strides, name="DilatedInput"): 26 """Dilate data with zeros. 27 28 Parameters 29 ---------- 30 data : tvm.Tensor 31 n-D, can be any layout. 32 33 strides : list / tuple of n ints 34 Dilation stride on each dimension, 1 means no dilation. 35 36 name : str, optional 37 The name prefix operators generated 38 39 Returns 40 ------- 41 Output : tvm.Tensor 42 n-D, the same layout as data. 43 """ 44 n = len(data.shape) 45 if len(strides) != n: 46 raise ValueError("data dimension and strides size dismatch : %d vs %d" % ( 47 n, len(strides))) 48 49 out_shape = tuple( 50 tvm.ir_pass.Simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n)) 51 52 def _dilate(*indices): 53 not_zero = [] 54 index_tuple = [] 55 idxdiv = tvm.indexdiv 56 idxmod = tvm.indexmod 57 for i in range(n): 58 if not util.equal_const_int(strides[i], 1): 59 index_tuple.append(idxdiv(indices[i], strides[i])) 60 not_zero.append(idxmod(indices[i], strides[i]).equal(0)) 61 else: 62 index_tuple.append(indices[i]) 63 if not_zero: 64 not_zero = tvm.all(*not_zero) 65 return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) 66 return data(*index_tuple) 67 68 return tvm.compute(out_shape, _dilate, name=name) 69