1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""FIFO buffer op"""
19from __future__ import absolute_import as _abs
20import tvm
21from .. import tag
22from ..transform import concatenate, strided_slice
23
24@tvm.tag_scope(tag=tag.INJECTIVE+",fifo_buffer")
25def fifo_buffer(data, buffer, axis):
26    """
27    FIFO buffer to enable computation reuse in CNNs with sliding indow input
28
29    Compute equivalent of
30
31    .. code-block:: python
32
33        concat(buffer, data, axis=axis)
34        .slice_axis(axis=axis,
35                    begin=data.shape[axis],
36                    end=data.shape[axis]+buffer.shape[axis])
37
38    Useful for
39
40    * Encoding explicit re-use of computation in convolution ops operated on a sliding window input
41    * Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.
42
43    Parameters
44    ----------
45    data : tvm.Tensor
46        The input data
47    buffer : tvm.Tensor
48        Previous value of the FIFO buffer
49    axis : int
50        Specify which axis should be used for buffering
51
52    Returns
53    -------
54    result : tvm.Tensor
55        Updated value for the buffer
56    """
57    assert len(data.shape) == len(buffer.shape), \
58        'buffer and data must have same number of dimensions, ' + \
59        'buffer.shape = {}, data.shape = {}'.format(buffer.shape, data.shape)
60    assert len(buffer.shape) >= 1, 'Zero-dimension tensor not supported'
61    assert 0 <= axis < len(buffer.shape), 'buffer axis out of range'
62    for i in range(len(data.shape)):
63        if i == axis:
64            assert int(str(data.shape[i])) <= int(str(buffer.shape[i]))
65        else:
66            assert int(str(data.shape[i])) == int(str(buffer.shape[i]))
67
68    buflen = buffer.shape[axis]
69    data_size = data.shape[axis]
70
71    # Explicitly write out formula up to 4D, and then use concat+slice combo for 5D and higher
72    if len(buffer.shape) == 1:
73        return tvm.compute(buffer.shape,
74                           lambda i:
75                           tvm.if_then_else(i < buflen - data_size,
76                                            buffer[i + data_size],
77                                            data[i - buflen + data_size]),
78                           name='new_buffer')
79    elif len(buffer.shape) == 2:
80        if axis == 0:
81            return tvm.compute(buffer.shape,
82                               lambda i, j:
83                               tvm.if_then_else(i < buflen - data_size,
84                                                buffer[i + data_size, j],
85                                                data[i - buflen + data_size, j]),
86                               name='new_buffer')
87        if axis == 1:
88            return tvm.compute(buffer.shape,
89                               lambda i, j:
90                               tvm.if_then_else(j < buflen - data_size,
91                                                buffer[i, j + data_size],
92                                                data[i, j - buflen + data_size]),
93                               name='new_buffer')
94        assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
95    elif len(buffer.shape) == 3:
96        if axis == 0:
97            return tvm.compute(buffer.shape,
98                               lambda i, j, k:
99                               tvm.if_then_else(i < buflen - data_size,
100                                                buffer[i + data_size, j, k],
101                                                data[i - buflen + data_size, j, k]),
102                               name='new_buffer')
103        if axis == 1:
104            return tvm.compute(buffer.shape,
105                               lambda i, j, k:
106                               tvm.if_then_else(j < buflen - data_size,
107                                                buffer[i, j + data_size, k],
108                                                data[i, j - buflen + data_size, k]),
109                               name='new_buffer')
110        if axis == 2:
111            return tvm.compute(buffer.shape,
112                               lambda i, j, k:
113                               tvm.if_then_else(k < buflen - data_size,
114                                                buffer[i, j, k + data_size],
115                                                data[i, j, k - buflen + data_size]),
116                               name='new_buffer')
117        assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
118    elif len(buffer.shape) == 4:
119        if axis == 0:
120            return tvm.compute(buffer.shape,
121                               lambda i, j, k, l:
122                               tvm.if_then_else(i < buflen - data_size,
123                                                buffer[i + data_size, j, k, l],
124                                                data[i - buflen + data_size, j, k, l]),
125                               name='new_buffer')
126        if axis == 1:
127            return tvm.compute(buffer.shape,
128                               lambda i, j, k, l:
129                               tvm.if_then_else(j < buflen - data_size,
130                                                buffer[i, j + data_size, k, l],
131                                                data[i, j - buflen + data_size, k, l]),
132                               name='new_buffer')
133        if axis == 2:
134            return tvm.compute(buffer.shape,
135                               lambda i, j, k, l:
136                               tvm.if_then_else(k < buflen - data_size,
137                                                buffer[i, j, k + data_size, l],
138                                                data[i, j, k - buflen + data_size, l]),
139                               name='new_buffer')
140        if axis == 3:
141            return tvm.compute(buffer.shape,
142                               lambda i, j, k, l:
143                               tvm.if_then_else(l < buflen - data_size,
144                                                buffer[i, j, k, l + data_size],
145                                                data[i, j, k, l - buflen + data_size]),
146                               name='new_buffer')
147        assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape))
148    else:
149        # Implement FIFO buffer as combination of concat and slice
150        begin = [0] * len(buffer.shape)
151        begin[axis] = data.shape[axis]
152        end = list(buffer.shape[:])
153        end[axis] += data.shape[axis]
154        return strided_slice(concatenate((buffer, data), axis=axis), begin=begin, end=end)
155    return None
156