1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef RNN_HPP
18 #define RNN_HPP
19 
20 #include <assert.h>
21 #include <limits.h>
22 #include <stdint.h>
23 
24 #include <string>
25 #include <vector>
26 
27 #include "common.hpp"
28 #include "dnn_types.hpp"
29 #include "dnnl_common.hpp"
30 #include "dnnl_debug.hpp"
31 #include "dnnl_memory.hpp"
32 #include "perf_report.hpp"
33 
34 #define AOC array_offset_calculator
35 
36 namespace rnn {
37 
38 enum alg_t { VANILLA_RNN, VANILLA_LSTM, VANILLA_GRU, LBR_GRU };
39 alg_t str2alg(const char *str);
40 const char *alg2str(alg_t alg);
41 dnnl_alg_kind_t alg2kind(alg_t alg);
42 
43 enum activation_t { UNDEF, RELU, LOGISTIC, TANH };
44 activation_t str2activation(const char *str);
45 const char *activation2str(activation_t alg);
46 dnnl_alg_kind_t activation2kind(activation_t alg);
47 
48 dnnl_rnn_direction_t str2direction(const char *str);
49 const char *direction2str(dnnl_rnn_direction_t direction);
50 
51 enum data_kind_t {
52     SRC_LAYER,
53     SRC_ITER,
54     SRC_ITER_C,
55     WEIGHTS_LAYER,
56     WEIGHTS_ITER,
57     BIAS,
58     DST_ITER,
59     DST_ITER_C,
60     DST_LAYER,
61 
62     DIFF_SRC_LAYER,
63     DIFF_SRC_ITER,
64     DIFF_SRC_ITER_C,
65     DIFF_WEIGHTS_LAYER,
66     DIFF_WEIGHTS_ITER,
67     DIFF_BIAS,
68     DIFF_DST_ITER,
69     DIFF_DST_ITER_C,
70     DIFF_DST_LAYER,
71 
72     // FIXME: adding peephole related weights to the appropriate places will
73     // cause false-positive accuracy check failures in unrelated test cases
74     // (e.g.  backward vanilla RNN for bf16) due to the data fill seed being
75     // dependent on the position of the tensor kind in the enum: adding
76     // `WEIGHTS_PEEPHOLE` before `dst_*` and `*diff_*` results in initializing
77     // the corresponding tensors differently.
78     // We need a more robust way of testing RNN.
79     WEIGHTS_PEEPHOLE,
80     DIFF_WEIGHTS_PEEPHOLE,
81     WEIGHTS_PROJECTION,
82     DIFF_WEIGHTS_PROJECTION,
83 };
84 const char *data_kind2str(data_kind_t kind);
85 
86 // Gates indices
87 enum {
88     LSTM_I = 0,
89     LSTM_F = 1,
90     LSTM_C = 2,
91     LSTM_O = 3,
92     GRU_U = 0,
93     GRU_R = 1,
94     GRU_O = 2,
95     LBR_GRU_U_PRIME = 3,
96 };
97 
98 // dlc is different at the cell level and the primitive level
99 // This enum enable to explicitely query the intended one
100 enum dlc_type_t { CELL, PRIMITIVE };
101 
102 template <typename Telem>
103 struct array_offset_calculator {
104     array_offset_calculator() = default;
105 
106     template <typename... Targs>
array_offset_calculatorrnn::array_offset_calculator107     array_offset_calculator(Telem *base_ptr, Targs... dims)
108         : base_ptr_(base_ptr), dims_({dims...}) {}
109 
110     // ctor for AOC<const T> based on const AOC<T> &
111     template <typename Uelem>
array_offset_calculatorrnn::array_offset_calculator112     array_offset_calculator(const array_offset_calculator<Uelem> &rhs)
113         : base_ptr_(rhs.base_ptr_), dims_(rhs.dims_) {}
114 
115     // to make the above ctor work AOC<const T> should be able to access
116     // private fields of AOC<T>, hence let's friend them
117     friend struct array_offset_calculator<const Telem>;
118 
119     template <typename... Targs>
operator ()rnn::array_offset_calculator120     Telem &operator()(Targs... Fargs) const {
121         return *(base_ptr_ + offset(1, Fargs...));
122     }
123 
nelemsrnn::array_offset_calculator124     int64_t nelems() const {
125         int64_t res = 1;
126         for (auto dim : dims_)
127             res *= dim;
128         return res;
129     }
130 
set_base_ptrrnn::array_offset_calculator131     void set_base_ptr(Telem *base_ptr) { base_ptr_ = base_ptr; }
132 
133 private:
134     template <typename... Targs>
offsetrnn::array_offset_calculator135     int64_t offset(int64_t d, int64_t pos) const {
136         return pos;
137     }
138 
139     template <typename... Targs>
offsetrnn::array_offset_calculator140     int64_t offset(int64_t d, int64_t off, int64_t pos) const {
141         return off * dims_[d] + pos;
142     }
143 
144     template <typename... Targs>
offsetrnn::array_offset_calculator145     int64_t offset(int64_t d, int64_t off, int64_t pos, Targs... rem) const {
146         return offset(d + 1, off * dims_[d] + pos, rem...);
147     }
148 
149     Telem *base_ptr_;
150     std::vector<int64_t> dims_;
151 };
152 
153 struct desc_t {
154     int64_t sic;
155     int64_t slc;
156     int64_t dhc;
157     int64_t dic;
158     int64_t wc;
159     int64_t mb;
160     int64_t n_layer;
161     int64_t n_iter;
162     const char *name;
163 };
164 int str2desc(desc_t *desc, const char *str);
165 std::ostream &operator<<(std::ostream &s, const desc_t &d);
166 
167 /** configuration structure, that controls initial data filling + error check
168 *
169 * dt defines precision
170 *
171 * for each lst data kind the values are filled as follows:
172 * if (rand() > f_sparsity) then:
173 *     v <-- f_base
174 * else:
175 *     v <-- f_min + rand() * f_step % (f_max - f_min)
176 *
177 * on final check the resulting values should be in [min .. max] range, the
178 * relative difference should not exceed eps
179 */
180 struct dt_conf_t {
181     struct entry_t {
182         dnnl_data_type_t dt;
183         int min, max; // representative
184         float f_min, f_max; // fill range
185         float f_mean, f_stddev; // parameters of normal distribution
186         double eps; // acceptable error
187     };
188 
dt_conf_trnn::dt_conf_t189     dt_conf_t(const std::string &str) : str_(str) {}
190 
191     virtual const entry_t &operator[](data_kind_t kind) const = 0;
192 
strrnn::dt_conf_t193     const std::string &str() const { return str_; }
is_int8rnn::dt_conf_t194     bool is_int8() const {
195         return operator[](SRC_LAYER).dt == dnnl_u8
196                 || operator[](SRC_LAYER).dt == dnnl_s8;
197     }
is_s8rnn::dt_conf_t198     bool is_s8() const { return operator[](SRC_LAYER).dt == dnnl_s8; }
199 
200     static const dt_conf_t &create(const std::string &str);
201 
202     std::string str_;
203 };
204 
205 struct settings_t {
206     settings_t() = default;
207 
208     // ctor to save certain fields from resetting
settings_trnn::settings_t209     settings_t(const char *perf_template) : settings_t() {
210         this->perf_template = perf_template;
211     }
212 
213     desc_t desc {};
214 
215     std::vector<dir_t> prop {FWD_I};
216     std::vector<std::string> cfg {"f32"};
217     std::vector<alg_t> alg {VANILLA_RNN};
218     std::vector<dnnl_rnn_direction_t> direction {
219             dnnl_unidirectional_left2right};
220     std::vector<activation_t> activation {RELU};
221     std::vector<bool> skip_nonlinear {false};
222     std::vector<bool> trivial_strides {false};
223     std::vector<bool> with_peephole {false};
224     std::vector<bool> with_projection {false};
225     std::vector<int64_t> n_layer {0}, n_iter {0}, mb {0};
226     std::vector<policy_t> scale_policy {policy_t::COMMON};
227     std::vector<policy_t> scale_proj_policy {policy_t::COMMON};
228     std::vector<dnnl_scratchpad_mode_t> scratchpad_mode {
229             dnnl_scratchpad_mode_library};
230     unsigned int flags = 0x0;
231     float alpha = 0.9f, beta = 0.0f;
232 
233     const char *perf_template_csv
234             = "perf,%engine%,%impl%,%name%,%prop%,%cfg%,%alg%,%activation%,%"
235               "direction%"
236               ","
237               "%DESC%,%Gops%,%Gfreq%,%-time%,%-Gflops%,%0time%,%0Gflops%";
238     const char *perf_template_def
239             = "perf,%engine%,%impl%,%name%,%prb%,%Gops%,%Gfreq%,%-time%,%-"
240               "Gflops%,%0time%,%0Gflops%";
241     const char *perf_template = perf_template_def;
242 
resetrnn::settings_t243     void reset() { *this = settings_t(perf_template); }
244 };
245 
246 struct prb_t : public desc_t {
prb_trnn::prb_t247     prb_t(const desc_t &desc, const dt_conf_t &cfg, dir_t prop, alg_t alg,
248             bool with_peephole, bool with_projection,
249             dnnl_rnn_direction_t direction, policy_t scale_policy,
250             policy_t scale_proj_policy, unsigned int flags,
251             activation_t activation, const attr_t &attr, float alpha,
252             float beta, bool skip_nonlinear, bool trivial_strides,
253             int64_t n_layer, int64_t n_iter, int64_t mb = 0)
254         : desc_t(desc)
255         , cfg(cfg)
256         , prop(prop2prop_kind(prop))
257         , alg(alg)
258         , with_peephole(with_peephole)
259         , with_projection(with_projection)
260         , direction(direction)
261         , wei_scales_policy(scale_policy)
262         , wei_proj_scales_policy(scale_proj_policy)
263         , flags(flags)
264         , activation(activation)
265         , attr(attr)
266         , user_mb(mb)
267         , alpha(alpha)
268         , beta(beta)
269         , skip_nonlinear(skip_nonlinear)
270         , trivial_strides(trivial_strides)
271         , ops(0.0)
272         , linear_cscale(0.0f) {
273 
274         if (n_layer) this->n_layer = n_layer;
275         if (n_iter) this->n_iter = n_iter;
276         if (mb) this->mb = mb;
277         count_ops();
278 
279         wei_scales = nullptr;
280         wei_proj_scales = nullptr;
281         linear_scales = nullptr;
282 
283         // We always allocate linear scales. Even if they are not
284         // used, they get dereferenced when built in debug mode.
285         linear_scales = (float *)zmalloc(sizeof(float) * n_gates(), 64);
286         // Here we use the range of SRC_LAYER to set the scales
287         set_tparams(cfg[SRC_LAYER].f_min, cfg[SRC_LAYER].f_max);
288 
289         switch (wei_scales_policy) {
290             case policy_t::COMMON:
291                 wei_scales_mask = 0x0;
292                 wei_nscales = 1;
293                 break;
294             case policy_t::PER_OC:
295                 wei_scales_mask = 0x18;
296                 wei_nscales = dhc * n_gates();
297                 break;
298             default: assert(!"unsupported scaling policy");
299         }
300         wei_scales = (float *)zmalloc(sizeof(float) * wei_nscales, 64);
301 
302         if (with_projection) {
303             switch (wei_proj_scales_policy) {
304                 case policy_t::PER_OC:
305                     wei_proj_scales_mask = 0x8;
306                     wei_proj_nscales = dic;
307                     break;
308                 case policy_t::COMMON:
309                     wei_proj_scales_mask = 0x0;
310                     wei_proj_nscales = 1;
311                     break;
312                 default: assert(!"unsupported scaling policy");
313             }
314             wei_proj_scales
315                     = (float *)zmalloc(sizeof(float) * wei_proj_nscales, 64);
316         }
317 
318         set_qparams(-1., 1.);
319     }
~prb_trnn::prb_t320     ~prb_t() {
321         if (wei_scales) zfree(wei_scales);
322         if (wei_proj_scales) zfree(wei_proj_scales);
323         if (linear_scales) zfree(linear_scales);
324     }
325 
get_wei_scalernn::prb_t326     float get_wei_scale(int idx) const {
327         return wei_scales[MIN2(idx, wei_nscales - 1)];
328     }
329 
get_wei_proj_scalernn::prb_t330     inline float get_wei_proj_scale(int idx) const {
331         return wei_proj_scales[MIN2(idx, wei_proj_nscales - 1)];
332     }
333 
count_opsrnn::prb_t334     void count_ops() {
335         // Here, we count only the ops in GEMM portion as there is no
336         // theoretical number of ops for the post-gemm operations
337         int64_t num_cells = (int64_t)n_dir() * n_layer * n_iter;
338         int64_t cell_ops = (int64_t)2 * (n_gates() * dhc) * mb * (sic + slc);
339         if (with_projection) cell_ops += (int64_t)2 * dhc * mb * dic;
340         int64_t prop_multiplier = prop == dnnl_backward ? 2 : 1;
341         ops = prop_multiplier * num_cells * cell_ops;
342     }
343 
n_dirrnn::prb_t344     int64_t n_dir() const {
345         return (direction == dnnl_bidirectional_concat
346                        || direction == dnnl_bidirectional_sum)
347                 ? 2
348                 : 1;
349     }
n_statesrnn::prb_t350     int64_t n_states() const { return alg == VANILLA_LSTM ? 2 : 1; }
n_gatesrnn::prb_t351     int64_t n_gates() const {
352         return alg == VANILLA_LSTM
353                 ? 4
354                 : (alg == VANILLA_GRU || alg == LBR_GRU ? 3 : 1);
355     }
n_biasrnn::prb_t356     int64_t n_bias() const {
357         return alg == LBR_GRU ? n_gates() + 1 : n_gates();
358     }
359 
dlcrnn::prb_t360     int64_t dlc(dlc_type_t type) const {
361         if (type == PRIMITIVE)
362             return (direction == dnnl_bidirectional_concat ? 2 : 1) * dic;
363         if (type == CELL) return dic;
364         assert(!"unsupported dlc type");
365         return 0;
366     }
367 
is_int8rnn::prb_t368     bool is_int8() const {
369         return cfg[SRC_LAYER].dt == dnnl_u8 || cfg[SRC_LAYER].dt == dnnl_s8;
370     }
is_u8rnn::prb_t371     bool is_u8() const { return cfg[SRC_LAYER].dt == dnnl_u8; }
is_s8rnn::prb_t372     bool is_s8() const { return cfg[SRC_LAYER].dt == dnnl_s8; }
is_lstm_peepholernn::prb_t373     bool is_lstm_peephole() const { return with_peephole; }
is_lstm_projectionrnn::prb_t374     bool is_lstm_projection() const { return with_projection; }
375 
376     const dt_conf_t &cfg;
377     dnnl_prop_kind_t prop;
378     alg_t alg;
379     bool with_peephole, with_projection;
380     dnnl_rnn_direction_t direction;
381     policy_t wei_scales_policy;
382     policy_t wei_proj_scales_policy;
383     unsigned int flags;
384     activation_t activation;
385     attr_t attr;
386     int64_t user_mb;
387     float alpha;
388     float beta;
389 
390     float data_scale, data_shift;
391 
392     float *wei_scales;
393     int wei_nscales;
394     int wei_scales_mask;
395 
396     float *wei_proj_scales;
397     int wei_proj_nscales;
398     int wei_proj_scales_mask;
399 
400     bool skip_nonlinear;
401     bool trivial_strides;
402     double ops;
403     float *linear_scales;
404     float linear_cscale;
405 
406 private:
407     /* Todo: fused the two functions in set_shifts_scales */
408     void set_qparams(float fp_min, float fp_max);
409     void set_tparams(float fp_min, float fp_max);
410     prb_t(const prb_t &) = delete;
411     prb_t &operator=(const prb_t &) = delete;
412 };
413 std::ostream &operator<<(std::ostream &s, const prb_t &prb);
414 
415 struct perf_report_t : public base_perf_report_t {
416     using base_perf_report_t::base_perf_report_t;
417 
reportrnn::perf_report_t418     void report(const prb_t *prb, const res_t *res, const char *prb_str) {
419         p_ = prb;
420         base_report(res, prb_str);
421     }
422 
dump_algrnn::perf_report_t423     void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); }
424 
dump_cfgrnn::perf_report_t425     void dump_cfg(std::ostream &s) const override { s << p_->cfg.str(); }
426 
dump_descrnn::perf_report_t427     void dump_desc(std::ostream &s) const override {
428         s << static_cast<const desc_t &>(*p_);
429     }
430 
dump_desc_csvrnn::perf_report_t431     void dump_desc_csv(std::ostream &s) const override {
432         s << p_->n_layer << "," << p_->n_iter << "," << p_->mb << "," << p_->sic
433           << "," << p_->slc << "," << p_->dhc << "," << p_->dic;
434     }
435 
dump_rnn_activationrnn::perf_report_t436     void dump_rnn_activation(std::ostream &s) const override {
437         s << activation2str(p_->activation);
438     }
439 
dump_rnn_directionrnn::perf_report_t440     void dump_rnn_direction(std::ostream &s) const override {
441         s << direction2str(p_->direction);
442     }
443 
opsrnn::perf_report_t444     double ops() const override { return p_->ops; }
user_mbrnn::perf_report_t445     const int64_t *user_mb() const override { return &p_->user_mb; }
namernn::perf_report_t446     const char *name() const override { return p_->name; }
proprnn::perf_report_t447     const dnnl_prop_kind_t *prop() const override { return &p_->prop; }
448 
449 private:
450     const prb_t *p_ = nullptr;
451 };
452 
453 void prepare_ws_fwd(const prb_t &prb, std::vector<float> &ws_fwd_buffer,
454         AOC<float> &ws_src_layer, AOC<float> &ws_src_iter,
455         AOC<float> &ws_src_iter_c, AOC<float> &ws_gates, AOC<float> &ws_ht);
456 
457 void rnn_linear_fwd(const prb_t &prb, const float *src_layer_,
458         const float *src_iter_, const float *src_iter_c_,
459         const float *weights_layer_, const float *weights_iter_,
460         const float *weights_peephole_, const float *weights_projection_,
461         const float *bias_, float *dst_layer_, float *dst_iter_,
462         float *dst_iter_c_, const AOC<float> &ws_src_layer,
463         const AOC<float> &ws_src_iter, const AOC<float> &ws_src_iter_c,
464         const AOC<float> &ws_gates, const AOC<float> &ws_ht);
465 
466 void compute_ref_fwd(const prb_t &prb, dnn_mem_t &src_layer_m,
467         dnn_mem_t &src_iter_m, dnn_mem_t &src_iter_c_m,
468         dnn_mem_t &weights_layer_m, dnn_mem_t &weights_iter_m,
469         dnn_mem_t &weights_peephole_m, dnn_mem_t &weights_projection_m,
470         dnn_mem_t &bias_m, dnn_mem_t &dst_layer_m, dnn_mem_t &dst_iter_m,
471         dnn_mem_t &dst_iter_c_m);
472 
473 void compute_ref_bwd(const prb_t &prb, dnn_mem_t &src_layer_m,
474         dnn_mem_t &src_iter_m, dnn_mem_t &src_iter_c_m,
475         dnn_mem_t &diff_dst_layer_m, dnn_mem_t &diff_dst_iter_m,
476         dnn_mem_t &diff_dst_iter_c_m, dnn_mem_t &weights_layer_m,
477         dnn_mem_t &weights_iter_m, dnn_mem_t &weights_peephole_m,
478         dnn_mem_t &weights_projection_m, dnn_mem_t &bias_m,
479         dnn_mem_t &dst_layer_m, dnn_mem_t &dst_iter_m, dnn_mem_t &dst_iter_c_m,
480         dnn_mem_t &diff_src_layer_m, dnn_mem_t &diff_src_iter_m,
481         dnn_mem_t &diff_src_iter_c_m, dnn_mem_t &diff_weights_layer_m,
482         dnn_mem_t &diff_weights_iter_m, dnn_mem_t &diff_weights_peephole_m,
483         dnn_mem_t &diff_weights_projection_m, dnn_mem_t &diff_bias_m);
484 
485 int doit(const prb_t &prb, res_t *res);
486 int bench(int argc, char **argv);
487 
488 } // namespace rnn
489 
490 #endif
491