1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include <type_traits>
17 
18 #include "cpu/x64/prelu/jit_uni_prelu_backward_kernel.hpp"
19 
20 namespace dnnl {
21 namespace impl {
22 namespace cpu {
23 namespace x64 {
24 
jit_prelu_backward_kernel_t(const cpu_prelu_bwd_pd_t * pd,const cpu_isa_t & isa,const int vlen,const size_t number_vmm_single_compute)25 jit_prelu_backward_kernel_t::jit_prelu_backward_kernel_t(
26         const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa, const int vlen,
27         const size_t number_vmm_single_compute)
28     : jit_prelu_base_kernel_t(isa, vlen,
29             prelu::get_bcast_type(memory_desc_wrapper(pd->diff_src_md(0)),
30                     memory_desc_wrapper(pd->diff_weights_md(0))),
31             memory_desc_wrapper(pd->diff_src_md(0)), number_vmm_single_compute)
32     , pd_(pd)
33     , src_dt_(pd->src_md(0)->data_type)
34     , wei_dt_(pd->weights_md(0)->data_type)
35     , diff_src_dt_(pd->diff_src_md(0)->data_type)
36     , diff_dst_dt_(pd->diff_dst_md(0)->data_type)
37     , diff_wei_dt_(bcast_ == prelu::bcast::full
38                       ? pd->diff_weights_md(0)->data_type
39                       : data_type::f32)
40     , diff_src_block_tail_(prelu::get_block_tail_size(pd->diff_src_md(0)))
41     , diff_wei_block_tail_(prelu::get_block_tail_size(pd->diff_weights_md(0))) {
42 }
43 
44 #define PARAM_OFF(x) offsetof(call_params_t, x)
45 
load_kernel_call_params()46 void jit_prelu_backward_kernel_t::load_kernel_call_params() {
47     mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]);
48     mov(reg_weights_, ptr[abi_param1 + PARAM_OFF(weights)]);
49     mov(reg_src_diff_, ptr[abi_param1 + PARAM_OFF(src_diff)]);
50     mov(reg_weights_diff_, ptr[abi_param1 + PARAM_OFF(weights_diff)]);
51     mov(reg_dst_diff_, ptr[abi_param1 + PARAM_OFF(dst_diff)]);
52     mov(reg_data_size_, ptr[abi_param1 + PARAM_OFF(compute_data_size)]);
53 }
54 
55 #undef PARAM_OFF
56 
data_ptr(int arg_num,size_t offt)57 Xbyak::Address jit_prelu_backward_kernel_t::data_ptr(int arg_num, size_t offt) {
58     const auto get_addr
59             = [&](const Xbyak::Reg64 &reg_base, const data_type_t dt) {
60                   const auto dt_size = types::data_type_size(dt);
61                   return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size];
62               };
63 
64     switch (arg_num) {
65         case DNNL_ARG_SRC: return get_addr(reg_src_, src_dt_);
66         case DNNL_ARG_WEIGHTS: return get_addr(reg_weights_, wei_dt_);
67         case DNNL_ARG_DIFF_SRC: return get_addr(reg_src_diff_, diff_src_dt_);
68         case DNNL_ARG_DIFF_WEIGHTS:
69             return get_addr(reg_weights_diff_, diff_wei_dt_);
70         case DNNL_ARG_DIFF_DST: return get_addr(reg_dst_diff_, diff_dst_dt_);
71 
72         default: assert(!"unsupported arg_num"); break;
73     }
74     return Xbyak::Address(0);
75 }
76 
any_tensor_bf16() const77 bool jit_prelu_backward_kernel_t::any_tensor_bf16() const {
78     return utils::one_of(data_type::bf16, src_dt_, wei_dt_, diff_src_dt_,
79             diff_dst_dt_, diff_wei_dt_);
80 }
81 
82 template <typename Vmm>
jit_uni_prelu_backward_kernel_t(const cpu_prelu_bwd_pd_t * pd,const cpu_isa_t & isa)83 jit_uni_prelu_backward_kernel_t<Vmm>::jit_uni_prelu_backward_kernel_t(
84         const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa)
85     : jit_prelu_backward_kernel_t(pd, isa, prelu::vmm_traits_t<Vmm>::vlen,
86             std::is_same<Vmm, Xbyak::Zmm>::value ? 4u : 6u)
87     , saturation_needed_diff_src_(utils::one_of(
88               diff_src_dt_, data_type::u8, data_type::s8, data_type::s32))
89     , saturation_needed_diff_weights_(utils::one_of(
90               diff_wei_dt_, data_type::u8, data_type::s8, data_type::s32))
91     , vmm_zeros_(reserve_vmm())
92     , saturation_ubound_diff_src_(
93               saturation_needed_diff_src_ ? reserve_vmm() : 0)
94     , saturation_ubound_diff_weights_(saturation_needed_diff_weights_
95                       ? (diff_wei_dt_ == diff_src_dt_
96                                       ? saturation_ubound_diff_src_.getIdx()
97                                       : reserve_vmm())
98                       : 0)
99     , tail_vmm_mask_(
100               tail_size_ && utils::one_of(isa, avx, avx2) ? reserve_vmm() : 0)
101     , vmm_ones_(reserve_vmm())
102     , weights_const_vmm_(utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
103                                  prelu::bcast::per_oc_blocked)
104                       ? reserve_vmm()
105                       : 0)
106     , weights_diff_acc_vmm_(
107               utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
108                       prelu::bcast::per_oc_blocked)
109                       ? reserve_vmm()
110                       : 0)
111     , io_(this, isa,
112               {src_dt_, wei_dt_, diff_src_dt_, diff_dst_dt_, diff_wei_dt_}, {},
113               io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_,
114                       tail_vmm_mask_.getIdx(), reg_tmp_},
115               io::io_emu_bf16_conf_t {}, create_saturation_vmm_map()) {}
116 
117 template <typename Vmm>
118 jit_uni_prelu_backward_kernel_t<Vmm>::~jit_uni_prelu_backward_kernel_t()
119         = default;
120 
121 template <typename Vmm>
prepare_kernel_const_vars()122 void jit_uni_prelu_backward_kernel_t<Vmm>::prepare_kernel_const_vars() {
123     uni_vxorps(vmm_zeros_, vmm_zeros_, vmm_zeros_);
124 
125     io_.init_bf16();
126     if (tail_size_) io_.prepare_tail_mask();
127     if (saturation_needed_diff_src_ || saturation_needed_diff_weights_) {
128         io_.init_saturate_f32({diff_src_dt_, diff_wei_dt_});
129     }
130     // load ones
131     this->mov(this->reg_tmp_, float2int(1));
132     const Xbyak::Xmm xmm_ones_ {vmm_ones_.getIdx()};
133     this->uni_vmovq(xmm_ones_, this->reg_tmp_);
134     this->uni_vbroadcastss(vmm_ones_, xmm_ones_);
135 
136     if (bcast_ == prelu::bcast::per_oc_blocked) {
137         io_.at(wei_dt_)->load(
138                 ptr[reg_weights_], weights_const_vmm_, false /*tail*/);
139         vmovups(weights_diff_acc_vmm_, ptr[reg_weights_diff_]);
140     } else if (bcast_ == prelu::bcast::per_oc_n_c_spatial) {
141         io_.at(wei_dt_)->broadcast(ptr[reg_weights_], weights_const_vmm_);
142         uni_vxorps(weights_diff_acc_vmm_, weights_diff_acc_vmm_,
143                 weights_diff_acc_vmm_);
144         uni_vmovss(weights_diff_acc_vmm_, ptr[reg_weights_diff_]);
145     }
146 }
147 
148 template <typename Vmm>
compute_dst(size_t unrolling_factor,bool tail)149 void jit_uni_prelu_backward_kernel_t<Vmm>::compute_dst(
150         size_t unrolling_factor, bool tail) {
151 
152     static constexpr size_t dst_diff_idx = 0;
153     static constexpr size_t src_idx = 1;
154     static constexpr size_t src_le_zero_idx = 2;
155     static constexpr size_t src_gt_zero_idx = 3;
156     static constexpr size_t weights_diff_idx = 4;
157     static constexpr size_t weights_idx = 5;
158 
159     for (size_t unroll_group = 0; unroll_group < unrolling_factor;
160             ++unroll_group) {
161 
162         const Vmm dst_diff_vmm {get_compute_vmm(dst_diff_idx, unroll_group)};
163         const Vmm src_vmm {get_compute_vmm(src_idx, unroll_group)};
164         const Vmm src_le_zero_vmm {
165                 get_compute_vmm(src_le_zero_idx, unroll_group)};
166         const Vmm src_gt_zero_vmm {
167                 get_compute_vmm(src_gt_zero_idx, unroll_group)};
168         const Vmm weights_diff_vmm {
169                 get_compute_vmm(weights_diff_idx, unroll_group)};
170         const Vmm weights_vmm {get_compute_vmm(weights_idx, unroll_group)};
171 
172         const auto offset = unroll_group * simd_w_;
173         io_.at(diff_dst_dt_)
174                 ->load(data_ptr(DNNL_ARG_DIFF_DST, offset), dst_diff_vmm, tail);
175         io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail);
176         static constexpr int VCMPLEPS = 2;
177         uni_vcmpps(src_le_zero_vmm, src_vmm, vmm_zeros_, VCMPLEPS);
178         uni_vandps(src_le_zero_vmm, src_le_zero_vmm, vmm_ones_);
179         static constexpr int VCMPGTPS = 14;
180         uni_vcmpps(src_gt_zero_vmm, src_vmm, vmm_zeros_, VCMPGTPS);
181         uni_vandps(src_gt_zero_vmm, src_gt_zero_vmm, vmm_ones_);
182 
183         //weights_diff_calculations
184         uni_vmulps(weights_diff_vmm, dst_diff_vmm, src_vmm);
185         uni_vmulps(weights_diff_vmm, weights_diff_vmm, src_le_zero_vmm);
186 
187         //src_diff calculations
188         const auto weights_operand = get_or_load_weights(
189                 data_ptr(DNNL_ARG_WEIGHTS, offset), weights_vmm, tail);
190         uni_vfmadd231ps(src_gt_zero_vmm, src_le_zero_vmm, weights_operand);
191         const auto &src_diff_vmm = src_gt_zero_vmm;
192         uni_vmulps(src_diff_vmm, src_diff_vmm, dst_diff_vmm);
193         io_.at(diff_src_dt_)
194                 ->store(src_diff_vmm, data_ptr(DNNL_ARG_DIFF_SRC, offset),
195                         tail);
196         if (diff_src_block_tail_ && tail)
197             prelu::apply_zero_padding(this, tail_size_, diff_src_dt_,
198                     diff_src_block_tail_, reg_src_diff_, nullptr);
199 
200         accumulate_weights_diff(weights_diff_vmm, src_gt_zero_vmm,
201                 data_ptr(DNNL_ARG_DIFF_WEIGHTS, offset), tail);
202     }
203 }
204 
205 template <>
compute_dst(size_t unrolling_factor,bool tail)206 void jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>::compute_dst(
207         size_t unrolling_factor, bool tail) {
208 
209     size_t opmask_counter = 2;
210     auto get_next_opmask = [opmask_counter]() mutable {
211         static constexpr size_t opmask_range_begin = 2;
212         static constexpr size_t opmask_range_end = 8;
213         const auto opmask = Xbyak::Opmask(opmask_counter++);
214         if (opmask_counter == opmask_range_end)
215             opmask_counter = opmask_range_begin;
216         return opmask;
217     };
218 
219     static constexpr size_t dst_diff_idx = 0;
220     static constexpr size_t src_idx = 1;
221     static constexpr size_t weights_diff_idx = 2;
222     static constexpr size_t weights_idx = 3;
223 
224     for (size_t unroll_group = 0; unroll_group < unrolling_factor;
225             ++unroll_group) {
226 
227         const auto offset = unroll_group * simd_w_;
228         const Xbyak::Zmm dst_diff_vmm {
229                 get_compute_vmm(dst_diff_idx, unroll_group)};
230         const Xbyak::Zmm src_vmm {get_compute_vmm(src_idx, unroll_group)};
231 
232         io_.at(diff_dst_dt_)
233                 ->load(data_ptr(DNNL_ARG_DIFF_DST, offset), dst_diff_vmm, tail);
234         io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail);
235 
236         const Xbyak::Opmask src_le_zero_opmask = get_next_opmask();
237         static constexpr int VCMPLEPS = 2;
238         vcmpps(src_le_zero_opmask, src_vmm, vmm_zeros_, VCMPLEPS);
239         const Xbyak::Opmask src_gt_zero_vmm_opmask = get_next_opmask();
240         static constexpr int VCMPGTPS = 14;
241         vcmpps(src_gt_zero_vmm_opmask, src_vmm, vmm_zeros_, VCMPGTPS);
242 
243         // //weights_diff_calculations
244         const Xbyak::Zmm weights_diff_vmm {
245                 get_compute_vmm(weights_diff_idx, unroll_group)};
246         vmulps(weights_diff_vmm | src_le_zero_opmask | T_z, dst_diff_vmm,
247                 src_vmm);
248         accumulate_weights_diff(weights_diff_vmm, weights_diff_acc_vmm_,
249                 data_ptr(DNNL_ARG_DIFF_WEIGHTS, offset), tail);
250 
251         //src_diff calculations
252         const Xbyak::Zmm weights_vmm {
253                 get_compute_vmm(weights_idx, unroll_group)};
254         const auto &src_diff_vmm = weights_vmm;
255         const auto weights_operand = get_or_load_weights(
256                 data_ptr(DNNL_ARG_WEIGHTS, offset), weights_vmm, tail);
257 
258         vmovaps(src_diff_vmm | src_le_zero_opmask | T_z, weights_operand);
259         vaddps(src_diff_vmm | src_gt_zero_vmm_opmask, src_diff_vmm, vmm_ones_);
260         vmulps(src_diff_vmm, src_diff_vmm, dst_diff_vmm);
261         io_.at(diff_src_dt_)
262                 ->store(src_diff_vmm, data_ptr(DNNL_ARG_DIFF_SRC, offset),
263                         tail);
264         if (diff_src_block_tail_ && tail)
265             prelu::apply_zero_padding(this, tail_size_, diff_src_dt_,
266                     diff_src_block_tail_, reg_src_diff_, nullptr);
267     }
268 }
269 
270 template <typename Vmm>
accumulate_weights_diff(const Vmm & partial_sum_vmm,const Vmm & tmp_vmm,const Xbyak::Address & dst_addr,bool tail)271 void jit_uni_prelu_backward_kernel_t<Vmm>::accumulate_weights_diff(
272         const Vmm &partial_sum_vmm, const Vmm &tmp_vmm,
273         const Xbyak::Address &dst_addr, bool tail) {
274 
275     if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
276                 prelu::bcast::per_oc_blocked)) {
277         uni_vaddps(
278                 weights_diff_acc_vmm_, weights_diff_acc_vmm_, partial_sum_vmm);
279     } else if (bcast_ == prelu::bcast::per_oc_n_spatial_c) {
280         if (std::is_same<Vmm, Xbyak::Zmm>::value || isa_ == avx2)
281             uni_vaddps(partial_sum_vmm, partial_sum_vmm, dst_addr);
282         else {
283             uni_vmovups(tmp_vmm, dst_addr);
284             uni_vaddps(partial_sum_vmm, partial_sum_vmm, tmp_vmm);
285         }
286         uni_vmovups(dst_addr, partial_sum_vmm);
287     } else {
288         io_.at(diff_wei_dt_)->store(partial_sum_vmm, dst_addr, tail);
289         if (diff_wei_block_tail_ && tail)
290             prelu::apply_zero_padding(this, tail_size_, diff_wei_dt_,
291                     diff_wei_block_tail_, reg_weights_diff_, nullptr);
292     }
293 }
294 
295 template <typename Vmm>
get_or_load_weights(const Xbyak::Address & src_addr,const Vmm & weights_vmm,bool tail)296 const Xbyak::Operand &jit_uni_prelu_backward_kernel_t<Vmm>::get_or_load_weights(
297         const Xbyak::Address &src_addr, const Vmm &weights_vmm, bool tail) {
298 
299     if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
300                 prelu::bcast::per_oc_blocked))
301         return weights_const_vmm_;
302 
303     io_.at(wei_dt_)->load(src_addr, weights_vmm, tail);
304     return weights_vmm;
305 }
306 
reduce(jit_generator * host,const Xbyak::Xmm & src,const Xbyak::Xmm & helper,const cpu_isa_t & isa)307 static void reduce(jit_generator *host, const Xbyak::Xmm &src,
308         const Xbyak::Xmm &helper, const cpu_isa_t &isa) {
309     UNUSED(helper);
310     if (isa == sse41) {
311         host->haddps(src, src);
312         host->haddps(src, src);
313     } else {
314         host->vhaddps(src, src, src);
315         host->vhaddps(src, src, src);
316     }
317 }
318 
reduce(jit_generator * host,const Xbyak::Ymm & src,const Xbyak::Ymm & helper,const cpu_isa_t & isa)319 static void reduce(jit_generator *host, const Xbyak::Ymm &src,
320         const Xbyak::Ymm &helper, const cpu_isa_t &isa) {
321     const Xbyak::Xmm xmm_helper {helper.getIdx()};
322     const Xbyak::Xmm xmm_src {src.getIdx()};
323 
324     host->vextractf128(xmm_helper, src, 1);
325     host->vaddps(xmm_src, xmm_src, xmm_helper);
326     reduce(host, xmm_src, xmm_helper, isa);
327 }
328 
reduce(jit_generator * host,const Xbyak::Zmm & src,const Xbyak::Zmm & helper,const cpu_isa_t & isa)329 static void reduce(jit_generator *host, const Xbyak::Zmm &src,
330         const Xbyak::Zmm &helper, const cpu_isa_t &isa) {
331     const Xbyak::Ymm ymm_helper {helper.getIdx()};
332     const Xbyak::Ymm ymm_src {src.getIdx()};
333 
334     host->vextractf64x4(ymm_helper, src, 1);
335     host->vaddps(ymm_src, ymm_src, ymm_helper);
336     reduce(host, ymm_src, ymm_helper, isa);
337 }
338 
339 template <typename Vmm>
finalize()340 void jit_uni_prelu_backward_kernel_t<Vmm>::finalize() {
341     if (bcast_ == prelu::bcast::per_oc_blocked)
342         uni_vmovups(ptr[reg_weights_diff_], weights_diff_acc_vmm_);
343     else if (bcast_ == prelu::bcast::per_oc_n_c_spatial) {
344         reduce(this, weights_diff_acc_vmm_, weights_const_vmm_, isa_);
345         uni_vmovss(ptr[reg_weights_diff_], weights_diff_acc_vmm_);
346     }
347 }
348 
349 template <typename Vmm>
350 std::map<data_type_t, io::io_saturation_conf_t>
create_saturation_vmm_map() const351 jit_uni_prelu_backward_kernel_t<Vmm>::create_saturation_vmm_map() const {
352 
353     std::map<data_type_t, io::io_saturation_conf_t> saturation_map {};
354 
355     if (saturation_needed_diff_src_)
356         saturation_map.emplace(diff_src_dt_,
357                 io::io_saturation_conf_t {vmm_zeros_.getIdx(),
358                         saturation_ubound_diff_src_.getIdx(), reg_tmp_});
359 
360     if (saturation_needed_diff_weights_ && diff_src_dt_ != diff_wei_dt_)
361         saturation_map.emplace(diff_wei_dt_,
362                 io::io_saturation_conf_t {vmm_zeros_.getIdx(),
363                         saturation_ubound_diff_weights_.getIdx(), reg_tmp_});
364 
365     return saturation_map;
366 }
367 
create(const cpu_prelu_bwd_pd_t * pd)368 jit_prelu_backward_kernel_t *jit_prelu_backward_kernel_t::create(
369         const cpu_prelu_bwd_pd_t *pd) {
370 
371     const auto isa = prelu::get_supported_isa();
372 
373     const auto &src_dt = pd->src_md(0)->data_type;
374     const auto &wei_dt = pd->weights_md(0)->data_type;
375     const auto &diff_src_dt = pd->diff_src_md(0)->data_type;
376     const auto &diff_dst_dt = pd->diff_dst_md(0)->data_type;
377     const auto &diff_wei_dt = pd->diff_weights_md(0)->data_type;
378 
379     if (is_superset(isa, avx512_common))
380         return new jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>(pd, isa);
381     else if (is_superset(isa, avx)) {
382         if (isa == avx
383                 && prelu::is_s8u8({src_dt, wei_dt, diff_src_dt, diff_dst_dt,
384                         diff_wei_dt}))
385             return new jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>(pd, isa);
386         else
387             return new jit_uni_prelu_backward_kernel_t<Xbyak::Ymm>(pd, isa);
388     } else if (isa == sse41)
389         return new jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>(pd, isa);
390 
391     return nullptr;
392 }
393 
394 template class jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>;
395 template class jit_uni_prelu_backward_kernel_t<Xbyak::Ymm>;
396 template class jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>;
397 
398 } // namespace x64
399 } // namespace cpu
400 } // namespace impl
401 } // namespace dnnl
402