1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "layernorm.h"
16 
17 #include <math.h>
18 
19 namespace ncnn {
20 
LayerNorm()21 LayerNorm::LayerNorm()
22 {
23     one_blob_only = true;
24     support_inplace = true;
25 }
26 
load_param(const ParamDict & pd)27 int LayerNorm::load_param(const ParamDict& pd)
28 {
29     affine_size = pd.get(0, 0);
30     eps = pd.get(1, 0.001f);
31     affine = pd.get(2, 1);
32 
33     return 0;
34 }
35 
load_model(const ModelBin & mb)36 int LayerNorm::load_model(const ModelBin& mb)
37 {
38     if (affine == 0)
39         return 0;
40 
41     gamma_data = mb.load(affine_size, 1);
42     if (gamma_data.empty())
43         return -100;
44 
45     beta_data = mb.load(affine_size, 1);
46     if (beta_data.empty())
47         return -100;
48 
49     return 0;
50 }
51 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const52 int LayerNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
53 {
54     // x = (x - mean) / sqrt(var + eps) * gamma + beta
55 
56     int dims = bottom_top_blob.dims;
57 
58     if (dims == 2)
59     {
60         int w = bottom_top_blob.w;
61         int h = bottom_top_blob.h;
62         // assert affine_size == w
63 
64         #pragma omp parallel for num_threads(opt.num_threads)
65         for (int i = 0; i < h; i++)
66         {
67             float* ptr = bottom_top_blob.row(i);
68 
69             // mean and var
70             float sum = 0.f;
71             float sqsum = 0.f;
72             for (int j = 0; j < w; j++)
73             {
74                 sum += ptr[j];
75                 //sqsum += ptr[j] * ptr[j];
76             }
77             float mean = sum / w;
78             float tmp = 0.f;
79             for (int j = 0; j < w; j++)
80             {
81                 tmp = ptr[j] - mean;
82                 sqsum += tmp * tmp;
83             }
84             float var = sqsum / w;
85             // the var maybe minus due to accuracy
86             //float var = sqsum / w - mean * mean;
87 
88             float a = static_cast<float>(1.f / (sqrt(var + eps)));
89             float b = -mean * a;
90 
91             if (affine)
92             {
93                 for (int j = 0; j < w; j++)
94                 {
95                     ptr[j] = ptr[j] * gamma_data[j] + beta_data[j];
96                 }
97             }
98             else
99             {
100                 for (int j = 0; j < w; j++)
101                 {
102                     ptr[j] = ptr[j] * a + b;
103                 }
104             }
105         }
106     }
107 
108     if (dims == 3)
109     {
110         int w = bottom_top_blob.w;
111         int h = bottom_top_blob.h;
112         int channels = bottom_top_blob.c;
113         int size = w * h;
114         // assert affine_size == size
115 
116         #pragma omp parallel for num_threads(opt.num_threads)
117         for (int q = 0; q < channels; q++)
118         {
119             float* ptr = bottom_top_blob.channel(q);
120 
121             // mean and var
122             float sum = 0.f;
123             float sqsum = 0.f;
124             for (int i = 0; i < size; i++)
125             {
126                 sum += ptr[i];
127                 //sqsum += ptr[i] * ptr[i];
128             }
129             float mean = sum / size;
130             float tmp = 0.f;
131             for (int i = 0; i < size; i++)
132             {
133                 tmp = ptr[i] - mean;
134                 sqsum += tmp * tmp;
135             }
136             float var = sqsum / size;
137             // the var maybe minus due to accuracy
138             //float var = sqsum / size - mean * mean;
139 
140             float a = static_cast<float>(1.f / (sqrt(var + eps)));
141             float b = -mean * a;
142 
143             if (affine)
144             {
145                 for (int i = 0; i < size; i++)
146                 {
147                     ptr[i] = ptr[i] * gamma_data[i] + beta_data[i];
148                 }
149             }
150             else
151             {
152                 for (int i = 0; i < size; i++)
153                 {
154                     ptr[i] = ptr[i] * a + b;
155                 }
156             }
157         }
158     }
159 
160     return 0;
161 }
162 
163 } // namespace ncnn
164