1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_REF_FUSED_CONVOLUTION_HPP
18 #define CPU_REF_FUSED_CONVOLUTION_HPP
19 
20 #include "common/primitive.hpp"
21 #include "common/primitive_iterator.hpp"
22 #include "common/reorder.hpp"
23 #include "common/stream.hpp"
24 
25 #include "cpu/cpu_convolution_pd.hpp"
26 #include "cpu/dw_convolution_utils.hpp"
27 
28 namespace dnnl {
29 namespace impl {
30 namespace cpu {
31 
32 struct ref_fused_convolution_fwd_t : public primitive_t {
33 
34     struct arg_cache_t {
35         struct arg_info_t {
36             int op_arg;
37             bool is_ctx_arg;
38             bool is_const;
39             union {
40                 size_t offset;
41                 int ctx_arg;
42             };
43             memory_desc_t md;
44         };
45 
append_ctx_argdnnl::impl::cpu::ref_fused_convolution_fwd_t::arg_cache_t46         void append_ctx_arg(int op_arg, int ctx_arg) {
47             arg_info_t arg_info;
48             arg_info.op_arg = op_arg;
49             arg_info.is_ctx_arg = true;
50             arg_info.is_const = false; // unused
51             arg_info.ctx_arg = ctx_arg;
52             arg_info.md = glob_zero_md;
53             info_.push_back(arg_info);
54         }
55 
append_inout_argdnnl::impl::cpu::ref_fused_convolution_fwd_t::arg_cache_t56         void append_inout_arg(int arg, size_t offset, const memory_desc_t *md,
57                 bool is_const) {
58             arg_info_t arg_info;
59             arg_info.op_arg = arg;
60             arg_info.is_ctx_arg = false;
61             arg_info.is_const = is_const;
62             arg_info.offset = offset;
63             arg_info.md = *md;
64             info_.push_back(arg_info);
65         }
66 
append_ctx_argdnnl::impl::cpu::ref_fused_convolution_fwd_t::arg_cache_t67         void append_ctx_arg(int arg) { append_ctx_arg(arg, arg); }
68 
infodnnl::impl::cpu::ref_fused_convolution_fwd_t::arg_cache_t69         const std::vector<arg_info_t> &info() const { return info_; }
70 
71     private:
72         std::vector<arg_info_t> info_;
73     };
74 
75     struct pd_t : public cpu_convolution_fwd_pd_t {
pd_tdnnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t76         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
77                 const typename pd_t::base_class *hint_fwd_pd)
78             : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {
79             name_ = "ref_fused_convolution:any";
80         }
81 
82         pd_t(const pd_t &other) = default;
83 
84         DECLARE_COMMON_PD_T(name_.c_str(), ref_fused_convolution_fwd_t);
85 
initdnnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t86         virtual status_t init(engine_t *engine) {
87             bool ok = true && is_fwd()
88                     && (attr()->post_ops_.find(primitive_kind::sum) == -1);
89 
90             if (!ok) return status::unimplemented;
91 
92             CHECK(init_ops(engine));
93             init_name();
94             return status::success;
95         }
96 
src_mddnnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t97         const memory_desc_t *src_md(int index = 0) const override {
98             return op_pds_.front()->src_md(index);
99         }
100 
dst_mddnnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t101         const memory_desc_t *dst_md(int index = 0) const override {
102             return op_pds_.back()->dst_md(index);
103         }
104 
weights_mddnnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t105         const memory_desc_t *weights_md(int index = 0) const override {
106             return op_pds_.front()->weights_md(index); // for now
107         }
108 
arg_mddnnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t109         const memory_desc_t *arg_md(int index = 0) const override {
110             switch (index) { // for now
111                 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
112                     return op_pds_.back()->weights_md(0);
113                 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
114                     return op_pds_.back()->weights_md(1);
115                 default: return convolution_fwd_pd_t::arg_md(index);
116             }
117         }
118 
arg_usagednnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t119         arg_usage_t arg_usage(int arg) const override {
120             if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
121                 return arg_usage_t::input;
122 
123             if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
124                     && attr_post_op_dw_inputs() > 1)
125                 return arg_usage_t::input;
126 
127             return convolution_fwd_pd_t::arg_usage(arg);
128         }
129 
130         size_t user_scratchpad_size_;
131         std::vector<std::shared_ptr<primitive_desc_t>> op_pds_;
132         std::vector<arg_cache_t> args_;
133 
134     private:
135         std::string name_;
136         const unsigned int max_fusions_ = 1;
137 
append_opdnnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t138         status_t append_op(std::shared_ptr<primitive_desc_t> &op_pd,
139                 size_t &sp_begin, size_t &sp_end, engine_t *engine) {
140             auto from_md = op_pds_.back()->dst_md();
141             auto to_md = op_pd->src_md();
142 
143             if (*from_md != *to_md) {
144                 //TODO: Find a test-case for this
145                 std::shared_ptr<primitive_desc_t> pd;
146                 CHECK(reorder_primitive_desc_create(
147                         pd, engine, from_md, to_md));
148                 op_pds_.emplace_back(std::move(pd));
149 
150                 arg_cache_t arg_cache;
151                 arg_cache.append_inout_arg(
152                         DNNL_ARG_FROM, sp_begin, from_md, true);
153                 arg_cache.append_inout_arg(DNNL_ARG_TO, sp_end, to_md, false);
154                 args_.push_back(arg_cache);
155 
156                 // Increment scratchpad offsets
157                 sp_begin = sp_end;
158                 sp_end += memory_desc_wrapper(to_md).size();
159 
160                 user_scratchpad_size_ = nstl::max<size_t>(user_scratchpad_size_,
161                         op_pds_.back()->scratchpad_size(
162                                 attr()->scratchpad_mode_));
163             }
164 
165             op_pds_.emplace_back(std::move(op_pd));
166             user_scratchpad_size_ = nstl::max<size_t>(user_scratchpad_size_,
167                     op_pds_.back()->scratchpad_size(attr()->scratchpad_mode_));
168             return status::success;
169         }
170 
init_opsdnnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t171         status_t init_ops(engine_t *engine) {
172             using namespace data_type;
173             primitive_attr_t root_attr(*attr());
174             if (!root_attr.is_initialized()) return status::out_of_memory;
175             auto po_op_iter
176                     = attr()->post_ops_.find(primitive_kind::convolution);
177             if (po_op_iter == -1) return status::unimplemented;
178 
179             primitive_attr_t attr_1x1(*attr());
180             // erase post-ops after fusion as they will be handled separately
181             auto &e = attr_1x1.post_ops_.entry_;
182             e.erase(e.begin() + po_op_iter, e.end());
183 
184             dnnl_primitive_desc_iterator it(
185                     engine, op_desc(), &attr_1x1, nullptr);
186             if (!it.is_initialized()) return status::out_of_memory;
187             std::shared_ptr<primitive_desc_t> root_pd = *(++it);
188             if (!root_pd) return status::unimplemented;
189             op_pds_.emplace_back(root_pd);
190             // Scratchpad offsets. Simulate offset computation so that offset
191             // computation can be avoided during execution.
192             size_t inout_sp_offset_begin = 0;
193             size_t inout_sp_offset_end = 0;
194             user_scratchpad_size_
195                     = root_pd->scratchpad_size(attr()->scratchpad_mode_);
196 
197             // Create arg cache for the root pd
198             arg_cache_t arg_cache;
199             arg_cache.append_ctx_arg(DNNL_ARG_SRC);
200             arg_cache.append_ctx_arg(DNNL_ARG_WEIGHTS);
201             if (desc()->bias_desc.data_type != data_type::undef)
202                 arg_cache.append_ctx_arg(DNNL_ARG_BIAS);
203             arg_cache.append_inout_arg(DNNL_ARG_DST, inout_sp_offset_end,
204                     root_pd->dst_md(), false);
205             for (int idx = 0; idx < attr_1x1.post_ops_.len(); ++idx) {
206                 if (attr_1x1.post_ops_.contain(primitive_kind::binary, idx))
207                     arg_cache.append_ctx_arg(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
208                             | DNNL_ARG_SRC_1);
209             }
210             args_.push_back(arg_cache);
211 
212             // Increment scratchpad offsets
213             inout_sp_offset_begin = inout_sp_offset_end;
214             inout_sp_offset_end
215                     += memory_desc_wrapper(root_pd->dst_md()).size();
216 
217             const auto &po = attr()->post_ops_;
218             const auto &end = po.len();
219 
220             unsigned int fusion_ops = 0;
221             // Loop through the post-ops until we reach the end
222             // (if we have more than one op to fuse later)
223             while (po_op_iter < end) {
224                 if (fusion_ops++ > max_fusions_) return status::unimplemented;
225 
226                 const auto &prev_op_pd = op_pds_.back();
227 
228                 if (po.entry_[po_op_iter].kind != primitive_kind::convolution)
229                     return status::unimplemented;
230 
231                 if (prev_op_pd->kind() != primitive_kind::convolution)
232                     return status::unimplemented;
233 
234                 auto conv_pd = reinterpret_cast<convolution_pd_t *>(
235                         prev_op_pd.get());
236                 bool ok = true && is_fwd()
237                         && utils::everyone_is(
238                                 1, conv_pd->KD(), conv_pd->KH(), conv_pd->KW());
239                 if (!ok) return status::unimplemented;
240 
241                 convolution_desc_t cd_dw;
242                 primitive_attr_t attr_dw;
243                 CHECK(get_depthwise_conv_desc(cd_dw, *(conv_pd->dst_md()),
244                         root_attr, attr_dw, po_op_iter));
245                 dnnl_primitive_desc_iterator it(
246                         engine, (op_desc_t *)&cd_dw, &attr_dw, nullptr);
247                 if (!it.is_initialized()) return status::out_of_memory;
248 
249                 std::shared_ptr<primitive_desc_t> append_conv_pd = *(++it);
250                 if (!append_conv_pd) return status::unimplemented;
251 
252                 CHECK(append_op(append_conv_pd, inout_sp_offset_begin,
253                         inout_sp_offset_end, engine));
254 
255                 const auto &op = op_pds_.back();
256                 arg_cache_t arg_cache;
257                 arg_cache.append_inout_arg(DNNL_ARG_SRC, inout_sp_offset_begin,
258                         op->src_md(), true);
259                 arg_cache.append_ctx_arg(DNNL_ARG_DST);
260                 arg_cache.append_ctx_arg(DNNL_ARG_WEIGHTS,
261                         DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
262                 if (op->weights_md(1)->data_type != data_type::undef)
263                     arg_cache.append_ctx_arg(DNNL_ARG_BIAS,
264                             DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
265                 for (int idx = 0; idx < attr_dw.post_ops_.len(); ++idx) {
266                     if (attr_dw.post_ops_.contain(primitive_kind::binary, idx))
267                         arg_cache.append_ctx_arg(
268                                 (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
269                                         | DNNL_ARG_SRC_1),
270                                 (DNNL_ARG_ATTR_MULTIPLE_POST_OP(
271                                          idx + po_op_iter + 1)
272                                         | DNNL_ARG_SRC_1));
273                 }
274 
275                 args_.push_back(arg_cache);
276 
277                 while (++po_op_iter < end) {
278                     if (utils::one_of(po.entry_[po_op_iter].kind,
279                                 primitive_kind::convolution))
280                         break;
281                 }
282             }
283 
284             assert(!op_pds_.empty());
285 
286             CHECK(init_scratchpad_memory(inout_sp_offset_end));
287 
288             return status::success;
289         }
290 
init_scratchpad_memorydnnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t291         status_t init_scratchpad_memory(size_t inout_buffer_size) {
292 
293             auto scratchpad = scratchpad_registry().registrar();
294 
295             scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
296                     inout_buffer_size, 1, 16);
297             scratchpad.book(
298                     memory_tracking::names::key_fusion_forward_scratchpad,
299                     user_scratchpad_size_, 1, 16);
300             return status::success;
301         }
302 
init_namednnl::impl::cpu::ref_fused_convolution_fwd_t::pd_t303         void init_name() {
304             for (const auto &op_pd : op_pds_) {
305                 name_.append(":");
306                 name_.append(op_pd->name());
307             }
308             return;
309         }
310     };
311 
ref_fused_convolution_fwd_tdnnl::impl::cpu::ref_fused_convolution_fwd_t312     ref_fused_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
313 
initdnnl::impl::cpu::ref_fused_convolution_fwd_t314     status_t init(engine_t *engine) override {
315         const auto &op_pds = pd()->op_pds_;
316         for (auto &op_pd : op_pds) {
317             std::shared_ptr<primitive_t> p;
318             op_pd->create_primitive(p, engine);
319             primitives_.emplace_back(p);
320         }
321         return status::success;
322     }
323 
executednnl::impl::cpu::ref_fused_convolution_fwd_t324     status_t execute(const exec_ctx_t &ctx) const override {
325         engine_t *engine = ctx.stream()->engine();
326         const auto scratchpad = ctx.get_scratchpad_grantor();
327 
328         const auto inout_buffer = scratchpad.get_memory_storage(
329                 memory_tracking::names::key_fusion_inout_buffer);
330 
331         const auto &ctx_args = ctx.args();
332         const auto op_count = primitives_.size();
333         std::vector<std::unique_ptr<memory_t>> inout_memory;
334 
335         for (size_t i = 0; i < op_count; ++i) {
336             const auto &op = primitives_[i];
337             const auto &arg_cache = pd()->args_[i];
338 
339             exec_args_t exec_args;
340 
341             for (const auto &arg_info : arg_cache.info()) {
342                 if (arg_info.is_ctx_arg) {
343                     exec_args[arg_info.op_arg] = ctx_args.at(arg_info.ctx_arg);
344                 } else {
345                     inout_memory.emplace_back(new memory_t(engine, &arg_info.md,
346                             inout_buffer->get_sub_storage(arg_info.offset,
347                                     memory_desc_wrapper(arg_info.md).size())));
348                     exec_args[arg_info.op_arg].mem = inout_memory.back().get();
349                     exec_args[arg_info.op_arg].is_const = arg_info.is_const;
350                 }
351             }
352 
353             exec_ctx_t op_ctx(ctx, std::move(exec_args));
354 
355             nested_scratchpad_t ns(ctx,
356                     memory_tracking::names::key_fusion_forward_scratchpad, op);
357             op_ctx.set_scratchpad_grantor(ns.grantor());
358             CHECK(op->execute(op_ctx));
359         }
360 
361         return status::success;
362     }
363 
364 private:
pddnnl::impl::cpu::ref_fused_convolution_fwd_t365     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
366     std::vector<std::shared_ptr<primitive_t>> primitives_;
367 };
368 
369 } // namespace cpu
370 } // namespace impl
371 } // namespace dnnl
372 
373 #endif
374 
375 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
376