1 /******************************************************************************* 2 * Copyright 2018-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef RNN_HPP 18 #define RNN_HPP 19 20 #include <assert.h> 21 #include <limits.h> 22 #include <stdint.h> 23 24 #include <string> 25 #include <vector> 26 27 #include "common.hpp" 28 #include "dnn_types.hpp" 29 #include "dnnl_common.hpp" 30 #include "dnnl_debug.hpp" 31 #include "dnnl_memory.hpp" 32 #include "perf_report.hpp" 33 34 #define AOC array_offset_calculator 35 36 namespace rnn { 37 38 enum alg_t { VANILLA_RNN, VANILLA_LSTM, VANILLA_GRU, LBR_GRU }; 39 alg_t str2alg(const char *str); 40 const char *alg2str(alg_t alg); 41 dnnl_alg_kind_t alg2kind(alg_t alg); 42 43 enum activation_t { UNDEF, RELU, LOGISTIC, TANH }; 44 activation_t str2activation(const char *str); 45 const char *activation2str(activation_t alg); 46 dnnl_alg_kind_t activation2kind(activation_t alg); 47 48 dnnl_rnn_direction_t str2direction(const char *str); 49 const char *direction2str(dnnl_rnn_direction_t direction); 50 51 enum data_kind_t { 52 SRC_LAYER, 53 SRC_ITER, 54 SRC_ITER_C, 55 WEIGHTS_LAYER, 56 WEIGHTS_ITER, 57 BIAS, 58 DST_ITER, 59 DST_ITER_C, 60 DST_LAYER, 61 62 DIFF_SRC_LAYER, 63 DIFF_SRC_ITER, 64 DIFF_SRC_ITER_C, 65 DIFF_WEIGHTS_LAYER, 66 DIFF_WEIGHTS_ITER, 67 DIFF_BIAS, 68 DIFF_DST_ITER, 69 DIFF_DST_ITER_C, 70 DIFF_DST_LAYER, 71 72 // FIXME: adding peephole related weights to the appropriate places will 73 // cause false-positive accuracy check failures in unrelated test cases 74 // (e.g. backward vanilla RNN for bf16) due to the data fill seed being 75 // dependent on the position of the tensor kind in the enum: adding 76 // `WEIGHTS_PEEPHOLE` before `dst_*` and `*diff_*` results in initializing 77 // the corresponding tensors differently. 78 // We need a more robust way of testing RNN. 79 WEIGHTS_PEEPHOLE, 80 DIFF_WEIGHTS_PEEPHOLE, 81 WEIGHTS_PROJECTION, 82 DIFF_WEIGHTS_PROJECTION, 83 }; 84 const char *data_kind2str(data_kind_t kind); 85 86 // Gates indices 87 enum { 88 LSTM_I = 0, 89 LSTM_F = 1, 90 LSTM_C = 2, 91 LSTM_O = 3, 92 GRU_U = 0, 93 GRU_R = 1, 94 GRU_O = 2, 95 LBR_GRU_U_PRIME = 3, 96 }; 97 98 // dlc is different at the cell level and the primitive level 99 // This enum enable to explicitely query the intended one 100 enum dlc_type_t { CELL, PRIMITIVE }; 101 102 template <typename Telem> 103 struct array_offset_calculator { 104 array_offset_calculator() = default; 105 106 template <typename... Targs> array_offset_calculatorrnn::array_offset_calculator107 array_offset_calculator(Telem *base_ptr, Targs... dims) 108 : base_ptr_(base_ptr), dims_({dims...}) {} 109 110 // ctor for AOC<const T> based on const AOC<T> & 111 template <typename Uelem> array_offset_calculatorrnn::array_offset_calculator112 array_offset_calculator(const array_offset_calculator<Uelem> &rhs) 113 : base_ptr_(rhs.base_ptr_), dims_(rhs.dims_) {} 114 115 // to make the above ctor work AOC<const T> should be able to access 116 // private fields of AOC<T>, hence let's friend them 117 friend struct array_offset_calculator<const Telem>; 118 119 template <typename... Targs> operator ()rnn::array_offset_calculator120 Telem &operator()(Targs... Fargs) const { 121 return *(base_ptr_ + offset(1, Fargs...)); 122 } 123 nelemsrnn::array_offset_calculator124 int64_t nelems() const { 125 int64_t res = 1; 126 for (auto dim : dims_) 127 res *= dim; 128 return res; 129 } 130 set_base_ptrrnn::array_offset_calculator131 void set_base_ptr(Telem *base_ptr) { base_ptr_ = base_ptr; } 132 133 private: 134 template <typename... Targs> offsetrnn::array_offset_calculator135 int64_t offset(int64_t d, int64_t pos) const { 136 return pos; 137 } 138 139 template <typename... Targs> offsetrnn::array_offset_calculator140 int64_t offset(int64_t d, int64_t off, int64_t pos) const { 141 return off * dims_[d] + pos; 142 } 143 144 template <typename... Targs> offsetrnn::array_offset_calculator145 int64_t offset(int64_t d, int64_t off, int64_t pos, Targs... rem) const { 146 return offset(d + 1, off * dims_[d] + pos, rem...); 147 } 148 149 Telem *base_ptr_; 150 std::vector<int64_t> dims_; 151 }; 152 153 struct desc_t { 154 int64_t sic; 155 int64_t slc; 156 int64_t dhc; 157 int64_t dic; 158 int64_t wc; 159 int64_t mb; 160 int64_t n_layer; 161 int64_t n_iter; 162 const char *name; 163 }; 164 int str2desc(desc_t *desc, const char *str); 165 std::ostream &operator<<(std::ostream &s, const desc_t &d); 166 167 /** configuration structure, that controls initial data filling + error check 168 * 169 * dt defines precision 170 * 171 * for each lst data kind the values are filled as follows: 172 * if (rand() > f_sparsity) then: 173 * v <-- f_base 174 * else: 175 * v <-- f_min + rand() * f_step % (f_max - f_min) 176 * 177 * on final check the resulting values should be in [min .. max] range, the 178 * relative difference should not exceed eps 179 */ 180 struct dt_conf_t { 181 struct entry_t { 182 dnnl_data_type_t dt; 183 int min, max; // representative 184 float f_min, f_max; // fill range 185 float f_mean, f_stddev; // parameters of normal distribution 186 double eps; // acceptable error 187 }; 188 dt_conf_trnn::dt_conf_t189 dt_conf_t(const std::string &str) : str_(str) {} 190 191 virtual const entry_t &operator[](data_kind_t kind) const = 0; 192 strrnn::dt_conf_t193 const std::string &str() const { return str_; } is_int8rnn::dt_conf_t194 bool is_int8() const { 195 return operator[](SRC_LAYER).dt == dnnl_u8 196 || operator[](SRC_LAYER).dt == dnnl_s8; 197 } is_s8rnn::dt_conf_t198 bool is_s8() const { return operator[](SRC_LAYER).dt == dnnl_s8; } 199 200 static const dt_conf_t &create(const std::string &str); 201 202 std::string str_; 203 }; 204 205 struct settings_t { 206 settings_t() = default; 207 208 // ctor to save certain fields from resetting settings_trnn::settings_t209 settings_t(const char *perf_template) : settings_t() { 210 this->perf_template = perf_template; 211 } 212 213 desc_t desc {}; 214 215 std::vector<dir_t> prop {FWD_I}; 216 std::vector<std::string> cfg {"f32"}; 217 std::vector<alg_t> alg {VANILLA_RNN}; 218 std::vector<dnnl_rnn_direction_t> direction { 219 dnnl_unidirectional_left2right}; 220 std::vector<activation_t> activation {RELU}; 221 std::vector<bool> skip_nonlinear {false}; 222 std::vector<bool> trivial_strides {false}; 223 std::vector<bool> with_peephole {false}; 224 std::vector<bool> with_projection {false}; 225 std::vector<int64_t> n_layer {0}, n_iter {0}, mb {0}; 226 std::vector<policy_t> scale_policy {policy_t::COMMON}; 227 std::vector<policy_t> scale_proj_policy {policy_t::COMMON}; 228 std::vector<dnnl_scratchpad_mode_t> scratchpad_mode { 229 dnnl_scratchpad_mode_library}; 230 unsigned int flags = 0x0; 231 float alpha = 0.9f, beta = 0.0f; 232 233 const char *perf_template_csv 234 = "perf,%engine%,%impl%,%name%,%prop%,%cfg%,%alg%,%activation%,%" 235 "direction%" 236 "," 237 "%DESC%,%Gops%,%Gfreq%,%-time%,%-Gflops%,%0time%,%0Gflops%"; 238 const char *perf_template_def 239 = "perf,%engine%,%impl%,%name%,%prb%,%Gops%,%Gfreq%,%-time%,%-" 240 "Gflops%,%0time%,%0Gflops%"; 241 const char *perf_template = perf_template_def; 242 resetrnn::settings_t243 void reset() { *this = settings_t(perf_template); } 244 }; 245 246 struct prb_t : public desc_t { prb_trnn::prb_t247 prb_t(const desc_t &desc, const dt_conf_t &cfg, dir_t prop, alg_t alg, 248 bool with_peephole, bool with_projection, 249 dnnl_rnn_direction_t direction, policy_t scale_policy, 250 policy_t scale_proj_policy, unsigned int flags, 251 activation_t activation, const attr_t &attr, float alpha, 252 float beta, bool skip_nonlinear, bool trivial_strides, 253 int64_t n_layer, int64_t n_iter, int64_t mb = 0) 254 : desc_t(desc) 255 , cfg(cfg) 256 , prop(prop2prop_kind(prop)) 257 , alg(alg) 258 , with_peephole(with_peephole) 259 , with_projection(with_projection) 260 , direction(direction) 261 , wei_scales_policy(scale_policy) 262 , wei_proj_scales_policy(scale_proj_policy) 263 , flags(flags) 264 , activation(activation) 265 , attr(attr) 266 , user_mb(mb) 267 , alpha(alpha) 268 , beta(beta) 269 , skip_nonlinear(skip_nonlinear) 270 , trivial_strides(trivial_strides) 271 , ops(0.0) 272 , linear_cscale(0.0f) { 273 274 if (n_layer) this->n_layer = n_layer; 275 if (n_iter) this->n_iter = n_iter; 276 if (mb) this->mb = mb; 277 count_ops(); 278 279 wei_scales = nullptr; 280 wei_proj_scales = nullptr; 281 linear_scales = nullptr; 282 283 // We always allocate linear scales. Even if they are not 284 // used, they get dereferenced when built in debug mode. 285 linear_scales = (float *)zmalloc(sizeof(float) * n_gates(), 64); 286 // Here we use the range of SRC_LAYER to set the scales 287 set_tparams(cfg[SRC_LAYER].f_min, cfg[SRC_LAYER].f_max); 288 289 switch (wei_scales_policy) { 290 case policy_t::COMMON: 291 wei_scales_mask = 0x0; 292 wei_nscales = 1; 293 break; 294 case policy_t::PER_OC: 295 wei_scales_mask = 0x18; 296 wei_nscales = dhc * n_gates(); 297 break; 298 default: assert(!"unsupported scaling policy"); 299 } 300 wei_scales = (float *)zmalloc(sizeof(float) * wei_nscales, 64); 301 302 if (with_projection) { 303 switch (wei_proj_scales_policy) { 304 case policy_t::PER_OC: 305 wei_proj_scales_mask = 0x8; 306 wei_proj_nscales = dic; 307 break; 308 case policy_t::COMMON: 309 wei_proj_scales_mask = 0x0; 310 wei_proj_nscales = 1; 311 break; 312 default: assert(!"unsupported scaling policy"); 313 } 314 wei_proj_scales 315 = (float *)zmalloc(sizeof(float) * wei_proj_nscales, 64); 316 } 317 318 set_qparams(-1., 1.); 319 } ~prb_trnn::prb_t320 ~prb_t() { 321 if (wei_scales) zfree(wei_scales); 322 if (wei_proj_scales) zfree(wei_proj_scales); 323 if (linear_scales) zfree(linear_scales); 324 } 325 get_wei_scalernn::prb_t326 float get_wei_scale(int idx) const { 327 return wei_scales[MIN2(idx, wei_nscales - 1)]; 328 } 329 get_wei_proj_scalernn::prb_t330 inline float get_wei_proj_scale(int idx) const { 331 return wei_proj_scales[MIN2(idx, wei_proj_nscales - 1)]; 332 } 333 count_opsrnn::prb_t334 void count_ops() { 335 // Here, we count only the ops in GEMM portion as there is no 336 // theoretical number of ops for the post-gemm operations 337 int64_t num_cells = (int64_t)n_dir() * n_layer * n_iter; 338 int64_t cell_ops = (int64_t)2 * (n_gates() * dhc) * mb * (sic + slc); 339 if (with_projection) cell_ops += (int64_t)2 * dhc * mb * dic; 340 int64_t prop_multiplier = prop == dnnl_backward ? 2 : 1; 341 ops = prop_multiplier * num_cells * cell_ops; 342 } 343 n_dirrnn::prb_t344 int64_t n_dir() const { 345 return (direction == dnnl_bidirectional_concat 346 || direction == dnnl_bidirectional_sum) 347 ? 2 348 : 1; 349 } n_statesrnn::prb_t350 int64_t n_states() const { return alg == VANILLA_LSTM ? 2 : 1; } n_gatesrnn::prb_t351 int64_t n_gates() const { 352 return alg == VANILLA_LSTM 353 ? 4 354 : (alg == VANILLA_GRU || alg == LBR_GRU ? 3 : 1); 355 } n_biasrnn::prb_t356 int64_t n_bias() const { 357 return alg == LBR_GRU ? n_gates() + 1 : n_gates(); 358 } 359 dlcrnn::prb_t360 int64_t dlc(dlc_type_t type) const { 361 if (type == PRIMITIVE) 362 return (direction == dnnl_bidirectional_concat ? 2 : 1) * dic; 363 if (type == CELL) return dic; 364 assert(!"unsupported dlc type"); 365 return 0; 366 } 367 is_int8rnn::prb_t368 bool is_int8() const { 369 return cfg[SRC_LAYER].dt == dnnl_u8 || cfg[SRC_LAYER].dt == dnnl_s8; 370 } is_u8rnn::prb_t371 bool is_u8() const { return cfg[SRC_LAYER].dt == dnnl_u8; } is_s8rnn::prb_t372 bool is_s8() const { return cfg[SRC_LAYER].dt == dnnl_s8; } is_lstm_peepholernn::prb_t373 bool is_lstm_peephole() const { return with_peephole; } is_lstm_projectionrnn::prb_t374 bool is_lstm_projection() const { return with_projection; } 375 376 const dt_conf_t &cfg; 377 dnnl_prop_kind_t prop; 378 alg_t alg; 379 bool with_peephole, with_projection; 380 dnnl_rnn_direction_t direction; 381 policy_t wei_scales_policy; 382 policy_t wei_proj_scales_policy; 383 unsigned int flags; 384 activation_t activation; 385 attr_t attr; 386 int64_t user_mb; 387 float alpha; 388 float beta; 389 390 float data_scale, data_shift; 391 392 float *wei_scales; 393 int wei_nscales; 394 int wei_scales_mask; 395 396 float *wei_proj_scales; 397 int wei_proj_nscales; 398 int wei_proj_scales_mask; 399 400 bool skip_nonlinear; 401 bool trivial_strides; 402 double ops; 403 float *linear_scales; 404 float linear_cscale; 405 406 private: 407 /* Todo: fused the two functions in set_shifts_scales */ 408 void set_qparams(float fp_min, float fp_max); 409 void set_tparams(float fp_min, float fp_max); 410 prb_t(const prb_t &) = delete; 411 prb_t &operator=(const prb_t &) = delete; 412 }; 413 std::ostream &operator<<(std::ostream &s, const prb_t &prb); 414 415 struct perf_report_t : public base_perf_report_t { 416 using base_perf_report_t::base_perf_report_t; 417 reportrnn::perf_report_t418 void report(const prb_t *prb, const res_t *res, const char *prb_str) { 419 p_ = prb; 420 base_report(res, prb_str); 421 } 422 dump_algrnn::perf_report_t423 void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); } 424 dump_cfgrnn::perf_report_t425 void dump_cfg(std::ostream &s) const override { s << p_->cfg.str(); } 426 dump_descrnn::perf_report_t427 void dump_desc(std::ostream &s) const override { 428 s << static_cast<const desc_t &>(*p_); 429 } 430 dump_desc_csvrnn::perf_report_t431 void dump_desc_csv(std::ostream &s) const override { 432 s << p_->n_layer << "," << p_->n_iter << "," << p_->mb << "," << p_->sic 433 << "," << p_->slc << "," << p_->dhc << "," << p_->dic; 434 } 435 dump_rnn_activationrnn::perf_report_t436 void dump_rnn_activation(std::ostream &s) const override { 437 s << activation2str(p_->activation); 438 } 439 dump_rnn_directionrnn::perf_report_t440 void dump_rnn_direction(std::ostream &s) const override { 441 s << direction2str(p_->direction); 442 } 443 opsrnn::perf_report_t444 double ops() const override { return p_->ops; } user_mbrnn::perf_report_t445 const int64_t *user_mb() const override { return &p_->user_mb; } namernn::perf_report_t446 const char *name() const override { return p_->name; } proprnn::perf_report_t447 const dnnl_prop_kind_t *prop() const override { return &p_->prop; } 448 449 private: 450 const prb_t *p_ = nullptr; 451 }; 452 453 void prepare_ws_fwd(const prb_t &prb, std::vector<float> &ws_fwd_buffer, 454 AOC<float> &ws_src_layer, AOC<float> &ws_src_iter, 455 AOC<float> &ws_src_iter_c, AOC<float> &ws_gates, AOC<float> &ws_ht); 456 457 void rnn_linear_fwd(const prb_t &prb, const float *src_layer_, 458 const float *src_iter_, const float *src_iter_c_, 459 const float *weights_layer_, const float *weights_iter_, 460 const float *weights_peephole_, const float *weights_projection_, 461 const float *bias_, float *dst_layer_, float *dst_iter_, 462 float *dst_iter_c_, const AOC<float> &ws_src_layer, 463 const AOC<float> &ws_src_iter, const AOC<float> &ws_src_iter_c, 464 const AOC<float> &ws_gates, const AOC<float> &ws_ht); 465 466 void compute_ref_fwd(const prb_t &prb, dnn_mem_t &src_layer_m, 467 dnn_mem_t &src_iter_m, dnn_mem_t &src_iter_c_m, 468 dnn_mem_t &weights_layer_m, dnn_mem_t &weights_iter_m, 469 dnn_mem_t &weights_peephole_m, dnn_mem_t &weights_projection_m, 470 dnn_mem_t &bias_m, dnn_mem_t &dst_layer_m, dnn_mem_t &dst_iter_m, 471 dnn_mem_t &dst_iter_c_m); 472 473 void compute_ref_bwd(const prb_t &prb, dnn_mem_t &src_layer_m, 474 dnn_mem_t &src_iter_m, dnn_mem_t &src_iter_c_m, 475 dnn_mem_t &diff_dst_layer_m, dnn_mem_t &diff_dst_iter_m, 476 dnn_mem_t &diff_dst_iter_c_m, dnn_mem_t &weights_layer_m, 477 dnn_mem_t &weights_iter_m, dnn_mem_t &weights_peephole_m, 478 dnn_mem_t &weights_projection_m, dnn_mem_t &bias_m, 479 dnn_mem_t &dst_layer_m, dnn_mem_t &dst_iter_m, dnn_mem_t &dst_iter_c_m, 480 dnn_mem_t &diff_src_layer_m, dnn_mem_t &diff_src_iter_m, 481 dnn_mem_t &diff_src_iter_c_m, dnn_mem_t &diff_weights_layer_m, 482 dnn_mem_t &diff_weights_iter_m, dnn_mem_t &diff_weights_peephole_m, 483 dnn_mem_t &diff_weights_projection_m, dnn_mem_t &diff_bias_m); 484 485 int doit(const prb_t &prb, res_t *res); 486 int bench(int argc, char **argv); 487 488 } // namespace rnn 489 490 #endif 491