1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \brief local response normalization op constructions 22 * \file nn/local_response_norm.h 23 */ 24 #ifndef TOPI_NN_LOCAL_RESPONSE_NORM_H_ 25 #define TOPI_NN_LOCAL_RESPONSE_NORM_H_ 26 27 #include <string> 28 29 #include "topi/tags.h" 30 #include "tvm/operation.h" 31 32 namespace topi { 33 namespace nn { 34 using namespace tvm; 35 36 /*! 37 * \brief Local response normalization inference operator 38 * 39 * \param data The input tensor. 4-D shape NCHW or NHWC 40 * \param size Integer to define normalisation window size 41 * \param axis Input data layout channel axis 42 * \param alpha Float scaling factor 43 * \param beta Exponent value 44 * \param bias Offset to avoid dividing by zero 45 * \param name The name of the operation 46 * \param tag The tag to mark the operation 47 * 48 * \return A Tensor whose op member is the Local response normalization operation 49 */ 50 inline Tensor lrn(const Tensor& data, 51 int size, 52 int axis = 1, 53 float alpha = 0.0001, 54 float beta = 0.75, 55 float bias = 2, 56 std::string name = "tensor", 57 std::string tag = kBroadcast) { 58 CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; 59 CHECK_EQ(size % 2, 1) << "size should be odd number"; 60 CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; 61 auto input_shape = data->shape; 62 Array<Expr> pad_before{ 0, 0, 0, 0}; 63 Array<Expr> pad_after{ 0, 0, 0, 0}; 64 pad_before.Set(axis, static_cast<Expr>(size/2)); 65 pad_after.Set(axis, static_cast<Expr>(size/2)); 66 auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data"); 67 auto rxs = tvm::reduce_axis(Range(0, size), "rxs"); 68 Tensor sqr_sum; 69 if (axis == 1) { 70 sqr_sum = tvm::compute(input_shape, 71 [&](Var i, Var l, Var j, Var k) { 72 return tvm::sum(pad_data(i, l + rxs, j, k) * 73 pad_data(i, l + rxs, j, k), 74 {rxs}); 75 }); 76 } else if (axis == 3) { 77 sqr_sum = tvm::compute(input_shape, 78 [&](Var i, Var l, Var j, Var k) { 79 return tvm::sum(pad_data(i, l, j, k + rxs) * 80 pad_data(i, l, j, k + rxs), 81 {rxs}); 82 }); 83 } 84 auto sqrt_sum_up = tvm::compute( 85 input_shape, 86 [&](Var i, Var j, Var k, Var l) { 87 return tvm::pow(bias + 88 (div(alpha * sqr_sum(i, j, k, l), size)), 89 beta); 90 }); 91 return topi::divide(data, sqrt_sum_up); 92 } 93 } // namespace nn 94 } // namespace topi 95 #endif // TOPI_NN_LOCAL_RESPONSE_NORM_H_ 96