1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file layer_norm.cu
22  * \brief Implements Ba et. al, Layer Normalization (https://arxiv.org/abs/1607.06450).
23 */
24 #include "./layer_norm-inl.h"
25 
26 using namespace mshadow::cuda;
27 
28 namespace mxnet {
29 namespace op {
30 
31 template <typename DType>
warp_shfl(DType value,int src_lane,int width=32,unsigned int mask=0xffffffff)32 __device__ __forceinline__ DType warp_shfl(DType value, int src_lane,
33                                            int width = 32, unsigned int mask = 0xffffffff) {
34 #if CUDA_VERSION >= 9000
35   return __shfl_sync(mask, value, src_lane, width);
36 #else
37   return __shfl(value, src_lane, width);
38 #endif
39 }
40 
41 template <typename DType>
warp_shfl_xor(DType value,int laneMask,int width=32,unsigned int mask=0xffffffff)42 __device__ __forceinline__ DType warp_shfl_xor(DType value, int laneMask,
43                                                int width = 32, unsigned int mask = 0xffffffff) {
44 #if CUDA_VERSION >= 9000
45   return __shfl_xor_sync(mask, value, laneMask, width);
46 #else
47   return __shfl_xor(value, laneMask, width);
48 #endif
49 }
50 
51 
52 /* A single updating step of the Welford's online algorithm to calculate the mean and variance.
53  * The value 'curr' will be accumulated to the (mean, sigma2, count) triplet.
54  *
55  */
56 template<typename DType, typename IType>
StepWelfordOnlineSum(const DType curr,DType & mean,DType & sigma2,IType & count)57 __device__ __forceinline__ void StepWelfordOnlineSum(const DType curr,
58                                                      DType& mean,         //NOLINT
59                                                      DType& sigma2,       //NOLINT
60                                                      IType& count) {      //NOLINT
61   count += IType(1);
62   DType delta = curr - mean;
63   mean += delta / count;
64   sigma2 += delta * (curr - mean);
65 }
66 
67 /* Merge the mean/variance of two partitions. It's the key step of the Chan's parallel algorithm.
68  * The (lhs_mean, lhs_sigma2, lhs_count) will be merged into (rhs_mean, rhs_sigma2, rhs_count)
69  *
70  * See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance for more details.
71  *
72  *  TODO(sxjscience) Explore the possibility of int lhs_count and rhs_count
73  */
74 template<typename DType, typename IType>
ChanMergePartition(const DType lhs_mean,const DType lhs_sigma2,const IType lhs_count,DType & rhs_mean,DType & rhs_sigma2,IType & rhs_count)75 __device__ __inline__ void ChanMergePartition(const DType lhs_mean,
76                                               const DType lhs_sigma2,
77                                               const IType lhs_count,
78                                               DType& rhs_mean,         //NOLINT
79                                               DType& rhs_sigma2,       //NOLINT
80                                               IType& rhs_count) {      //NOLINT
81   DType delta = rhs_mean - lhs_mean;
82   DType nA = static_cast<DType>(lhs_count);
83   DType nB = static_cast<DType>(rhs_count);
84   rhs_count = nA + nB;
85   if (rhs_count > DType(0)) {
86     nA = nA / rhs_count;
87     nB = nB / rhs_count;
88     rhs_mean = nA * lhs_mean + nB * rhs_mean;
89     rhs_sigma2 = rhs_sigma2 + lhs_sigma2 + delta * delta * nA * nB * rhs_count;
90   } else {
91     rhs_mean = DType(0);
92     rhs_sigma2 = DType(0);
93   }
94 }
95 
96 /* Split the input column into multiple partitions and compute the mean/sigma of each partition.
97  * Each thread will keep a mean/sigma2. The mean/sigma2 can be further merged to get the mean and
98  * sigma2 of the column.
99  */
100 template<typename AType, typename DType, typename IType>
BlockWelfordOnlineSum(const DType * __restrict__ col_vals,const int nchannel,AType & mean,AType & sigma2,IType & count)101 __device__ __forceinline__ void BlockWelfordOnlineSum(const DType* __restrict__ col_vals,
102                                                       const int nchannel,
103                                                       AType& mean,         //NOLINT
104                                                       AType& sigma2,       //NOLINT
105                                                       IType& count) {      //NOLINT
106   int tid = threadIdx.x + threadIdx.y * blockDim.x;
107   const int nthread = blockDim.x * blockDim.y;
108   // Each thread takes charge of 4 consecutive numbers. This should optimize the loading speed using
109   // vectorized types like float4.
110   // Also, to minimize branch divergence, we split the for-loop into two parts.
111   int l = 4 * tid;
112   for (; l + 3 < nchannel; l += 4 * nthread) {
113 #pragma unroll
114     for (int i = 0; i < 4; ++i) {
115       StepWelfordOnlineSum(static_cast<AType>(col_vals[l + i]), mean, sigma2, count);
116     }
117   }
118   for (; l < nchannel; ++l) {
119     StepWelfordOnlineSum(static_cast<AType>(col_vals[l]), mean, sigma2, count);
120   }
121 }
122 
123 template<>
124 __device__ __forceinline__
BlockWelfordOnlineSum(const mshadow::half::half_t * __restrict__ col_vals,const int nchannel,float & mean,float & sigma2,int & count)125 void BlockWelfordOnlineSum<float, mshadow::half::half_t, int>
126                                           (const mshadow::half::half_t* __restrict__ col_vals,
127                                            const int nchannel,
128                                            float& mean,                    //NOLINT
129                                            float& sigma2,                  //NOLINT
130                                            int& count) {                 //NOLINT
131   int tid = threadIdx.x + threadIdx.y * blockDim.x;
132   const int nthread = blockDim.x * blockDim.y;
133   // We cast the input half pointer to half2 to optimize the loading speed.
134   // Here, we need to notice that CUDA forces memory alignment, i.e.,
135   // ASSERT static_cast<size_t>(ptr) % sizeof(dtype) == 0.
136   // Thus, we need to shift the address of the half pointer to be aligned by half2.
137   int align_shift = (reinterpret_cast<size_t>(col_vals) % 4) != 0;
138   int padding = (nchannel - align_shift) % 2;
139   int half2_size = (nchannel - align_shift) / 2;
140   const __half2* half2_col_vals = reinterpret_cast<const __half2*>(col_vals + align_shift);
141   if (threadIdx.x == 0 && threadIdx.y == 0) {
142     if (align_shift) {
143       StepWelfordOnlineSum(__half2float(col_vals[0].cuhalf_), mean, sigma2, count);
144     }
145     if (padding) {
146       StepWelfordOnlineSum(__half2float(col_vals[nchannel - 1].cuhalf_), mean, sigma2, count);
147     }
148   }
149 
150   for (int l = tid; l < half2_size; l += nthread) {
151     float2 ele_val =  __half22float2(half2_col_vals[l]);
152     StepWelfordOnlineSum(ele_val.x, mean, sigma2, count);
153     StepWelfordOnlineSum(ele_val.y, mean, sigma2, count);
154   }
155 }
156 
157 /* Fused CUDA kernel for the forward pass of layer normalization.
158  * It computes the LayerNorm when axis=-1, i.e., contiguous reduction scenario.
159  * Shape of the input tensors:
160  *      in_data = (nbatch, nchannel)
161  *      gamma = (nchannel,)
162  *      beta = (nchannel,)
163  *      out_data = (nchannel,)
164  *      mean_data = (nbatch,)
165  *      var_data = (nbatch,)
166  *  It's always launched with (blockDim.x, blockDim.y) = (WARP_SIZE, blockDim.y)
167  *  Also, when blockDim.y > 1, it requires shared memory that has size:
168  *      sizeof(AType) * blockDim.y + sizeof(int) * blockDim.y / 2
169  */
170 template<typename AType, typename DType, typename IType>
LayerNormFusedForwardKernelContig(const int nbatch,const int nchannel,const AType eps,const DType * __restrict__ in_data,const DType * __restrict__ gamma,const DType * __restrict__ beta,DType * __restrict__ out_data,DType * __restrict__ mean_data,DType * __restrict__ std_data)171 __global__ void LayerNormFusedForwardKernelContig(const int nbatch,
172                                                   const int nchannel,
173                                                   const AType eps,
174                                                   const DType* __restrict__ in_data,
175                                                   const DType* __restrict__ gamma,
176                                                   const DType* __restrict__ beta,
177                                                   DType* __restrict__ out_data,
178                                                   DType* __restrict__ mean_data,
179                                                   DType* __restrict__ std_data) {
180   int bid = blockIdx.x + blockIdx.y * gridDim.x;
181   const int tid = threadIdx.y * blockDim.x + threadIdx.x;
182   const int nthread = blockDim.x * blockDim.y;
183   IType count = 0;
184   AType mean = 0;
185   AType sigma2 = 0;
186 
187   if (bid < nbatch) {
188     extern __shared__ char buf[];  // Shared memory
189     const DType* col_vals = in_data + bid * nchannel;
190     BlockWelfordOnlineSum(col_vals, nchannel, mean, sigma2, count);
191 
192     // Merge the mean/sigma2 within a warp
193     // Use the Chan's Parallel Algorithm to merge all (mean, sigma2, counts)
194     // within a warp of threads.
195     // After calling the function, threadIdx.x == 0 will store the result of
196     // the aggregated (mean, sigma2, counts).
197     for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) {
198       AType meanB = warp_shfl_xor(mean, mask);
199       AType sigma2B = warp_shfl_xor(sigma2, mask);
200       IType countB = warp_shfl_xor(count, mask);
201       ChanMergePartition(meanB, sigma2B, countB, mean, sigma2, count);
202     }
203     if (blockDim.y > 1) {
204       // Inter-warp reduction. Copy the upper-half of the warps to shared memory
205       // and merge with the lower-half warp
206       AType* mean_buf = reinterpret_cast<AType*>(buf);
207       AType* sigma2_buf =
208         reinterpret_cast<AType*>(buf + sizeof(AType) * blockDim.y / 2 * blockDim.x);
209       IType* count_buf = reinterpret_cast<IType*>(buf + sizeof(AType) * blockDim.y * blockDim.x);
210       for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
211         if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
212           const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
213           mean_buf[idx] = mean;
214           sigma2_buf[idx] = sigma2;
215           count_buf[idx] = count;
216         }
217         __syncthreads();
218         if (threadIdx.y < offset) {
219           const int idx = threadIdx.y * blockDim.x + threadIdx.x;
220           ChanMergePartition(mean_buf[idx], sigma2_buf[idx], count_buf[idx], mean, sigma2, count);
221         }
222         __syncthreads();
223       }
224       // Broadcast the result to all threads
225       if (threadIdx.y == 0) {
226         mean_buf[threadIdx.x] = mean;
227         sigma2_buf[threadIdx.x] = sigma2;
228       }
229       __syncthreads();
230       mean = mean_buf[threadIdx.x];
231       sigma2 = sigma2_buf[threadIdx.x] / nchannel;
232     } else {
233       sigma2 /= nchannel;
234     }
235     // Calculate the out_data: gamma * (x - mean) / sqrt(var + eps) + beta
236     AType std_eps = sqrt(sigma2 + eps);
237     AType invstd_eps = DType(1.0) / std_eps;
238     DType* out_col_val = out_data + bid * nchannel;
239 
240     if (gamma != nullptr && beta != nullptr) {
241       for (int i = tid; i < nchannel; i += nthread) {
242         out_col_val[i] = gamma[i] * static_cast<DType>(invstd_eps *
243                                                        (static_cast<AType>(col_vals[i]) - mean))
244                                                          + beta[i];
245       }
246     } else if (gamma == nullptr && beta != nullptr) {
247       for (int i = tid; i < nchannel; i += nthread) {
248         out_col_val[i] = static_cast<DType>(invstd_eps * (static_cast<AType>(col_vals[i]) - mean))
249                                                        + beta[i];
250       }
251     } else if (gamma != nullptr && beta == nullptr) {
252       for (int i = tid; i < nchannel; i += nthread) {
253         out_col_val[i] = gamma[i] * static_cast<DType>(invstd_eps *
254                                                        (static_cast<AType>(col_vals[i]) - mean));
255       }
256     } else {
257       for (int i = tid; i < nchannel; i += nthread) {
258         out_col_val[i] = static_cast<DType>(invstd_eps * (static_cast<AType>(col_vals[i]) - mean));
259       }
260     }
261     // Write the out_data and var_data
262     if (threadIdx.x == 0 && threadIdx.y == 0) {
263       mean_data[bid] = static_cast<DType>(mean);
264       std_data[bid] = static_cast<DType>(std_eps);
265     }
266   }
267 }
268 
269 template<bool safe_acc = false>
LayerNormGPUContig(const LayerNormParam param,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)270 void LayerNormGPUContig(const LayerNormParam param,
271                         const OpContext& ctx, const std::vector<TBlob>& inputs,
272                         const std::vector<OpReqType>& req,
273                         const std::vector<TBlob>& outputs) {
274   using namespace mshadow;
275   CHECK_EQ(inputs.size(), 3U);
276   mxnet::TShape data_shape(2, 0);
277   mxnet::TShape mean_shape(1, 0);
278   size_t in_ndim = inputs[layernorm::kData].ndim();
279   data_shape[0] = mean_shape[0] = inputs[layernorm::kData].shape_.ProdShape(0, in_ndim - 1);
280   data_shape[1] = inputs[layernorm::kData].shape_[in_ndim - 1];
281   const TBlob in_data = inputs[layernorm::kData].reshape(data_shape);
282   const TBlob gamma = inputs[layernorm::kGamma];
283   const TBlob beta = inputs[layernorm::kBeta];
284   const TBlob out_data = outputs[layernorm::kOut].reshape(data_shape);
285   const TBlob mean_data = outputs[layernorm::kMean].reshape(mean_shape);
286   const TBlob std_data = outputs[layernorm::kStd].reshape(mean_shape);
287   // Make sure the inputs are contiguous
288   CHECK_EQ(in_data.CheckContiguous(), true);
289   CHECK_EQ(gamma.CheckContiguous(), true);
290   CHECK_EQ(beta.CheckContiguous(), true);
291   CHECK_EQ(out_data.CheckContiguous(), true);
292   CHECK_EQ(mean_data.CheckContiguous(), true);
293   CHECK_EQ(std_data.CheckContiguous(), true);
294 
295   // Lauch the kernel. The dynamic shared memory size is
296   // sizeof(DType) * blockDim.y * blockDim.x + sizeof(DType) * blockDim.y / 2 * blockDim.x
297   int nbatch = data_shape[0];
298   int nchannel = data_shape[1];
299   float eps = param.eps;
300   int ngrid_x = (nbatch > kMaxGridDim) ? (nbatch + kBaseGridNum - 1) / kBaseGridNum : nbatch;
301   int ngrid_y = (nbatch > kMaxGridDim) ? kBaseGridNum : 1;
302   int nthread_y;
303   const dim3 dimGrid(ngrid_x, ngrid_y);
304   if (nchannel <= 128) {
305     nthread_y = 1;
306   } else if (nchannel <= 512) {
307     nthread_y = 2;
308   } else {
309     nthread_y = 4;
310   }
311   cudaStream_t stream = Stream<gpu>::GetStream(ctx.get_stream<gpu>());
312   const dim3 dimBlock(32, nthread_y);
313   MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
314     typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
315     int nshared = nthread_y > 1 ? nthread_y * 32 * sizeof(AType)
316                                   + (nthread_y / 2) * 32 * sizeof(int) : 0;
317     CheckLaunchParam(dimGrid, dimBlock);
318     LayerNormFusedForwardKernelContig<AType, DType, int> <<<dimGrid, dimBlock, nshared, stream>>>
319      (nbatch, nchannel, static_cast<AType>(eps),
320       in_data.dptr<DType>(), gamma.dptr<DType>(), beta.dptr<DType>(),
321       out_data.dptr<DType>(), mean_data.dptr<DType>(), std_data.dptr<DType>());
322   });
323   MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedForwardKernelContig);
324 }
325 
326 template<>
LayerNormCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)327 void LayerNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
328                            const OpContext& ctx, const std::vector<TBlob>& inputs,
329                            const std::vector<OpReqType>& req,
330                            const std::vector<TBlob>& outputs) {
331   const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
332   if (req[0] == kNullOp) return;
333   CHECK_NE(req[0], kAddTo);
334   int axis = param.axis;
335   if (axis < 0) {
336     axis += static_cast<int>(inputs[0].ndim());
337   }
338   CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
339   if (axis == inputs[0].ndim() - 1) {
340     // Try to use the accelerated CUDA kernels
341     bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
342     if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
343       common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LayerNorm with float16 inputs. "
344                       "See https://mxnet.apache.org/api/faq/env_var "
345                       "for more details.");
346     }
347     if (safe_acc) {
348       return LayerNormGPUContig<true>(param, ctx, inputs, req, outputs);
349     } else {
350       return LayerNormGPUContig<false>(param, ctx, inputs, req, outputs);
351     }
352   }
353   return LayerNormComputeGeneral<gpu>(attrs, ctx, inputs, req, outputs);
354 }
355 
356 
357 /* Fused CUDA kernel for calculating the gradient w.r.t gamma/beta in LayerNorm when axis=-1
358  * (Contiguous case).
359  * The gradient of gamma and beta are:
360  *   d_gamma = sum(out_grad * (x - mean) / std, axis=0)
361  *   d_beta = sum(out_grad, axis=0)
362  *
363  * We compute the gradient (mainly reduction over a non-contiguous axis) using two steps to
364  * improve the parallelism.
365  *
366  * In the first step, we divide the rows uniformly into K parts. K independent threadblocks are used
367  * to calculate the partial reduction result of each part. Illustrated below:
368  *
369  *      1st Block          2nd Block          3rd Block              k-th Block
370  * | --------------- | ---------------- | --------------- | ... | ---------------- |
371  * | --------------- | ---------------- | --------------- | ... | ---------------- |
372  * | --------------- | ---------------- | --------------- | ... | ---------------- |
373  * | --------------- | ---------------- | --------------- | ... | ---------------- |
374  *     part_gamma[0]     part_gamma[1]      part_gamma[2]           part_gamma[k-1]
375  *     part_beta[0]      part_beta[1]       part_beta[2]            part_beta[k-1]
376  *
377  *
378  * In the second step, we sum up the row-values in part_gamma and part_beta.
379  *
380  * This `LayerNormFusedBackwardKernel_PartGammaBeta` function implements the first step and
381  * `LayerNormFusedBackwardKernel_GammaBeta` implements the second step.
382  */
383 template<typename AType, typename DType>
LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch,const int nchannel,const DType * __restrict__ in_data,const DType * __restrict__ out_grad,const DType * __restrict__ mean_data,const DType * __restrict__ std_data,AType * __restrict__ part_gamma_grad,AType * __restrict__ part_beta_grad)384 __global__ void LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch,
385                                                            const int nchannel,
386                                                            const DType* __restrict__ in_data,
387                                                            const DType* __restrict__ out_grad,
388                                                            const DType* __restrict__ mean_data,
389                                                            const DType* __restrict__ std_data,
390                                                            AType* __restrict__ part_gamma_grad,
391                                                            AType* __restrict__ part_beta_grad) {
392   extern __shared__ char buf[];
393   AType* d_buf = reinterpret_cast<AType*>(buf);
394   const int npart = gridDim.y;
395   const int block_row_num = (nbatch + npart - 1) / npart;
396   // The rows are divided into `npart` parts. Each threadblock calculates the reduction result
397   // within the corresponding row ranges.
398   int row_stride = blockDim.x + 1;
399   const int c = blockIdx.x * blockDim.x + threadIdx.x;
400   int r_begin = blockIdx.y * block_row_num;
401   int r_end = min((blockIdx.y + 1) * block_row_num, nbatch);
402   AType* buf_gamma_grad = d_buf;
403   AType* buf_beta_grad = d_buf + blockDim.y * row_stride;
404   AType local_gamma_grad = 0;
405   AType local_beta_grad = 0;
406 
407   if (c < nchannel) {
408     for (int r_b = r_begin; r_b < r_end; r_b += blockDim.y) {
409       int r = r_b + threadIdx.y;
410       if (r < r_end) {
411         AType local_mean = static_cast<AType>(mean_data[r]);
412         AType local_std = static_cast<AType>(std_data[r]);
413         int read_idx = r * nchannel + c;
414         AType local_in_data = static_cast<AType>(in_data[read_idx]);
415         AType local_out_grad = static_cast<AType>(out_grad[read_idx]);
416         local_gamma_grad += (local_in_data - local_mean) / local_std * local_out_grad;
417         local_beta_grad += local_out_grad;
418       }
419     }
420   }
421   buf_gamma_grad[threadIdx.y * row_stride + threadIdx.x] = local_gamma_grad;
422   buf_beta_grad[threadIdx.y * row_stride + threadIdx.x] = local_beta_grad;
423   __syncthreads();
424   for (int offset = blockDim.y/2;  offset > 1;  offset >>= 1) {
425     if (threadIdx.y < offset) {
426       int idx1 = threadIdx.y * row_stride + threadIdx.x;
427       int idx2 = (threadIdx.y + offset) * row_stride + threadIdx.x;
428       buf_gamma_grad[idx1] += buf_gamma_grad[idx2];
429       buf_beta_grad[idx1] += buf_beta_grad[idx2];
430     }
431     __syncthreads();
432   }
433   if (threadIdx.y == 0 && c < nchannel) {
434     part_gamma_grad[blockIdx.y * nchannel + c] = buf_gamma_grad[threadIdx.x]
435                                                    + buf_gamma_grad[threadIdx.x + row_stride];
436     part_beta_grad[blockIdx.y * nchannel + c] = buf_beta_grad[threadIdx.x]
437                                                    + buf_beta_grad[threadIdx.x + row_stride];
438   }
439 }
440 
441 template<bool gamma_addto, bool beta_addto, typename AType, typename DType>
LayerNormFusedBackwardKernel_GammaBeta(const int nbatch,const int nchannel,const int npart,const AType * __restrict__ part_gamma_grad,const AType * __restrict__ part_beta_grad,DType * gamma_grad,DType * beta_grad)442 __global__ void LayerNormFusedBackwardKernel_GammaBeta(const int nbatch,
443                                                        const int nchannel,
444                                                        const int npart,
445                                                        const AType* __restrict__ part_gamma_grad,
446                                                        const AType* __restrict__ part_beta_grad,
447                                                        DType* gamma_grad,
448                                                        DType* beta_grad) {
449   const int c = blockIdx.x * blockDim.x + threadIdx.x;
450   const int tid = threadIdx.y * blockDim.x + threadIdx.x;
451   if (c < nchannel) {
452     extern __shared__ char buf[];
453     AType* buf_gamma_grad = reinterpret_cast<AType*>(buf);
454     AType* buf_beta_grad = reinterpret_cast<AType*>(buf) + blockDim.x * blockDim.y;
455     buf_gamma_grad[tid] = 0;
456     buf_beta_grad[tid] = 0;
457     for (int r = threadIdx.y; r < npart; r += blockDim.y) {
458       buf_gamma_grad[tid] += part_gamma_grad[r * nchannel + c];
459       buf_beta_grad[tid] += part_beta_grad[r * nchannel + c];
460     }
461     __syncthreads();
462     // Begin for inter-warp reduce
463     if (npart > 1) {
464       for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
465         if (threadIdx.y < offset) {
466           int idx1 = tid;
467           int idx2 = tid + offset * blockDim.x;
468           buf_gamma_grad[idx1] += buf_gamma_grad[idx2];
469           buf_beta_grad[idx1] += buf_beta_grad[idx2];
470         }
471         __syncthreads();
472       }
473     }
474     if (threadIdx.y == 0) {
475       if (gamma_grad) {
476         if (gamma_addto) {
477           gamma_grad[c] += static_cast<DType>(buf_gamma_grad[threadIdx.x]);
478         } else {
479           gamma_grad[c] = static_cast<DType>(buf_gamma_grad[threadIdx.x]);
480         }
481       }
482       if (beta_grad) {
483         if (beta_addto) {
484           beta_grad[c] += static_cast<DType>(buf_beta_grad[threadIdx.x]);
485         } else {
486           beta_grad[c] = static_cast<DType>(buf_beta_grad[threadIdx.x]);
487         }
488       }
489     }
490   }
491 }
492 
493 /*
494  *
495  *
496  */
497 template<int LOAD_UNROLL, bool data_addto, typename AType, typename DType>
LayerNormFusedBackwardKernel_Data(const int nbatch,const int nchannel,const DType * __restrict__ in_data,const DType * __restrict__ out_grad,const DType * __restrict__ mean_data,const DType * __restrict__ std_data,const DType * __restrict__ gamma,DType * data_grad)498 __global__ void LayerNormFusedBackwardKernel_Data(const int nbatch,
499                                                   const int nchannel,
500                                                   const DType* __restrict__ in_data,
501                                                   const DType* __restrict__ out_grad,
502                                                   const DType* __restrict__ mean_data,
503                                                   const DType* __restrict__ std_data,
504                                                   const DType* __restrict__ gamma,
505                                                   DType* data_grad) {
506   int bid = blockIdx.x + blockIdx.y * gridDim.x;
507   const int nthread = blockDim.x * blockDim.y;
508   if (bid < nbatch) {
509     // Shared memory with size blockDim.y * blockDim.x * sizeof(DType)
510     extern __shared__ char buf[];
511     int tid = threadIdx.x + threadIdx.y * blockDim.x;
512     // 1. Calculate: mean(out_grad * gamma / std, axis=-1)
513     //               mean(out_grad * gamma / std * (x - mean) / std, axis=-1)
514     AType sum_val0 = 0;  // Stores mean(out_grad * gamma / std, axis=-1)
515     AType sum_val1 = 0;  // Stores mean(out_grad * gamma / std * (x - mean) / std, axis=-1)
516     AType mean = static_cast<AType>(mean_data[bid]);
517     AType invstd_eps = AType(1) / static_cast<AType>(std_data[bid]);
518     int l = LOAD_UNROLL * tid;
519     for (; l + LOAD_UNROLL - 1 < nchannel; l += nthread * LOAD_UNROLL) {
520 #pragma unroll
521       for (int i = 0; i < LOAD_UNROLL; ++i) {
522         AType ele_og = static_cast<AType>(out_grad[bid * nchannel + l + i]);
523         AType ele_x = static_cast<AType>(in_data[bid * nchannel + l + i]);
524         AType ele_gamma = static_cast<AType>(gamma[l + i]);
525         sum_val0 += ele_og * ele_gamma * invstd_eps;
526         sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps;
527       }
528     }
529     for (; l < nchannel; ++l) {
530       AType ele_og = static_cast<AType>(out_grad[bid * nchannel + l]);
531       AType ele_x = static_cast<AType>(in_data[bid * nchannel + l]);
532       AType ele_gamma = static_cast<AType>(gamma[l]);
533       sum_val0 += ele_og * ele_gamma * invstd_eps;
534       sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps;
535     }
536     // Intra-warp reduction (all-reduce)
537     for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) {
538       sum_val0 += warp_shfl_xor(sum_val0, mask);
539       sum_val1 += warp_shfl_xor(sum_val1, mask);
540     }
541     // Inter-warp reduction (all-reduce)
542     if (blockDim.y > 1) {
543       AType* sum_val0_buf = reinterpret_cast<AType*>(buf);
544       AType* sum_val1_buf =
545         reinterpret_cast<AType*>(buf + blockDim.y / 2 * blockDim.x * sizeof(AType));
546       for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
547         if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
548           const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
549           sum_val0_buf[idx] = sum_val0;
550           sum_val1_buf[idx] = sum_val1;
551         }
552         __syncthreads();
553         if (threadIdx.y < offset) {
554           const int idx = threadIdx.y * blockDim.x + threadIdx.x;
555           sum_val0 += sum_val0_buf[idx];
556           sum_val1 += sum_val1_buf[idx];
557         }
558         __syncthreads();
559       }
560       if (threadIdx.y == 0) {
561         sum_val0_buf[threadIdx.x] = sum_val0;
562         sum_val1_buf[threadIdx.x] = sum_val1;
563       }
564       __syncthreads();
565       sum_val0 = sum_val0_buf[threadIdx.x];
566       sum_val1 = sum_val1_buf[threadIdx.x];
567     }
568     sum_val0 /= nchannel;
569     sum_val1 /= nchannel;
570     // 2. Calculate the gradient as
571     //      out_grad * gamma / std - sum_val0 - (x - mean) / std * sum_val1
572     for (int l = tid; l < nchannel; l += nthread) {
573       AType ele_out_grad = static_cast<AType>(out_grad[bid * nchannel + l]);
574       AType ele_x = static_cast<AType>(in_data[bid * nchannel + l]);
575       AType ele_gamma = static_cast<AType>(gamma[l]);
576       if (data_addto) {
577         data_grad[bid * nchannel + l] +=
578           static_cast<DType>(ele_out_grad * ele_gamma * invstd_eps
579                                - sum_val0 - (ele_x - mean) * invstd_eps * sum_val1);
580       } else {
581         data_grad[bid * nchannel + l] =
582           static_cast<DType>(ele_out_grad * ele_gamma * invstd_eps - sum_val0
583                                                - (ele_x - mean) * invstd_eps * sum_val1);
584       }
585     }
586   }
587 }
588 
GetGammaBetaGradKernelParams(const int nbatch,const int nchannel,dim3 * part_grad_block_dim,dim3 * part_grad_grid_dim,dim3 * gb_block_dim,dim3 * gb_grid_dim,int * npart)589 void GetGammaBetaGradKernelParams(const int nbatch, const int nchannel,
590                                   dim3* part_grad_block_dim, dim3* part_grad_grid_dim,
591                                   dim3* gb_block_dim, dim3* gb_grid_dim,
592                                   int* npart) {
593   *npart = 16;
594   *part_grad_block_dim = dim3(32, 16);
595   *part_grad_grid_dim = dim3((nchannel + 32 - 1) / 32, *npart);
596   *gb_block_dim = dim3(32, *npart);
597   *gb_grid_dim = dim3((nchannel + 32 - 1) / 32);
598   CheckLaunchParam(*part_grad_grid_dim, *part_grad_block_dim);
599   CheckLaunchParam(*gb_grid_dim, *gb_block_dim);
600 }
601 
602 template<bool safe_acc = false>
LayerNormGradGPUContig(const LayerNormParam param,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)603 void LayerNormGradGPUContig(const LayerNormParam param,
604                             const OpContext& ctx, const std::vector<TBlob>& inputs,
605                             const std::vector<OpReqType>& req,
606                             const std::vector<TBlob>& outputs) {
607   using namespace mshadow;
608   CHECK_EQ(inputs.size(), 5U);
609   const TBlob out_grad = inputs[0];
610   const TBlob in_data = inputs[1];
611   const TBlob gamma = inputs[2];
612   const TBlob mean_data = inputs[3];
613   const TBlob std_data = inputs[4];
614   const TBlob data_grad = outputs[0];
615   const TBlob gamma_grad = outputs[1];
616   const TBlob beta_grad = outputs[2];
617 
618   // Make sure the inputs are contiguous
619   CHECK_EQ(out_grad.CheckContiguous(), true);
620   CHECK_EQ(in_data.CheckContiguous(), true);
621   CHECK_EQ(gamma.CheckContiguous(), true);
622   CHECK_EQ(mean_data.CheckContiguous(), true);
623   CHECK_EQ(std_data.CheckContiguous(), true);
624   int nbatch = in_data.shape_.ProdShape(0, in_data.ndim() - 1);
625   int nchannel = in_data.shape_[in_data.ndim() - 1];
626   int data_grad_req = req[0];
627   int gamma_grad_req = req[1];
628   int beta_grad_req = req[2];
629   CHECK_NE(data_grad_req, kWriteInplace);
630   CHECK_NE(gamma_grad_req, kWriteInplace);
631   CHECK_NE(beta_grad_req, kWriteInplace);
632   Stream<gpu> *s = ctx.get_stream<gpu>();
633   cudaStream_t stream = Stream<gpu>::GetStream(s);
634 
635   // Calculate the gradient for gamma/beta
636   CHECK_EQ(gamma_grad.CheckContiguous(), true);
637   CHECK_EQ(beta_grad.CheckContiguous(), true);
638   dim3 part_grad_block_dim, part_grad_grid_dim, gb_block_dim, gb_grid_dim;
639   int npart;
640   GetGammaBetaGradKernelParams(nbatch, nchannel, &part_grad_block_dim, &part_grad_grid_dim,
641                                &gb_block_dim, &gb_grid_dim, &npart);
642   if (gamma_grad_req != kNullOp || beta_grad_req != kNullOp) {
643     MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
644       typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
645       Tensor<gpu, 1, AType> workspace =
646         ctx.requested[0].get_space_typed<gpu, 1, AType>(Shape1(2 * npart * nchannel), s);
647       AType* part_gamma_grad_ptr = workspace.dptr_;
648       AType* part_beta_grad_ptr = workspace.dptr_ + npart * nchannel;
649       const int nshared_K1 = 2 * (part_grad_block_dim.x + 1)
650                                * part_grad_block_dim.y * sizeof(AType);
651       const int nshared_K2 = 2 * gb_block_dim.x * gb_block_dim.y * sizeof(AType);
652       DType* gamma_grad_ptr = (gamma_grad_req != kNullOp) ? gamma_grad.dptr<DType>() : nullptr;
653       DType* beta_grad_ptr = (beta_grad_req != kNullOp) ? beta_grad.dptr<DType>() : nullptr;
654       LayerNormFusedBackwardKernel_PartGammaBeta
655         <<<part_grad_grid_dim, part_grad_block_dim, nshared_K1, stream>>>
656         (nbatch, nchannel, in_data.dptr<DType>(), out_grad.dptr<DType>(),
657          mean_data.dptr<DType>(), std_data.dptr<DType>(), part_gamma_grad_ptr, part_beta_grad_ptr);
658       MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_PartGammaBeta);
659       if (gamma_grad_req == kAddTo && beta_grad_req != kAddTo) {
660         LayerNormFusedBackwardKernel_GammaBeta<true, false>
661           <<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>
662           (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr,
663            gamma_grad_ptr, beta_grad_ptr);
664       } else if (gamma_grad_req != kAddTo && beta_grad_req == kAddTo) {
665         LayerNormFusedBackwardKernel_GammaBeta<false, true>
666           <<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>
667           (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr,
668             gamma_grad_ptr, beta_grad_ptr);
669       } else if (gamma_grad_req == kAddTo && beta_grad_req == kAddTo) {
670         LayerNormFusedBackwardKernel_GammaBeta<true, true>
671           <<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>
672           (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr,
673             gamma_grad_ptr, beta_grad_ptr);
674       } else {
675         LayerNormFusedBackwardKernel_GammaBeta<false, false>
676           <<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>
677           (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr,
678             gamma_grad_ptr, beta_grad_ptr);
679       }
680     });
681     MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_GammaBeta);
682   }
683 
684   // Calculate the gradient for data
685   CHECK_EQ(data_grad.CheckContiguous(), true);
686   int ngrid_x = (nbatch > kMaxGridDim) ? (nbatch + kBaseGridNum - 1) / kBaseGridNum : nbatch;
687   int ngrid_y = (nbatch > kMaxGridDim) ? kBaseGridNum : 1;
688   const dim3 data_grid_dim(ngrid_x, ngrid_y);
689   int nthread_y;
690   if (nchannel <= 32) {
691     nthread_y = 1;
692   } else if (nchannel <= 128) {
693     nthread_y = 2;
694   } else if (nchannel <= 512) {
695     nthread_y = 4;
696   } else {
697     nthread_y = 8;
698   }
699   const dim3 data_block_dim(32, nthread_y);
700   const int LOAD_UNROLL = 4;
701   if (data_grad_req != kNullOp) {
702     MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
703       typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
704       int nshared = data_block_dim.y > 1 ? data_block_dim.y * data_block_dim.x * sizeof(AType) : 0;
705       CheckLaunchParam(data_grid_dim, data_block_dim);
706       if (data_grad_req == kAddTo) {
707         LayerNormFusedBackwardKernel_Data<LOAD_UNROLL, true, AType>
708           <<<data_grid_dim, data_block_dim, nshared, stream>>>
709           (nbatch, nchannel, in_data.dptr<DType>(), out_grad.dptr<DType>(), mean_data.dptr<DType>(),
710            std_data.dptr<DType>(), gamma.dptr<DType>(), data_grad.dptr<DType>());
711       } else {
712         LayerNormFusedBackwardKernel_Data<LOAD_UNROLL, false, AType>
713           <<<data_grid_dim, data_block_dim, nshared, stream>>>
714           (nbatch, nchannel, in_data.dptr<DType>(), out_grad.dptr<DType>(), mean_data.dptr<DType>(),
715            std_data.dptr<DType>(), gamma.dptr<DType>(), data_grad.dptr<DType>());
716       }
717     });
718     MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_Data);
719   }
720 }
721 
722 template<>
LayerNormGradCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)723 void LayerNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
724                                const OpContext& ctx, const std::vector<TBlob>& inputs,
725                                const std::vector<OpReqType>& req,
726                                const std::vector<TBlob>& outputs) {
727   const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
728   int axis = param.axis;
729   if (axis < 0) {
730     axis += static_cast<int>(inputs[0].ndim());
731   }
732   CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
733   if (axis == inputs[0].ndim() - 1) {
734     // Use the accelerated CUDA kernels
735     bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
736     if (safe_acc) {
737       return LayerNormGradGPUContig<true>(param, ctx, inputs, req, outputs);
738     } else {
739       return LayerNormGradGPUContig<false>(param, ctx, inputs, req, outputs);
740     }
741   }
742   return LayerNormGradComputeGeneral<gpu>(attrs, ctx, inputs, req, outputs);
743 }
744 
745 
746 NNVM_REGISTER_OP(LayerNorm)
747 .set_attr<FCompute>("FCompute<gpu>", LayerNormCompute<gpu>);
748 
749 NNVM_REGISTER_OP(_backward_LayerNorm)
750 .set_attr<FCompute>("FCompute<gpu>", LayerNormGradCompute<gpu>);
751 
752 }  // namespace op
753 }  // namespace mxnet
754