1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief local response normalization op constructions
22  * \file nn/local_response_norm.h
23  */
24 #ifndef TOPI_NN_LOCAL_RESPONSE_NORM_H_
25 #define TOPI_NN_LOCAL_RESPONSE_NORM_H_
26 
27 #include <string>
28 
29 #include "topi/tags.h"
30 #include "tvm/operation.h"
31 
32 namespace topi {
33 namespace nn {
34 using namespace tvm;
35 
36 /*!
37 * \brief Local response normalization inference operator
38 *
39 * \param data The input tensor. 4-D shape NCHW or NHWC
40 * \param size Integer to define normalisation window size
41 * \param axis Input data layout channel axis
42 * \param alpha Float scaling factor
43 * \param beta Exponent value
44 * \param bias Offset to avoid dividing by zero
45 * \param name The name of the operation
46 * \param tag The tag to mark the operation
47 *
48 * \return A Tensor whose op member is the Local response normalization operation
49 */
50 inline Tensor lrn(const Tensor& data,
51                   int size,
52                   int axis = 1,
53                   float alpha = 0.0001,
54                   float beta = 0.75,
55                   float bias = 2,
56                   std::string name = "tensor",
57                   std::string tag = kBroadcast) {
58   CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input";
59   CHECK_EQ(size % 2, 1) << "size should be odd number";
60   CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC";
61   auto input_shape = data->shape;
62   Array<Expr> pad_before{ 0, 0, 0, 0};
63   Array<Expr> pad_after{ 0, 0, 0, 0};
64   pad_before.Set(axis, static_cast<Expr>(size/2));
65   pad_after.Set(axis, static_cast<Expr>(size/2));
66   auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data");
67   auto rxs = tvm::reduce_axis(Range(0, size), "rxs");
68   Tensor sqr_sum;
69   if (axis == 1) {
70     sqr_sum = tvm::compute(input_shape,
71                            [&](Var i, Var l, Var j, Var k) {
72                            return tvm::sum(pad_data(i, l + rxs, j, k) *
73                                            pad_data(i, l + rxs, j, k),
74                                            {rxs});
75                            });
76   } else if (axis == 3) {
77     sqr_sum = tvm::compute(input_shape,
78                            [&](Var i, Var l, Var j, Var k) {
79                            return tvm::sum(pad_data(i, l, j, k + rxs) *
80                                            pad_data(i, l, j, k + rxs),
81                                            {rxs});
82                            });
83   }
84   auto sqrt_sum_up = tvm::compute(
85       input_shape,
86       [&](Var i, Var j, Var k, Var l) {
87         return tvm::pow(bias +
88                         (div(alpha * sqr_sum(i, j, k, l), size)),
89                         beta);
90       });
91   return topi::divide(data, sqrt_sum_up);
92 }
93 }  // namespace nn
94 }  // namespace topi
95 #endif  // TOPI_NN_LOCAL_RESPONSE_NORM_H_
96