1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""FIFO buffer op""" 19from __future__ import absolute_import as _abs 20import tvm 21from .. import tag 22from ..transform import concatenate, strided_slice 23 24@tvm.tag_scope(tag=tag.INJECTIVE+",fifo_buffer") 25def fifo_buffer(data, buffer, axis): 26 """ 27 FIFO buffer to enable computation reuse in CNNs with sliding indow input 28 29 Compute equivalent of 30 31 .. code-block:: python 32 33 concat(buffer, data, axis=axis) 34 .slice_axis(axis=axis, 35 begin=data.shape[axis], 36 end=data.shape[axis]+buffer.shape[axis]) 37 38 Useful for 39 40 * Encoding explicit re-use of computation in convolution ops operated on a sliding window input 41 * Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet. 42 43 Parameters 44 ---------- 45 data : tvm.Tensor 46 The input data 47 buffer : tvm.Tensor 48 Previous value of the FIFO buffer 49 axis : int 50 Specify which axis should be used for buffering 51 52 Returns 53 ------- 54 result : tvm.Tensor 55 Updated value for the buffer 56 """ 57 assert len(data.shape) == len(buffer.shape), \ 58 'buffer and data must have same number of dimensions, ' + \ 59 'buffer.shape = {}, data.shape = {}'.format(buffer.shape, data.shape) 60 assert len(buffer.shape) >= 1, 'Zero-dimension tensor not supported' 61 assert 0 <= axis < len(buffer.shape), 'buffer axis out of range' 62 for i in range(len(data.shape)): 63 if i == axis: 64 assert int(str(data.shape[i])) <= int(str(buffer.shape[i])) 65 else: 66 assert int(str(data.shape[i])) == int(str(buffer.shape[i])) 67 68 buflen = buffer.shape[axis] 69 data_size = data.shape[axis] 70 71 # Explicitly write out formula up to 4D, and then use concat+slice combo for 5D and higher 72 if len(buffer.shape) == 1: 73 return tvm.compute(buffer.shape, 74 lambda i: 75 tvm.if_then_else(i < buflen - data_size, 76 buffer[i + data_size], 77 data[i - buflen + data_size]), 78 name='new_buffer') 79 elif len(buffer.shape) == 2: 80 if axis == 0: 81 return tvm.compute(buffer.shape, 82 lambda i, j: 83 tvm.if_then_else(i < buflen - data_size, 84 buffer[i + data_size, j], 85 data[i - buflen + data_size, j]), 86 name='new_buffer') 87 if axis == 1: 88 return tvm.compute(buffer.shape, 89 lambda i, j: 90 tvm.if_then_else(j < buflen - data_size, 91 buffer[i, j + data_size], 92 data[i, j - buflen + data_size]), 93 name='new_buffer') 94 assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape)) 95 elif len(buffer.shape) == 3: 96 if axis == 0: 97 return tvm.compute(buffer.shape, 98 lambda i, j, k: 99 tvm.if_then_else(i < buflen - data_size, 100 buffer[i + data_size, j, k], 101 data[i - buflen + data_size, j, k]), 102 name='new_buffer') 103 if axis == 1: 104 return tvm.compute(buffer.shape, 105 lambda i, j, k: 106 tvm.if_then_else(j < buflen - data_size, 107 buffer[i, j + data_size, k], 108 data[i, j - buflen + data_size, k]), 109 name='new_buffer') 110 if axis == 2: 111 return tvm.compute(buffer.shape, 112 lambda i, j, k: 113 tvm.if_then_else(k < buflen - data_size, 114 buffer[i, j, k + data_size], 115 data[i, j, k - buflen + data_size]), 116 name='new_buffer') 117 assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape)) 118 elif len(buffer.shape) == 4: 119 if axis == 0: 120 return tvm.compute(buffer.shape, 121 lambda i, j, k, l: 122 tvm.if_then_else(i < buflen - data_size, 123 buffer[i + data_size, j, k, l], 124 data[i - buflen + data_size, j, k, l]), 125 name='new_buffer') 126 if axis == 1: 127 return tvm.compute(buffer.shape, 128 lambda i, j, k, l: 129 tvm.if_then_else(j < buflen - data_size, 130 buffer[i, j + data_size, k, l], 131 data[i, j - buflen + data_size, k, l]), 132 name='new_buffer') 133 if axis == 2: 134 return tvm.compute(buffer.shape, 135 lambda i, j, k, l: 136 tvm.if_then_else(k < buflen - data_size, 137 buffer[i, j, k + data_size, l], 138 data[i, j, k - buflen + data_size, l]), 139 name='new_buffer') 140 if axis == 3: 141 return tvm.compute(buffer.shape, 142 lambda i, j, k, l: 143 tvm.if_then_else(l < buflen - data_size, 144 buffer[i, j, k, l + data_size], 145 data[i, j, k, l - buflen + data_size]), 146 name='new_buffer') 147 assert False, 'Invalid value for axis; it should be at most {}'.format(len(buffer.shape)) 148 else: 149 # Implement FIFO buffer as combination of concat and slice 150 begin = [0] * len(buffer.shape) 151 begin[axis] = data.shape[axis] 152 end = list(buffer.shape[:]) 153 end[axis] += data.shape[axis] 154 return strided_slice(concatenate((buffer, data), axis=axis), begin=begin, end=end) 155 return None 156