Searched refs:sum_val0 (Results 1 – 3 of 3) sorted by relevance
514 AType sum_val0 = 0; // Stores mean(out_grad * gamma / std, axis=-1) in LayerNormFusedBackwardKernel_Data() local525 sum_val0 += ele_og * ele_gamma * invstd_eps; in LayerNormFusedBackwardKernel_Data()533 sum_val0 += ele_og * ele_gamma * invstd_eps; in LayerNormFusedBackwardKernel_Data()538 sum_val0 += warp_shfl_xor(sum_val0, mask); in LayerNormFusedBackwardKernel_Data()549 sum_val0_buf[idx] = sum_val0; in LayerNormFusedBackwardKernel_Data()555 sum_val0 += sum_val0_buf[idx]; in LayerNormFusedBackwardKernel_Data()561 sum_val0_buf[threadIdx.x] = sum_val0; in LayerNormFusedBackwardKernel_Data()565 sum_val0 = sum_val0_buf[threadIdx.x]; in LayerNormFusedBackwardKernel_Data()568 sum_val0 /= nchannel; in LayerNormFusedBackwardKernel_Data()579 - sum_val0 - (ele_x - mean) * invstd_eps * sum_val1); in LayerNormFusedBackwardKernel_Data()[all …]
980 v_float32x4 sum_val0 = v_setzero_f32(), sum_val1 = v_setzero_f32(); in operator ()() local992 sum_val0 += v0; in operator ()()996 v_store(dstData + x0, sum_val0*ikarea); in operator ()()