1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file layer_norm.cu
22 * \brief Implements Ba et. al, Layer Normalization (https://arxiv.org/abs/1607.06450).
23 */
24 #include "./layer_norm-inl.h"
25
26 using namespace mshadow::cuda;
27
28 namespace mxnet {
29 namespace op {
30
31 template <typename DType>
warp_shfl(DType value,int src_lane,int width=32,unsigned int mask=0xffffffff)32 __device__ __forceinline__ DType warp_shfl(DType value, int src_lane,
33 int width = 32, unsigned int mask = 0xffffffff) {
34 #if CUDA_VERSION >= 9000
35 return __shfl_sync(mask, value, src_lane, width);
36 #else
37 return __shfl(value, src_lane, width);
38 #endif
39 }
40
41 template <typename DType>
warp_shfl_xor(DType value,int laneMask,int width=32,unsigned int mask=0xffffffff)42 __device__ __forceinline__ DType warp_shfl_xor(DType value, int laneMask,
43 int width = 32, unsigned int mask = 0xffffffff) {
44 #if CUDA_VERSION >= 9000
45 return __shfl_xor_sync(mask, value, laneMask, width);
46 #else
47 return __shfl_xor(value, laneMask, width);
48 #endif
49 }
50
51
52 /* A single updating step of the Welford's online algorithm to calculate the mean and variance.
53 * The value 'curr' will be accumulated to the (mean, sigma2, count) triplet.
54 *
55 */
56 template<typename DType, typename IType>
StepWelfordOnlineSum(const DType curr,DType & mean,DType & sigma2,IType & count)57 __device__ __forceinline__ void StepWelfordOnlineSum(const DType curr,
58 DType& mean, //NOLINT
59 DType& sigma2, //NOLINT
60 IType& count) { //NOLINT
61 count += IType(1);
62 DType delta = curr - mean;
63 mean += delta / count;
64 sigma2 += delta * (curr - mean);
65 }
66
67 /* Merge the mean/variance of two partitions. It's the key step of the Chan's parallel algorithm.
68 * The (lhs_mean, lhs_sigma2, lhs_count) will be merged into (rhs_mean, rhs_sigma2, rhs_count)
69 *
70 * See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance for more details.
71 *
72 * TODO(sxjscience) Explore the possibility of int lhs_count and rhs_count
73 */
74 template<typename DType, typename IType>
ChanMergePartition(const DType lhs_mean,const DType lhs_sigma2,const IType lhs_count,DType & rhs_mean,DType & rhs_sigma2,IType & rhs_count)75 __device__ __inline__ void ChanMergePartition(const DType lhs_mean,
76 const DType lhs_sigma2,
77 const IType lhs_count,
78 DType& rhs_mean, //NOLINT
79 DType& rhs_sigma2, //NOLINT
80 IType& rhs_count) { //NOLINT
81 DType delta = rhs_mean - lhs_mean;
82 DType nA = static_cast<DType>(lhs_count);
83 DType nB = static_cast<DType>(rhs_count);
84 rhs_count = nA + nB;
85 if (rhs_count > DType(0)) {
86 nA = nA / rhs_count;
87 nB = nB / rhs_count;
88 rhs_mean = nA * lhs_mean + nB * rhs_mean;
89 rhs_sigma2 = rhs_sigma2 + lhs_sigma2 + delta * delta * nA * nB * rhs_count;
90 } else {
91 rhs_mean = DType(0);
92 rhs_sigma2 = DType(0);
93 }
94 }
95
96 /* Split the input column into multiple partitions and compute the mean/sigma of each partition.
97 * Each thread will keep a mean/sigma2. The mean/sigma2 can be further merged to get the mean and
98 * sigma2 of the column.
99 */
100 template<typename AType, typename DType, typename IType>
BlockWelfordOnlineSum(const DType * __restrict__ col_vals,const int nchannel,AType & mean,AType & sigma2,IType & count)101 __device__ __forceinline__ void BlockWelfordOnlineSum(const DType* __restrict__ col_vals,
102 const int nchannel,
103 AType& mean, //NOLINT
104 AType& sigma2, //NOLINT
105 IType& count) { //NOLINT
106 int tid = threadIdx.x + threadIdx.y * blockDim.x;
107 const int nthread = blockDim.x * blockDim.y;
108 // Each thread takes charge of 4 consecutive numbers. This should optimize the loading speed using
109 // vectorized types like float4.
110 // Also, to minimize branch divergence, we split the for-loop into two parts.
111 int l = 4 * tid;
112 for (; l + 3 < nchannel; l += 4 * nthread) {
113 #pragma unroll
114 for (int i = 0; i < 4; ++i) {
115 StepWelfordOnlineSum(static_cast<AType>(col_vals[l + i]), mean, sigma2, count);
116 }
117 }
118 for (; l < nchannel; ++l) {
119 StepWelfordOnlineSum(static_cast<AType>(col_vals[l]), mean, sigma2, count);
120 }
121 }
122
123 template<>
124 __device__ __forceinline__
BlockWelfordOnlineSum(const mshadow::half::half_t * __restrict__ col_vals,const int nchannel,float & mean,float & sigma2,int & count)125 void BlockWelfordOnlineSum<float, mshadow::half::half_t, int>
126 (const mshadow::half::half_t* __restrict__ col_vals,
127 const int nchannel,
128 float& mean, //NOLINT
129 float& sigma2, //NOLINT
130 int& count) { //NOLINT
131 int tid = threadIdx.x + threadIdx.y * blockDim.x;
132 const int nthread = blockDim.x * blockDim.y;
133 // We cast the input half pointer to half2 to optimize the loading speed.
134 // Here, we need to notice that CUDA forces memory alignment, i.e.,
135 // ASSERT static_cast<size_t>(ptr) % sizeof(dtype) == 0.
136 // Thus, we need to shift the address of the half pointer to be aligned by half2.
137 int align_shift = (reinterpret_cast<size_t>(col_vals) % 4) != 0;
138 int padding = (nchannel - align_shift) % 2;
139 int half2_size = (nchannel - align_shift) / 2;
140 const __half2* half2_col_vals = reinterpret_cast<const __half2*>(col_vals + align_shift);
141 if (threadIdx.x == 0 && threadIdx.y == 0) {
142 if (align_shift) {
143 StepWelfordOnlineSum(__half2float(col_vals[0].cuhalf_), mean, sigma2, count);
144 }
145 if (padding) {
146 StepWelfordOnlineSum(__half2float(col_vals[nchannel - 1].cuhalf_), mean, sigma2, count);
147 }
148 }
149
150 for (int l = tid; l < half2_size; l += nthread) {
151 float2 ele_val = __half22float2(half2_col_vals[l]);
152 StepWelfordOnlineSum(ele_val.x, mean, sigma2, count);
153 StepWelfordOnlineSum(ele_val.y, mean, sigma2, count);
154 }
155 }
156
157 /* Fused CUDA kernel for the forward pass of layer normalization.
158 * It computes the LayerNorm when axis=-1, i.e., contiguous reduction scenario.
159 * Shape of the input tensors:
160 * in_data = (nbatch, nchannel)
161 * gamma = (nchannel,)
162 * beta = (nchannel,)
163 * out_data = (nchannel,)
164 * mean_data = (nbatch,)
165 * var_data = (nbatch,)
166 * It's always launched with (blockDim.x, blockDim.y) = (WARP_SIZE, blockDim.y)
167 * Also, when blockDim.y > 1, it requires shared memory that has size:
168 * sizeof(AType) * blockDim.y + sizeof(int) * blockDim.y / 2
169 */
170 template<typename AType, typename DType, typename IType>
LayerNormFusedForwardKernelContig(const int nbatch,const int nchannel,const AType eps,const DType * __restrict__ in_data,const DType * __restrict__ gamma,const DType * __restrict__ beta,DType * __restrict__ out_data,DType * __restrict__ mean_data,DType * __restrict__ std_data)171 __global__ void LayerNormFusedForwardKernelContig(const int nbatch,
172 const int nchannel,
173 const AType eps,
174 const DType* __restrict__ in_data,
175 const DType* __restrict__ gamma,
176 const DType* __restrict__ beta,
177 DType* __restrict__ out_data,
178 DType* __restrict__ mean_data,
179 DType* __restrict__ std_data) {
180 int bid = blockIdx.x + blockIdx.y * gridDim.x;
181 const int tid = threadIdx.y * blockDim.x + threadIdx.x;
182 const int nthread = blockDim.x * blockDim.y;
183 IType count = 0;
184 AType mean = 0;
185 AType sigma2 = 0;
186
187 if (bid < nbatch) {
188 extern __shared__ char buf[]; // Shared memory
189 const DType* col_vals = in_data + bid * nchannel;
190 BlockWelfordOnlineSum(col_vals, nchannel, mean, sigma2, count);
191
192 // Merge the mean/sigma2 within a warp
193 // Use the Chan's Parallel Algorithm to merge all (mean, sigma2, counts)
194 // within a warp of threads.
195 // After calling the function, threadIdx.x == 0 will store the result of
196 // the aggregated (mean, sigma2, counts).
197 for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) {
198 AType meanB = warp_shfl_xor(mean, mask);
199 AType sigma2B = warp_shfl_xor(sigma2, mask);
200 IType countB = warp_shfl_xor(count, mask);
201 ChanMergePartition(meanB, sigma2B, countB, mean, sigma2, count);
202 }
203 if (blockDim.y > 1) {
204 // Inter-warp reduction. Copy the upper-half of the warps to shared memory
205 // and merge with the lower-half warp
206 AType* mean_buf = reinterpret_cast<AType*>(buf);
207 AType* sigma2_buf =
208 reinterpret_cast<AType*>(buf + sizeof(AType) * blockDim.y / 2 * blockDim.x);
209 IType* count_buf = reinterpret_cast<IType*>(buf + sizeof(AType) * blockDim.y * blockDim.x);
210 for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
211 if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
212 const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
213 mean_buf[idx] = mean;
214 sigma2_buf[idx] = sigma2;
215 count_buf[idx] = count;
216 }
217 __syncthreads();
218 if (threadIdx.y < offset) {
219 const int idx = threadIdx.y * blockDim.x + threadIdx.x;
220 ChanMergePartition(mean_buf[idx], sigma2_buf[idx], count_buf[idx], mean, sigma2, count);
221 }
222 __syncthreads();
223 }
224 // Broadcast the result to all threads
225 if (threadIdx.y == 0) {
226 mean_buf[threadIdx.x] = mean;
227 sigma2_buf[threadIdx.x] = sigma2;
228 }
229 __syncthreads();
230 mean = mean_buf[threadIdx.x];
231 sigma2 = sigma2_buf[threadIdx.x] / nchannel;
232 } else {
233 sigma2 /= nchannel;
234 }
235 // Calculate the out_data: gamma * (x - mean) / sqrt(var + eps) + beta
236 AType std_eps = sqrt(sigma2 + eps);
237 AType invstd_eps = DType(1.0) / std_eps;
238 DType* out_col_val = out_data + bid * nchannel;
239
240 if (gamma != nullptr && beta != nullptr) {
241 for (int i = tid; i < nchannel; i += nthread) {
242 out_col_val[i] = gamma[i] * static_cast<DType>(invstd_eps *
243 (static_cast<AType>(col_vals[i]) - mean))
244 + beta[i];
245 }
246 } else if (gamma == nullptr && beta != nullptr) {
247 for (int i = tid; i < nchannel; i += nthread) {
248 out_col_val[i] = static_cast<DType>(invstd_eps * (static_cast<AType>(col_vals[i]) - mean))
249 + beta[i];
250 }
251 } else if (gamma != nullptr && beta == nullptr) {
252 for (int i = tid; i < nchannel; i += nthread) {
253 out_col_val[i] = gamma[i] * static_cast<DType>(invstd_eps *
254 (static_cast<AType>(col_vals[i]) - mean));
255 }
256 } else {
257 for (int i = tid; i < nchannel; i += nthread) {
258 out_col_val[i] = static_cast<DType>(invstd_eps * (static_cast<AType>(col_vals[i]) - mean));
259 }
260 }
261 // Write the out_data and var_data
262 if (threadIdx.x == 0 && threadIdx.y == 0) {
263 mean_data[bid] = static_cast<DType>(mean);
264 std_data[bid] = static_cast<DType>(std_eps);
265 }
266 }
267 }
268
269 template<bool safe_acc = false>
LayerNormGPUContig(const LayerNormParam param,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)270 void LayerNormGPUContig(const LayerNormParam param,
271 const OpContext& ctx, const std::vector<TBlob>& inputs,
272 const std::vector<OpReqType>& req,
273 const std::vector<TBlob>& outputs) {
274 using namespace mshadow;
275 CHECK_EQ(inputs.size(), 3U);
276 mxnet::TShape data_shape(2, 0);
277 mxnet::TShape mean_shape(1, 0);
278 size_t in_ndim = inputs[layernorm::kData].ndim();
279 data_shape[0] = mean_shape[0] = inputs[layernorm::kData].shape_.ProdShape(0, in_ndim - 1);
280 data_shape[1] = inputs[layernorm::kData].shape_[in_ndim - 1];
281 const TBlob in_data = inputs[layernorm::kData].reshape(data_shape);
282 const TBlob gamma = inputs[layernorm::kGamma];
283 const TBlob beta = inputs[layernorm::kBeta];
284 const TBlob out_data = outputs[layernorm::kOut].reshape(data_shape);
285 const TBlob mean_data = outputs[layernorm::kMean].reshape(mean_shape);
286 const TBlob std_data = outputs[layernorm::kStd].reshape(mean_shape);
287 // Make sure the inputs are contiguous
288 CHECK_EQ(in_data.CheckContiguous(), true);
289 CHECK_EQ(gamma.CheckContiguous(), true);
290 CHECK_EQ(beta.CheckContiguous(), true);
291 CHECK_EQ(out_data.CheckContiguous(), true);
292 CHECK_EQ(mean_data.CheckContiguous(), true);
293 CHECK_EQ(std_data.CheckContiguous(), true);
294
295 // Lauch the kernel. The dynamic shared memory size is
296 // sizeof(DType) * blockDim.y * blockDim.x + sizeof(DType) * blockDim.y / 2 * blockDim.x
297 int nbatch = data_shape[0];
298 int nchannel = data_shape[1];
299 float eps = param.eps;
300 int ngrid_x = (nbatch > kMaxGridDim) ? (nbatch + kBaseGridNum - 1) / kBaseGridNum : nbatch;
301 int ngrid_y = (nbatch > kMaxGridDim) ? kBaseGridNum : 1;
302 int nthread_y;
303 const dim3 dimGrid(ngrid_x, ngrid_y);
304 if (nchannel <= 128) {
305 nthread_y = 1;
306 } else if (nchannel <= 512) {
307 nthread_y = 2;
308 } else {
309 nthread_y = 4;
310 }
311 cudaStream_t stream = Stream<gpu>::GetStream(ctx.get_stream<gpu>());
312 const dim3 dimBlock(32, nthread_y);
313 MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
314 typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
315 int nshared = nthread_y > 1 ? nthread_y * 32 * sizeof(AType)
316 + (nthread_y / 2) * 32 * sizeof(int) : 0;
317 CheckLaunchParam(dimGrid, dimBlock);
318 LayerNormFusedForwardKernelContig<AType, DType, int> <<<dimGrid, dimBlock, nshared, stream>>>
319 (nbatch, nchannel, static_cast<AType>(eps),
320 in_data.dptr<DType>(), gamma.dptr<DType>(), beta.dptr<DType>(),
321 out_data.dptr<DType>(), mean_data.dptr<DType>(), std_data.dptr<DType>());
322 });
323 MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedForwardKernelContig);
324 }
325
326 template<>
LayerNormCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)327 void LayerNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
328 const OpContext& ctx, const std::vector<TBlob>& inputs,
329 const std::vector<OpReqType>& req,
330 const std::vector<TBlob>& outputs) {
331 const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
332 if (req[0] == kNullOp) return;
333 CHECK_NE(req[0], kAddTo);
334 int axis = param.axis;
335 if (axis < 0) {
336 axis += static_cast<int>(inputs[0].ndim());
337 }
338 CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
339 if (axis == inputs[0].ndim() - 1) {
340 // Try to use the accelerated CUDA kernels
341 bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
342 if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
343 common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LayerNorm with float16 inputs. "
344 "See https://mxnet.apache.org/api/faq/env_var "
345 "for more details.");
346 }
347 if (safe_acc) {
348 return LayerNormGPUContig<true>(param, ctx, inputs, req, outputs);
349 } else {
350 return LayerNormGPUContig<false>(param, ctx, inputs, req, outputs);
351 }
352 }
353 return LayerNormComputeGeneral<gpu>(attrs, ctx, inputs, req, outputs);
354 }
355
356
357 /* Fused CUDA kernel for calculating the gradient w.r.t gamma/beta in LayerNorm when axis=-1
358 * (Contiguous case).
359 * The gradient of gamma and beta are:
360 * d_gamma = sum(out_grad * (x - mean) / std, axis=0)
361 * d_beta = sum(out_grad, axis=0)
362 *
363 * We compute the gradient (mainly reduction over a non-contiguous axis) using two steps to
364 * improve the parallelism.
365 *
366 * In the first step, we divide the rows uniformly into K parts. K independent threadblocks are used
367 * to calculate the partial reduction result of each part. Illustrated below:
368 *
369 * 1st Block 2nd Block 3rd Block k-th Block
370 * | --------------- | ---------------- | --------------- | ... | ---------------- |
371 * | --------------- | ---------------- | --------------- | ... | ---------------- |
372 * | --------------- | ---------------- | --------------- | ... | ---------------- |
373 * | --------------- | ---------------- | --------------- | ... | ---------------- |
374 * part_gamma[0] part_gamma[1] part_gamma[2] part_gamma[k-1]
375 * part_beta[0] part_beta[1] part_beta[2] part_beta[k-1]
376 *
377 *
378 * In the second step, we sum up the row-values in part_gamma and part_beta.
379 *
380 * This `LayerNormFusedBackwardKernel_PartGammaBeta` function implements the first step and
381 * `LayerNormFusedBackwardKernel_GammaBeta` implements the second step.
382 */
383 template<typename AType, typename DType>
LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch,const int nchannel,const DType * __restrict__ in_data,const DType * __restrict__ out_grad,const DType * __restrict__ mean_data,const DType * __restrict__ std_data,AType * __restrict__ part_gamma_grad,AType * __restrict__ part_beta_grad)384 __global__ void LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch,
385 const int nchannel,
386 const DType* __restrict__ in_data,
387 const DType* __restrict__ out_grad,
388 const DType* __restrict__ mean_data,
389 const DType* __restrict__ std_data,
390 AType* __restrict__ part_gamma_grad,
391 AType* __restrict__ part_beta_grad) {
392 extern __shared__ char buf[];
393 AType* d_buf = reinterpret_cast<AType*>(buf);
394 const int npart = gridDim.y;
395 const int block_row_num = (nbatch + npart - 1) / npart;
396 // The rows are divided into `npart` parts. Each threadblock calculates the reduction result
397 // within the corresponding row ranges.
398 int row_stride = blockDim.x + 1;
399 const int c = blockIdx.x * blockDim.x + threadIdx.x;
400 int r_begin = blockIdx.y * block_row_num;
401 int r_end = min((blockIdx.y + 1) * block_row_num, nbatch);
402 AType* buf_gamma_grad = d_buf;
403 AType* buf_beta_grad = d_buf + blockDim.y * row_stride;
404 AType local_gamma_grad = 0;
405 AType local_beta_grad = 0;
406
407 if (c < nchannel) {
408 for (int r_b = r_begin; r_b < r_end; r_b += blockDim.y) {
409 int r = r_b + threadIdx.y;
410 if (r < r_end) {
411 AType local_mean = static_cast<AType>(mean_data[r]);
412 AType local_std = static_cast<AType>(std_data[r]);
413 int read_idx = r * nchannel + c;
414 AType local_in_data = static_cast<AType>(in_data[read_idx]);
415 AType local_out_grad = static_cast<AType>(out_grad[read_idx]);
416 local_gamma_grad += (local_in_data - local_mean) / local_std * local_out_grad;
417 local_beta_grad += local_out_grad;
418 }
419 }
420 }
421 buf_gamma_grad[threadIdx.y * row_stride + threadIdx.x] = local_gamma_grad;
422 buf_beta_grad[threadIdx.y * row_stride + threadIdx.x] = local_beta_grad;
423 __syncthreads();
424 for (int offset = blockDim.y/2; offset > 1; offset >>= 1) {
425 if (threadIdx.y < offset) {
426 int idx1 = threadIdx.y * row_stride + threadIdx.x;
427 int idx2 = (threadIdx.y + offset) * row_stride + threadIdx.x;
428 buf_gamma_grad[idx1] += buf_gamma_grad[idx2];
429 buf_beta_grad[idx1] += buf_beta_grad[idx2];
430 }
431 __syncthreads();
432 }
433 if (threadIdx.y == 0 && c < nchannel) {
434 part_gamma_grad[blockIdx.y * nchannel + c] = buf_gamma_grad[threadIdx.x]
435 + buf_gamma_grad[threadIdx.x + row_stride];
436 part_beta_grad[blockIdx.y * nchannel + c] = buf_beta_grad[threadIdx.x]
437 + buf_beta_grad[threadIdx.x + row_stride];
438 }
439 }
440
441 template<bool gamma_addto, bool beta_addto, typename AType, typename DType>
LayerNormFusedBackwardKernel_GammaBeta(const int nbatch,const int nchannel,const int npart,const AType * __restrict__ part_gamma_grad,const AType * __restrict__ part_beta_grad,DType * gamma_grad,DType * beta_grad)442 __global__ void LayerNormFusedBackwardKernel_GammaBeta(const int nbatch,
443 const int nchannel,
444 const int npart,
445 const AType* __restrict__ part_gamma_grad,
446 const AType* __restrict__ part_beta_grad,
447 DType* gamma_grad,
448 DType* beta_grad) {
449 const int c = blockIdx.x * blockDim.x + threadIdx.x;
450 const int tid = threadIdx.y * blockDim.x + threadIdx.x;
451 if (c < nchannel) {
452 extern __shared__ char buf[];
453 AType* buf_gamma_grad = reinterpret_cast<AType*>(buf);
454 AType* buf_beta_grad = reinterpret_cast<AType*>(buf) + blockDim.x * blockDim.y;
455 buf_gamma_grad[tid] = 0;
456 buf_beta_grad[tid] = 0;
457 for (int r = threadIdx.y; r < npart; r += blockDim.y) {
458 buf_gamma_grad[tid] += part_gamma_grad[r * nchannel + c];
459 buf_beta_grad[tid] += part_beta_grad[r * nchannel + c];
460 }
461 __syncthreads();
462 // Begin for inter-warp reduce
463 if (npart > 1) {
464 for (int offset = blockDim.y/2; offset > 0; offset >>= 1) {
465 if (threadIdx.y < offset) {
466 int idx1 = tid;
467 int idx2 = tid + offset * blockDim.x;
468 buf_gamma_grad[idx1] += buf_gamma_grad[idx2];
469 buf_beta_grad[idx1] += buf_beta_grad[idx2];
470 }
471 __syncthreads();
472 }
473 }
474 if (threadIdx.y == 0) {
475 if (gamma_grad) {
476 if (gamma_addto) {
477 gamma_grad[c] += static_cast<DType>(buf_gamma_grad[threadIdx.x]);
478 } else {
479 gamma_grad[c] = static_cast<DType>(buf_gamma_grad[threadIdx.x]);
480 }
481 }
482 if (beta_grad) {
483 if (beta_addto) {
484 beta_grad[c] += static_cast<DType>(buf_beta_grad[threadIdx.x]);
485 } else {
486 beta_grad[c] = static_cast<DType>(buf_beta_grad[threadIdx.x]);
487 }
488 }
489 }
490 }
491 }
492
493 /*
494 *
495 *
496 */
497 template<int LOAD_UNROLL, bool data_addto, typename AType, typename DType>
LayerNormFusedBackwardKernel_Data(const int nbatch,const int nchannel,const DType * __restrict__ in_data,const DType * __restrict__ out_grad,const DType * __restrict__ mean_data,const DType * __restrict__ std_data,const DType * __restrict__ gamma,DType * data_grad)498 __global__ void LayerNormFusedBackwardKernel_Data(const int nbatch,
499 const int nchannel,
500 const DType* __restrict__ in_data,
501 const DType* __restrict__ out_grad,
502 const DType* __restrict__ mean_data,
503 const DType* __restrict__ std_data,
504 const DType* __restrict__ gamma,
505 DType* data_grad) {
506 int bid = blockIdx.x + blockIdx.y * gridDim.x;
507 const int nthread = blockDim.x * blockDim.y;
508 if (bid < nbatch) {
509 // Shared memory with size blockDim.y * blockDim.x * sizeof(DType)
510 extern __shared__ char buf[];
511 int tid = threadIdx.x + threadIdx.y * blockDim.x;
512 // 1. Calculate: mean(out_grad * gamma / std, axis=-1)
513 // mean(out_grad * gamma / std * (x - mean) / std, axis=-1)
514 AType sum_val0 = 0; // Stores mean(out_grad * gamma / std, axis=-1)
515 AType sum_val1 = 0; // Stores mean(out_grad * gamma / std * (x - mean) / std, axis=-1)
516 AType mean = static_cast<AType>(mean_data[bid]);
517 AType invstd_eps = AType(1) / static_cast<AType>(std_data[bid]);
518 int l = LOAD_UNROLL * tid;
519 for (; l + LOAD_UNROLL - 1 < nchannel; l += nthread * LOAD_UNROLL) {
520 #pragma unroll
521 for (int i = 0; i < LOAD_UNROLL; ++i) {
522 AType ele_og = static_cast<AType>(out_grad[bid * nchannel + l + i]);
523 AType ele_x = static_cast<AType>(in_data[bid * nchannel + l + i]);
524 AType ele_gamma = static_cast<AType>(gamma[l + i]);
525 sum_val0 += ele_og * ele_gamma * invstd_eps;
526 sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps;
527 }
528 }
529 for (; l < nchannel; ++l) {
530 AType ele_og = static_cast<AType>(out_grad[bid * nchannel + l]);
531 AType ele_x = static_cast<AType>(in_data[bid * nchannel + l]);
532 AType ele_gamma = static_cast<AType>(gamma[l]);
533 sum_val0 += ele_og * ele_gamma * invstd_eps;
534 sum_val1 += ele_og * ele_gamma * (ele_x - mean) * invstd_eps * invstd_eps;
535 }
536 // Intra-warp reduction (all-reduce)
537 for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) {
538 sum_val0 += warp_shfl_xor(sum_val0, mask);
539 sum_val1 += warp_shfl_xor(sum_val1, mask);
540 }
541 // Inter-warp reduction (all-reduce)
542 if (blockDim.y > 1) {
543 AType* sum_val0_buf = reinterpret_cast<AType*>(buf);
544 AType* sum_val1_buf =
545 reinterpret_cast<AType*>(buf + blockDim.y / 2 * blockDim.x * sizeof(AType));
546 for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
547 if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
548 const int idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
549 sum_val0_buf[idx] = sum_val0;
550 sum_val1_buf[idx] = sum_val1;
551 }
552 __syncthreads();
553 if (threadIdx.y < offset) {
554 const int idx = threadIdx.y * blockDim.x + threadIdx.x;
555 sum_val0 += sum_val0_buf[idx];
556 sum_val1 += sum_val1_buf[idx];
557 }
558 __syncthreads();
559 }
560 if (threadIdx.y == 0) {
561 sum_val0_buf[threadIdx.x] = sum_val0;
562 sum_val1_buf[threadIdx.x] = sum_val1;
563 }
564 __syncthreads();
565 sum_val0 = sum_val0_buf[threadIdx.x];
566 sum_val1 = sum_val1_buf[threadIdx.x];
567 }
568 sum_val0 /= nchannel;
569 sum_val1 /= nchannel;
570 // 2. Calculate the gradient as
571 // out_grad * gamma / std - sum_val0 - (x - mean) / std * sum_val1
572 for (int l = tid; l < nchannel; l += nthread) {
573 AType ele_out_grad = static_cast<AType>(out_grad[bid * nchannel + l]);
574 AType ele_x = static_cast<AType>(in_data[bid * nchannel + l]);
575 AType ele_gamma = static_cast<AType>(gamma[l]);
576 if (data_addto) {
577 data_grad[bid * nchannel + l] +=
578 static_cast<DType>(ele_out_grad * ele_gamma * invstd_eps
579 - sum_val0 - (ele_x - mean) * invstd_eps * sum_val1);
580 } else {
581 data_grad[bid * nchannel + l] =
582 static_cast<DType>(ele_out_grad * ele_gamma * invstd_eps - sum_val0
583 - (ele_x - mean) * invstd_eps * sum_val1);
584 }
585 }
586 }
587 }
588
GetGammaBetaGradKernelParams(const int nbatch,const int nchannel,dim3 * part_grad_block_dim,dim3 * part_grad_grid_dim,dim3 * gb_block_dim,dim3 * gb_grid_dim,int * npart)589 void GetGammaBetaGradKernelParams(const int nbatch, const int nchannel,
590 dim3* part_grad_block_dim, dim3* part_grad_grid_dim,
591 dim3* gb_block_dim, dim3* gb_grid_dim,
592 int* npart) {
593 *npart = 16;
594 *part_grad_block_dim = dim3(32, 16);
595 *part_grad_grid_dim = dim3((nchannel + 32 - 1) / 32, *npart);
596 *gb_block_dim = dim3(32, *npart);
597 *gb_grid_dim = dim3((nchannel + 32 - 1) / 32);
598 CheckLaunchParam(*part_grad_grid_dim, *part_grad_block_dim);
599 CheckLaunchParam(*gb_grid_dim, *gb_block_dim);
600 }
601
602 template<bool safe_acc = false>
LayerNormGradGPUContig(const LayerNormParam param,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)603 void LayerNormGradGPUContig(const LayerNormParam param,
604 const OpContext& ctx, const std::vector<TBlob>& inputs,
605 const std::vector<OpReqType>& req,
606 const std::vector<TBlob>& outputs) {
607 using namespace mshadow;
608 CHECK_EQ(inputs.size(), 5U);
609 const TBlob out_grad = inputs[0];
610 const TBlob in_data = inputs[1];
611 const TBlob gamma = inputs[2];
612 const TBlob mean_data = inputs[3];
613 const TBlob std_data = inputs[4];
614 const TBlob data_grad = outputs[0];
615 const TBlob gamma_grad = outputs[1];
616 const TBlob beta_grad = outputs[2];
617
618 // Make sure the inputs are contiguous
619 CHECK_EQ(out_grad.CheckContiguous(), true);
620 CHECK_EQ(in_data.CheckContiguous(), true);
621 CHECK_EQ(gamma.CheckContiguous(), true);
622 CHECK_EQ(mean_data.CheckContiguous(), true);
623 CHECK_EQ(std_data.CheckContiguous(), true);
624 int nbatch = in_data.shape_.ProdShape(0, in_data.ndim() - 1);
625 int nchannel = in_data.shape_[in_data.ndim() - 1];
626 int data_grad_req = req[0];
627 int gamma_grad_req = req[1];
628 int beta_grad_req = req[2];
629 CHECK_NE(data_grad_req, kWriteInplace);
630 CHECK_NE(gamma_grad_req, kWriteInplace);
631 CHECK_NE(beta_grad_req, kWriteInplace);
632 Stream<gpu> *s = ctx.get_stream<gpu>();
633 cudaStream_t stream = Stream<gpu>::GetStream(s);
634
635 // Calculate the gradient for gamma/beta
636 CHECK_EQ(gamma_grad.CheckContiguous(), true);
637 CHECK_EQ(beta_grad.CheckContiguous(), true);
638 dim3 part_grad_block_dim, part_grad_grid_dim, gb_block_dim, gb_grid_dim;
639 int npart;
640 GetGammaBetaGradKernelParams(nbatch, nchannel, &part_grad_block_dim, &part_grad_grid_dim,
641 &gb_block_dim, &gb_grid_dim, &npart);
642 if (gamma_grad_req != kNullOp || beta_grad_req != kNullOp) {
643 MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
644 typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
645 Tensor<gpu, 1, AType> workspace =
646 ctx.requested[0].get_space_typed<gpu, 1, AType>(Shape1(2 * npart * nchannel), s);
647 AType* part_gamma_grad_ptr = workspace.dptr_;
648 AType* part_beta_grad_ptr = workspace.dptr_ + npart * nchannel;
649 const int nshared_K1 = 2 * (part_grad_block_dim.x + 1)
650 * part_grad_block_dim.y * sizeof(AType);
651 const int nshared_K2 = 2 * gb_block_dim.x * gb_block_dim.y * sizeof(AType);
652 DType* gamma_grad_ptr = (gamma_grad_req != kNullOp) ? gamma_grad.dptr<DType>() : nullptr;
653 DType* beta_grad_ptr = (beta_grad_req != kNullOp) ? beta_grad.dptr<DType>() : nullptr;
654 LayerNormFusedBackwardKernel_PartGammaBeta
655 <<<part_grad_grid_dim, part_grad_block_dim, nshared_K1, stream>>>
656 (nbatch, nchannel, in_data.dptr<DType>(), out_grad.dptr<DType>(),
657 mean_data.dptr<DType>(), std_data.dptr<DType>(), part_gamma_grad_ptr, part_beta_grad_ptr);
658 MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_PartGammaBeta);
659 if (gamma_grad_req == kAddTo && beta_grad_req != kAddTo) {
660 LayerNormFusedBackwardKernel_GammaBeta<true, false>
661 <<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>
662 (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr,
663 gamma_grad_ptr, beta_grad_ptr);
664 } else if (gamma_grad_req != kAddTo && beta_grad_req == kAddTo) {
665 LayerNormFusedBackwardKernel_GammaBeta<false, true>
666 <<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>
667 (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr,
668 gamma_grad_ptr, beta_grad_ptr);
669 } else if (gamma_grad_req == kAddTo && beta_grad_req == kAddTo) {
670 LayerNormFusedBackwardKernel_GammaBeta<true, true>
671 <<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>
672 (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr,
673 gamma_grad_ptr, beta_grad_ptr);
674 } else {
675 LayerNormFusedBackwardKernel_GammaBeta<false, false>
676 <<<gb_grid_dim, gb_block_dim, nshared_K2, stream>>>
677 (nbatch, nchannel, npart, part_gamma_grad_ptr, part_beta_grad_ptr,
678 gamma_grad_ptr, beta_grad_ptr);
679 }
680 });
681 MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_GammaBeta);
682 }
683
684 // Calculate the gradient for data
685 CHECK_EQ(data_grad.CheckContiguous(), true);
686 int ngrid_x = (nbatch > kMaxGridDim) ? (nbatch + kBaseGridNum - 1) / kBaseGridNum : nbatch;
687 int ngrid_y = (nbatch > kMaxGridDim) ? kBaseGridNum : 1;
688 const dim3 data_grid_dim(ngrid_x, ngrid_y);
689 int nthread_y;
690 if (nchannel <= 32) {
691 nthread_y = 1;
692 } else if (nchannel <= 128) {
693 nthread_y = 2;
694 } else if (nchannel <= 512) {
695 nthread_y = 4;
696 } else {
697 nthread_y = 8;
698 }
699 const dim3 data_block_dim(32, nthread_y);
700 const int LOAD_UNROLL = 4;
701 if (data_grad_req != kNullOp) {
702 MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
703 typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
704 int nshared = data_block_dim.y > 1 ? data_block_dim.y * data_block_dim.x * sizeof(AType) : 0;
705 CheckLaunchParam(data_grid_dim, data_block_dim);
706 if (data_grad_req == kAddTo) {
707 LayerNormFusedBackwardKernel_Data<LOAD_UNROLL, true, AType>
708 <<<data_grid_dim, data_block_dim, nshared, stream>>>
709 (nbatch, nchannel, in_data.dptr<DType>(), out_grad.dptr<DType>(), mean_data.dptr<DType>(),
710 std_data.dptr<DType>(), gamma.dptr<DType>(), data_grad.dptr<DType>());
711 } else {
712 LayerNormFusedBackwardKernel_Data<LOAD_UNROLL, false, AType>
713 <<<data_grid_dim, data_block_dim, nshared, stream>>>
714 (nbatch, nchannel, in_data.dptr<DType>(), out_grad.dptr<DType>(), mean_data.dptr<DType>(),
715 std_data.dptr<DType>(), gamma.dptr<DType>(), data_grad.dptr<DType>());
716 }
717 });
718 MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_Data);
719 }
720 }
721
722 template<>
LayerNormGradCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)723 void LayerNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
724 const OpContext& ctx, const std::vector<TBlob>& inputs,
725 const std::vector<OpReqType>& req,
726 const std::vector<TBlob>& outputs) {
727 const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
728 int axis = param.axis;
729 if (axis < 0) {
730 axis += static_cast<int>(inputs[0].ndim());
731 }
732 CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
733 if (axis == inputs[0].ndim() - 1) {
734 // Use the accelerated CUDA kernels
735 bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
736 if (safe_acc) {
737 return LayerNormGradGPUContig<true>(param, ctx, inputs, req, outputs);
738 } else {
739 return LayerNormGradGPUContig<false>(param, ctx, inputs, req, outputs);
740 }
741 }
742 return LayerNormGradComputeGeneral<gpu>(attrs, ctx, inputs, req, outputs);
743 }
744
745
746 NNVM_REGISTER_OP(LayerNorm)
747 .set_attr<FCompute>("FCompute<gpu>", LayerNormCompute<gpu>);
748
749 NNVM_REGISTER_OP(_backward_LayerNorm)
750 .set_attr<FCompute>("FCompute<gpu>", LayerNormGradCompute<gpu>);
751
752 } // namespace op
753 } // namespace mxnet
754