1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include <type_traits>
17
18 #include "cpu/x64/prelu/jit_uni_prelu_backward_kernel.hpp"
19
20 namespace dnnl {
21 namespace impl {
22 namespace cpu {
23 namespace x64 {
24
jit_prelu_backward_kernel_t(const cpu_prelu_bwd_pd_t * pd,const cpu_isa_t & isa,const int vlen,const size_t number_vmm_single_compute)25 jit_prelu_backward_kernel_t::jit_prelu_backward_kernel_t(
26 const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa, const int vlen,
27 const size_t number_vmm_single_compute)
28 : jit_prelu_base_kernel_t(isa, vlen,
29 prelu::get_bcast_type(memory_desc_wrapper(pd->diff_src_md(0)),
30 memory_desc_wrapper(pd->diff_weights_md(0))),
31 memory_desc_wrapper(pd->diff_src_md(0)), number_vmm_single_compute)
32 , pd_(pd)
33 , src_dt_(pd->src_md(0)->data_type)
34 , wei_dt_(pd->weights_md(0)->data_type)
35 , diff_src_dt_(pd->diff_src_md(0)->data_type)
36 , diff_dst_dt_(pd->diff_dst_md(0)->data_type)
37 , diff_wei_dt_(bcast_ == prelu::bcast::full
38 ? pd->diff_weights_md(0)->data_type
39 : data_type::f32)
40 , diff_src_block_tail_(prelu::get_block_tail_size(pd->diff_src_md(0)))
41 , diff_wei_block_tail_(prelu::get_block_tail_size(pd->diff_weights_md(0))) {
42 }
43
44 #define PARAM_OFF(x) offsetof(call_params_t, x)
45
load_kernel_call_params()46 void jit_prelu_backward_kernel_t::load_kernel_call_params() {
47 mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]);
48 mov(reg_weights_, ptr[abi_param1 + PARAM_OFF(weights)]);
49 mov(reg_src_diff_, ptr[abi_param1 + PARAM_OFF(src_diff)]);
50 mov(reg_weights_diff_, ptr[abi_param1 + PARAM_OFF(weights_diff)]);
51 mov(reg_dst_diff_, ptr[abi_param1 + PARAM_OFF(dst_diff)]);
52 mov(reg_data_size_, ptr[abi_param1 + PARAM_OFF(compute_data_size)]);
53 }
54
55 #undef PARAM_OFF
56
data_ptr(int arg_num,size_t offt)57 Xbyak::Address jit_prelu_backward_kernel_t::data_ptr(int arg_num, size_t offt) {
58 const auto get_addr
59 = [&](const Xbyak::Reg64 ®_base, const data_type_t dt) {
60 const auto dt_size = types::data_type_size(dt);
61 return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size];
62 };
63
64 switch (arg_num) {
65 case DNNL_ARG_SRC: return get_addr(reg_src_, src_dt_);
66 case DNNL_ARG_WEIGHTS: return get_addr(reg_weights_, wei_dt_);
67 case DNNL_ARG_DIFF_SRC: return get_addr(reg_src_diff_, diff_src_dt_);
68 case DNNL_ARG_DIFF_WEIGHTS:
69 return get_addr(reg_weights_diff_, diff_wei_dt_);
70 case DNNL_ARG_DIFF_DST: return get_addr(reg_dst_diff_, diff_dst_dt_);
71
72 default: assert(!"unsupported arg_num"); break;
73 }
74 return Xbyak::Address(0);
75 }
76
any_tensor_bf16() const77 bool jit_prelu_backward_kernel_t::any_tensor_bf16() const {
78 return utils::one_of(data_type::bf16, src_dt_, wei_dt_, diff_src_dt_,
79 diff_dst_dt_, diff_wei_dt_);
80 }
81
82 template <typename Vmm>
jit_uni_prelu_backward_kernel_t(const cpu_prelu_bwd_pd_t * pd,const cpu_isa_t & isa)83 jit_uni_prelu_backward_kernel_t<Vmm>::jit_uni_prelu_backward_kernel_t(
84 const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa)
85 : jit_prelu_backward_kernel_t(pd, isa, prelu::vmm_traits_t<Vmm>::vlen,
86 std::is_same<Vmm, Xbyak::Zmm>::value ? 4u : 6u)
87 , saturation_needed_diff_src_(utils::one_of(
88 diff_src_dt_, data_type::u8, data_type::s8, data_type::s32))
89 , saturation_needed_diff_weights_(utils::one_of(
90 diff_wei_dt_, data_type::u8, data_type::s8, data_type::s32))
91 , vmm_zeros_(reserve_vmm())
92 , saturation_ubound_diff_src_(
93 saturation_needed_diff_src_ ? reserve_vmm() : 0)
94 , saturation_ubound_diff_weights_(saturation_needed_diff_weights_
95 ? (diff_wei_dt_ == diff_src_dt_
96 ? saturation_ubound_diff_src_.getIdx()
97 : reserve_vmm())
98 : 0)
99 , tail_vmm_mask_(
100 tail_size_ && utils::one_of(isa, avx, avx2) ? reserve_vmm() : 0)
101 , vmm_ones_(reserve_vmm())
102 , weights_const_vmm_(utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
103 prelu::bcast::per_oc_blocked)
104 ? reserve_vmm()
105 : 0)
106 , weights_diff_acc_vmm_(
107 utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
108 prelu::bcast::per_oc_blocked)
109 ? reserve_vmm()
110 : 0)
111 , io_(this, isa,
112 {src_dt_, wei_dt_, diff_src_dt_, diff_dst_dt_, diff_wei_dt_}, {},
113 io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_,
114 tail_vmm_mask_.getIdx(), reg_tmp_},
115 io::io_emu_bf16_conf_t {}, create_saturation_vmm_map()) {}
116
117 template <typename Vmm>
118 jit_uni_prelu_backward_kernel_t<Vmm>::~jit_uni_prelu_backward_kernel_t()
119 = default;
120
121 template <typename Vmm>
prepare_kernel_const_vars()122 void jit_uni_prelu_backward_kernel_t<Vmm>::prepare_kernel_const_vars() {
123 uni_vxorps(vmm_zeros_, vmm_zeros_, vmm_zeros_);
124
125 io_.init_bf16();
126 if (tail_size_) io_.prepare_tail_mask();
127 if (saturation_needed_diff_src_ || saturation_needed_diff_weights_) {
128 io_.init_saturate_f32({diff_src_dt_, diff_wei_dt_});
129 }
130 // load ones
131 this->mov(this->reg_tmp_, float2int(1));
132 const Xbyak::Xmm xmm_ones_ {vmm_ones_.getIdx()};
133 this->uni_vmovq(xmm_ones_, this->reg_tmp_);
134 this->uni_vbroadcastss(vmm_ones_, xmm_ones_);
135
136 if (bcast_ == prelu::bcast::per_oc_blocked) {
137 io_.at(wei_dt_)->load(
138 ptr[reg_weights_], weights_const_vmm_, false /*tail*/);
139 vmovups(weights_diff_acc_vmm_, ptr[reg_weights_diff_]);
140 } else if (bcast_ == prelu::bcast::per_oc_n_c_spatial) {
141 io_.at(wei_dt_)->broadcast(ptr[reg_weights_], weights_const_vmm_);
142 uni_vxorps(weights_diff_acc_vmm_, weights_diff_acc_vmm_,
143 weights_diff_acc_vmm_);
144 uni_vmovss(weights_diff_acc_vmm_, ptr[reg_weights_diff_]);
145 }
146 }
147
148 template <typename Vmm>
compute_dst(size_t unrolling_factor,bool tail)149 void jit_uni_prelu_backward_kernel_t<Vmm>::compute_dst(
150 size_t unrolling_factor, bool tail) {
151
152 static constexpr size_t dst_diff_idx = 0;
153 static constexpr size_t src_idx = 1;
154 static constexpr size_t src_le_zero_idx = 2;
155 static constexpr size_t src_gt_zero_idx = 3;
156 static constexpr size_t weights_diff_idx = 4;
157 static constexpr size_t weights_idx = 5;
158
159 for (size_t unroll_group = 0; unroll_group < unrolling_factor;
160 ++unroll_group) {
161
162 const Vmm dst_diff_vmm {get_compute_vmm(dst_diff_idx, unroll_group)};
163 const Vmm src_vmm {get_compute_vmm(src_idx, unroll_group)};
164 const Vmm src_le_zero_vmm {
165 get_compute_vmm(src_le_zero_idx, unroll_group)};
166 const Vmm src_gt_zero_vmm {
167 get_compute_vmm(src_gt_zero_idx, unroll_group)};
168 const Vmm weights_diff_vmm {
169 get_compute_vmm(weights_diff_idx, unroll_group)};
170 const Vmm weights_vmm {get_compute_vmm(weights_idx, unroll_group)};
171
172 const auto offset = unroll_group * simd_w_;
173 io_.at(diff_dst_dt_)
174 ->load(data_ptr(DNNL_ARG_DIFF_DST, offset), dst_diff_vmm, tail);
175 io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail);
176 static constexpr int VCMPLEPS = 2;
177 uni_vcmpps(src_le_zero_vmm, src_vmm, vmm_zeros_, VCMPLEPS);
178 uni_vandps(src_le_zero_vmm, src_le_zero_vmm, vmm_ones_);
179 static constexpr int VCMPGTPS = 14;
180 uni_vcmpps(src_gt_zero_vmm, src_vmm, vmm_zeros_, VCMPGTPS);
181 uni_vandps(src_gt_zero_vmm, src_gt_zero_vmm, vmm_ones_);
182
183 //weights_diff_calculations
184 uni_vmulps(weights_diff_vmm, dst_diff_vmm, src_vmm);
185 uni_vmulps(weights_diff_vmm, weights_diff_vmm, src_le_zero_vmm);
186
187 //src_diff calculations
188 const auto weights_operand = get_or_load_weights(
189 data_ptr(DNNL_ARG_WEIGHTS, offset), weights_vmm, tail);
190 uni_vfmadd231ps(src_gt_zero_vmm, src_le_zero_vmm, weights_operand);
191 const auto &src_diff_vmm = src_gt_zero_vmm;
192 uni_vmulps(src_diff_vmm, src_diff_vmm, dst_diff_vmm);
193 io_.at(diff_src_dt_)
194 ->store(src_diff_vmm, data_ptr(DNNL_ARG_DIFF_SRC, offset),
195 tail);
196 if (diff_src_block_tail_ && tail)
197 prelu::apply_zero_padding(this, tail_size_, diff_src_dt_,
198 diff_src_block_tail_, reg_src_diff_, nullptr);
199
200 accumulate_weights_diff(weights_diff_vmm, src_gt_zero_vmm,
201 data_ptr(DNNL_ARG_DIFF_WEIGHTS, offset), tail);
202 }
203 }
204
205 template <>
compute_dst(size_t unrolling_factor,bool tail)206 void jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>::compute_dst(
207 size_t unrolling_factor, bool tail) {
208
209 size_t opmask_counter = 2;
210 auto get_next_opmask = [opmask_counter]() mutable {
211 static constexpr size_t opmask_range_begin = 2;
212 static constexpr size_t opmask_range_end = 8;
213 const auto opmask = Xbyak::Opmask(opmask_counter++);
214 if (opmask_counter == opmask_range_end)
215 opmask_counter = opmask_range_begin;
216 return opmask;
217 };
218
219 static constexpr size_t dst_diff_idx = 0;
220 static constexpr size_t src_idx = 1;
221 static constexpr size_t weights_diff_idx = 2;
222 static constexpr size_t weights_idx = 3;
223
224 for (size_t unroll_group = 0; unroll_group < unrolling_factor;
225 ++unroll_group) {
226
227 const auto offset = unroll_group * simd_w_;
228 const Xbyak::Zmm dst_diff_vmm {
229 get_compute_vmm(dst_diff_idx, unroll_group)};
230 const Xbyak::Zmm src_vmm {get_compute_vmm(src_idx, unroll_group)};
231
232 io_.at(diff_dst_dt_)
233 ->load(data_ptr(DNNL_ARG_DIFF_DST, offset), dst_diff_vmm, tail);
234 io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail);
235
236 const Xbyak::Opmask src_le_zero_opmask = get_next_opmask();
237 static constexpr int VCMPLEPS = 2;
238 vcmpps(src_le_zero_opmask, src_vmm, vmm_zeros_, VCMPLEPS);
239 const Xbyak::Opmask src_gt_zero_vmm_opmask = get_next_opmask();
240 static constexpr int VCMPGTPS = 14;
241 vcmpps(src_gt_zero_vmm_opmask, src_vmm, vmm_zeros_, VCMPGTPS);
242
243 // //weights_diff_calculations
244 const Xbyak::Zmm weights_diff_vmm {
245 get_compute_vmm(weights_diff_idx, unroll_group)};
246 vmulps(weights_diff_vmm | src_le_zero_opmask | T_z, dst_diff_vmm,
247 src_vmm);
248 accumulate_weights_diff(weights_diff_vmm, weights_diff_acc_vmm_,
249 data_ptr(DNNL_ARG_DIFF_WEIGHTS, offset), tail);
250
251 //src_diff calculations
252 const Xbyak::Zmm weights_vmm {
253 get_compute_vmm(weights_idx, unroll_group)};
254 const auto &src_diff_vmm = weights_vmm;
255 const auto weights_operand = get_or_load_weights(
256 data_ptr(DNNL_ARG_WEIGHTS, offset), weights_vmm, tail);
257
258 vmovaps(src_diff_vmm | src_le_zero_opmask | T_z, weights_operand);
259 vaddps(src_diff_vmm | src_gt_zero_vmm_opmask, src_diff_vmm, vmm_ones_);
260 vmulps(src_diff_vmm, src_diff_vmm, dst_diff_vmm);
261 io_.at(diff_src_dt_)
262 ->store(src_diff_vmm, data_ptr(DNNL_ARG_DIFF_SRC, offset),
263 tail);
264 if (diff_src_block_tail_ && tail)
265 prelu::apply_zero_padding(this, tail_size_, diff_src_dt_,
266 diff_src_block_tail_, reg_src_diff_, nullptr);
267 }
268 }
269
270 template <typename Vmm>
accumulate_weights_diff(const Vmm & partial_sum_vmm,const Vmm & tmp_vmm,const Xbyak::Address & dst_addr,bool tail)271 void jit_uni_prelu_backward_kernel_t<Vmm>::accumulate_weights_diff(
272 const Vmm &partial_sum_vmm, const Vmm &tmp_vmm,
273 const Xbyak::Address &dst_addr, bool tail) {
274
275 if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
276 prelu::bcast::per_oc_blocked)) {
277 uni_vaddps(
278 weights_diff_acc_vmm_, weights_diff_acc_vmm_, partial_sum_vmm);
279 } else if (bcast_ == prelu::bcast::per_oc_n_spatial_c) {
280 if (std::is_same<Vmm, Xbyak::Zmm>::value || isa_ == avx2)
281 uni_vaddps(partial_sum_vmm, partial_sum_vmm, dst_addr);
282 else {
283 uni_vmovups(tmp_vmm, dst_addr);
284 uni_vaddps(partial_sum_vmm, partial_sum_vmm, tmp_vmm);
285 }
286 uni_vmovups(dst_addr, partial_sum_vmm);
287 } else {
288 io_.at(diff_wei_dt_)->store(partial_sum_vmm, dst_addr, tail);
289 if (diff_wei_block_tail_ && tail)
290 prelu::apply_zero_padding(this, tail_size_, diff_wei_dt_,
291 diff_wei_block_tail_, reg_weights_diff_, nullptr);
292 }
293 }
294
295 template <typename Vmm>
get_or_load_weights(const Xbyak::Address & src_addr,const Vmm & weights_vmm,bool tail)296 const Xbyak::Operand &jit_uni_prelu_backward_kernel_t<Vmm>::get_or_load_weights(
297 const Xbyak::Address &src_addr, const Vmm &weights_vmm, bool tail) {
298
299 if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
300 prelu::bcast::per_oc_blocked))
301 return weights_const_vmm_;
302
303 io_.at(wei_dt_)->load(src_addr, weights_vmm, tail);
304 return weights_vmm;
305 }
306
reduce(jit_generator * host,const Xbyak::Xmm & src,const Xbyak::Xmm & helper,const cpu_isa_t & isa)307 static void reduce(jit_generator *host, const Xbyak::Xmm &src,
308 const Xbyak::Xmm &helper, const cpu_isa_t &isa) {
309 UNUSED(helper);
310 if (isa == sse41) {
311 host->haddps(src, src);
312 host->haddps(src, src);
313 } else {
314 host->vhaddps(src, src, src);
315 host->vhaddps(src, src, src);
316 }
317 }
318
reduce(jit_generator * host,const Xbyak::Ymm & src,const Xbyak::Ymm & helper,const cpu_isa_t & isa)319 static void reduce(jit_generator *host, const Xbyak::Ymm &src,
320 const Xbyak::Ymm &helper, const cpu_isa_t &isa) {
321 const Xbyak::Xmm xmm_helper {helper.getIdx()};
322 const Xbyak::Xmm xmm_src {src.getIdx()};
323
324 host->vextractf128(xmm_helper, src, 1);
325 host->vaddps(xmm_src, xmm_src, xmm_helper);
326 reduce(host, xmm_src, xmm_helper, isa);
327 }
328
reduce(jit_generator * host,const Xbyak::Zmm & src,const Xbyak::Zmm & helper,const cpu_isa_t & isa)329 static void reduce(jit_generator *host, const Xbyak::Zmm &src,
330 const Xbyak::Zmm &helper, const cpu_isa_t &isa) {
331 const Xbyak::Ymm ymm_helper {helper.getIdx()};
332 const Xbyak::Ymm ymm_src {src.getIdx()};
333
334 host->vextractf64x4(ymm_helper, src, 1);
335 host->vaddps(ymm_src, ymm_src, ymm_helper);
336 reduce(host, ymm_src, ymm_helper, isa);
337 }
338
339 template <typename Vmm>
finalize()340 void jit_uni_prelu_backward_kernel_t<Vmm>::finalize() {
341 if (bcast_ == prelu::bcast::per_oc_blocked)
342 uni_vmovups(ptr[reg_weights_diff_], weights_diff_acc_vmm_);
343 else if (bcast_ == prelu::bcast::per_oc_n_c_spatial) {
344 reduce(this, weights_diff_acc_vmm_, weights_const_vmm_, isa_);
345 uni_vmovss(ptr[reg_weights_diff_], weights_diff_acc_vmm_);
346 }
347 }
348
349 template <typename Vmm>
350 std::map<data_type_t, io::io_saturation_conf_t>
create_saturation_vmm_map() const351 jit_uni_prelu_backward_kernel_t<Vmm>::create_saturation_vmm_map() const {
352
353 std::map<data_type_t, io::io_saturation_conf_t> saturation_map {};
354
355 if (saturation_needed_diff_src_)
356 saturation_map.emplace(diff_src_dt_,
357 io::io_saturation_conf_t {vmm_zeros_.getIdx(),
358 saturation_ubound_diff_src_.getIdx(), reg_tmp_});
359
360 if (saturation_needed_diff_weights_ && diff_src_dt_ != diff_wei_dt_)
361 saturation_map.emplace(diff_wei_dt_,
362 io::io_saturation_conf_t {vmm_zeros_.getIdx(),
363 saturation_ubound_diff_weights_.getIdx(), reg_tmp_});
364
365 return saturation_map;
366 }
367
create(const cpu_prelu_bwd_pd_t * pd)368 jit_prelu_backward_kernel_t *jit_prelu_backward_kernel_t::create(
369 const cpu_prelu_bwd_pd_t *pd) {
370
371 const auto isa = prelu::get_supported_isa();
372
373 const auto &src_dt = pd->src_md(0)->data_type;
374 const auto &wei_dt = pd->weights_md(0)->data_type;
375 const auto &diff_src_dt = pd->diff_src_md(0)->data_type;
376 const auto &diff_dst_dt = pd->diff_dst_md(0)->data_type;
377 const auto &diff_wei_dt = pd->diff_weights_md(0)->data_type;
378
379 if (is_superset(isa, avx512_common))
380 return new jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>(pd, isa);
381 else if (is_superset(isa, avx)) {
382 if (isa == avx
383 && prelu::is_s8u8({src_dt, wei_dt, diff_src_dt, diff_dst_dt,
384 diff_wei_dt}))
385 return new jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>(pd, isa);
386 else
387 return new jit_uni_prelu_backward_kernel_t<Xbyak::Ymm>(pd, isa);
388 } else if (isa == sse41)
389 return new jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>(pd, isa);
390
391 return nullptr;
392 }
393
394 template class jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>;
395 template class jit_uni_prelu_backward_kernel_t<Xbyak::Ymm>;
396 template class jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>;
397
398 } // namespace x64
399 } // namespace cpu
400 } // namespace impl
401 } // namespace dnnl
402