1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""Dilation operators"""
19from __future__ import absolute_import as _abs
20import tvm
21from .. import util
22from .. import tag
23
24@tvm.tag_scope(tag=tag.INJECTIVE+",dilate")
25def dilate(data, strides, name="DilatedInput"):
26    """Dilate data with zeros.
27
28    Parameters
29    ----------
30    data : tvm.Tensor
31        n-D, can be any layout.
32
33    strides : list / tuple of n ints
34        Dilation stride on each dimension, 1 means no dilation.
35
36    name : str, optional
37        The name prefix operators generated
38
39    Returns
40    -------
41    Output : tvm.Tensor
42        n-D, the same layout as data.
43    """
44    n = len(data.shape)
45    if len(strides) != n:
46        raise ValueError("data dimension and strides size dismatch : %d vs %d" % (
47            n, len(strides)))
48
49    out_shape = tuple(
50        tvm.ir_pass.Simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n))
51
52    def _dilate(*indices):
53        not_zero = []
54        index_tuple = []
55        idxdiv = tvm.indexdiv
56        idxmod = tvm.indexmod
57        for i in range(n):
58            if not util.equal_const_int(strides[i], 1):
59                index_tuple.append(idxdiv(indices[i], strides[i]))
60                not_zero.append(idxmod(indices[i], strides[i]).equal(0))
61            else:
62                index_tuple.append(indices[i])
63        if not_zero:
64            not_zero = tvm.all(*not_zero)
65            return tvm.if_then_else(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
66        return data(*index_tuple)
67
68    return tvm.compute(out_shape, _dilate, name=name)
69