1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <set>
18 
19 #include "dnnl_common.hpp"
20 #include "rnn/rnn.hpp"
21 
22 namespace rnn {
23 
24 namespace {
25 
26 #define CASE(KIND, ENTRY) \
27     if (kind == (KIND)) return ENTRY
28 #define DEFAULT(ENTRY) return ENTRY
29 #define END_LIST \
30     SAFE_V(CRIT); \
31     return F32_ENTRY
32 
33 std::set<const dt_conf_t *> cfg_list;
34 #define CFG(name) \
35     struct conf_##name##_t : dt_conf_t { \
36         using dt_conf_t::dt_conf_t; \
37         const entry_t &operator[](data_kind_t kind) const override; \
38     } conf_##name(STRINGIFY(name)); \
39     static auto __reg_##name = cfg_list.insert(&conf_##name); \
40     const dt_conf_t::entry_t &conf_##name##_t::operator[](data_kind_t kind) \
41             const
42 
43 // f32
44 #define MIN_F32 0.0f
45 #define MAX_F32 .999999f
46 #define MEAN_F32 .5f
47 #define STDDEV_F32 0.01f
48 #define EPS_F32 epsilon_dt(dnnl_f32)
49 const int f32_max_exact = 1 << 24;
50 dt_conf_t::entry_t F32_ENTRY {dnnl_f32, -f32_max_exact, f32_max_exact, MIN_F32,
51         MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32};
52 
CFG(f32)53 CFG(f32) {
54     return F32_ENTRY;
55 }
56 
57 // bf16
58 #define MIN_BF16 0.0f
59 #define MAX_BF16 .999999f
60 #define MEAN_BF16 .5f
61 #define STDDEV_BF16 0.01f
62 #define EPS_BF16 epsilon_dt(dnnl_bf16)
63 dt_conf_t::entry_t BF16_ENTRY_BF16 {dnnl_bf16, -f32_max_exact, f32_max_exact,
64         MIN_BF16, MAX_BF16, MEAN_BF16, STDDEV_BF16, EPS_BF16};
65 dt_conf_t::entry_t BF16_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact,
66         MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_BF16};
67 
CFG(bf16)68 CFG(bf16) {
69     CASE(SRC_LAYER, BF16_ENTRY_BF16);
70     CASE(SRC_ITER, BF16_ENTRY_BF16);
71     CASE(WEIGHTS_LAYER, BF16_ENTRY_BF16);
72     CASE(WEIGHTS_ITER, BF16_ENTRY_BF16);
73     CASE(DST_ITER, BF16_ENTRY_BF16);
74     CASE(DST_LAYER, BF16_ENTRY_BF16);
75     DEFAULT(BF16_ENTRY_F32);
76 }
77 
78 // f16
79 const int f16_max_exact = 1 << 11;
80 dt_conf_t::entry_t F16_ENTRY {dnnl_f16, -f16_max_exact, f16_max_exact, 0.0f,
81         0.999999f, 0.5f, 0.01f, epsilon_dt(dnnl_f16)};
82 
CFG(f16)83 CFG(f16) {
84     return F16_ENTRY;
85 }
86 
87 // s8
88 #define EPS_U8 4e-3
89 #define EPS_S8 8e-3
90 
91 #define MIN_U8 0.0f
92 #define MAX_U8 127.f
93 #define MEAN_U8 28.f
94 #define STDDEV_U8 16.f
95 
96 #define MIN_S8 (-64.f)
97 #define MAX_S8 64.f
98 #define MEAN_S8 8.f
99 #define STDDEV_S8 32.f
100 #define MEAN_WEIGHT_S8 0.f
101 
102 dt_conf_t::entry_t U8_ENTRY_U8_EXACT {
103         dnnl_u8, 0, UINT8_MAX, MIN_U8, MAX_U8, MEAN_U8, STDDEV_U8, 0.f};
104 dt_conf_t::entry_t U8_ENTRY_U8 {
105         dnnl_u8, 0, UINT8_MAX, MIN_U8, MAX_U8, MEAN_U8, STDDEV_U8, EPS_U8};
106 dt_conf_t::entry_t U8_ENTRY_S8 {dnnl_s8, INT8_MIN, INT8_MAX, MIN_S8, MAX_S8,
107         MEAN_WEIGHT_S8, STDDEV_S8, EPS_S8};
108 dt_conf_t::entry_t U8_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact,
109         MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32};
110 
111 dt_conf_t::entry_t S8_ENTRY_S8_EXACT {
112         dnnl_s8, INT8_MIN, INT8_MAX, 0, MAX_S8, MEAN_S8, STDDEV_S8, 0.f};
113 dt_conf_t::entry_t S8_ENTRY_S8 {
114         dnnl_s8, INT8_MIN, INT8_MAX, 0, MAX_S8, MEAN_S8, STDDEV_S8, EPS_S8};
115 dt_conf_t::entry_t S8_ENTRY_WEIGHT_S8 {dnnl_s8, INT8_MIN, INT8_MAX, MIN_S8,
116         MAX_S8, MEAN_WEIGHT_S8, STDDEV_S8, EPS_S8};
117 dt_conf_t::entry_t S8_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact,
118         MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32};
119 
CFG(u8u8u8u8)120 CFG(u8u8u8u8) {
121     CASE(SRC_LAYER, U8_ENTRY_U8);
122     CASE(SRC_ITER, U8_ENTRY_U8);
123     CASE(SRC_ITER_C, U8_ENTRY_F32);
124     CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
125     CASE(WEIGHTS_ITER, U8_ENTRY_S8);
126     CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
127     CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
128     CASE(BIAS, U8_ENTRY_F32);
129     CASE(DST_ITER, U8_ENTRY_U8);
130     CASE(DST_ITER_C, U8_ENTRY_F32);
131     CASE(DST_LAYER, U8_ENTRY_U8_EXACT);
132     END_LIST;
133 }
134 
CFG(u8u8u8f32)135 CFG(u8u8u8f32) {
136     CASE(SRC_LAYER, U8_ENTRY_U8);
137     CASE(SRC_ITER, U8_ENTRY_U8);
138     CASE(SRC_ITER_C, U8_ENTRY_F32);
139     CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
140     CASE(WEIGHTS_ITER, U8_ENTRY_S8);
141     CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
142     CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
143     CASE(BIAS, U8_ENTRY_F32);
144     CASE(DST_ITER, U8_ENTRY_U8);
145     CASE(DST_ITER_C, U8_ENTRY_F32);
146     CASE(DST_LAYER, U8_ENTRY_F32);
147     END_LIST;
148 }
149 
CFG(f32u8f32u8)150 CFG(f32u8f32u8) {
151     CASE(SRC_LAYER, U8_ENTRY_U8);
152     CASE(SRC_ITER, U8_ENTRY_F32);
153     CASE(SRC_ITER_C, U8_ENTRY_F32);
154     CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
155     CASE(WEIGHTS_ITER, U8_ENTRY_S8);
156     CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
157     CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
158     CASE(BIAS, U8_ENTRY_F32);
159     CASE(DST_ITER, U8_ENTRY_F32);
160     CASE(DST_ITER_C, U8_ENTRY_F32);
161     CASE(DST_LAYER, U8_ENTRY_U8_EXACT);
162     END_LIST;
163 }
164 
CFG(f32u8f32f32)165 CFG(f32u8f32f32) {
166     CASE(SRC_LAYER, U8_ENTRY_U8);
167     CASE(SRC_ITER, U8_ENTRY_F32);
168     CASE(SRC_ITER_C, U8_ENTRY_F32);
169     CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
170     CASE(WEIGHTS_ITER, U8_ENTRY_S8);
171     CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
172     CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
173     CASE(BIAS, U8_ENTRY_F32);
174     CASE(DST_ITER, U8_ENTRY_F32);
175     CASE(DST_ITER_C, U8_ENTRY_F32);
176     CASE(DST_LAYER, U8_ENTRY_F32);
177     END_LIST;
178 }
179 
CFG(s8s8s8s8)180 CFG(s8s8s8s8) {
181     CASE(SRC_LAYER, S8_ENTRY_S8);
182     CASE(SRC_ITER, S8_ENTRY_S8);
183     CASE(SRC_ITER_C, S8_ENTRY_F32);
184     CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
185     CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
186     CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
187     CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
188     CASE(BIAS, S8_ENTRY_F32);
189     CASE(DST_ITER, S8_ENTRY_S8);
190     CASE(DST_ITER_C, S8_ENTRY_F32);
191     CASE(DST_LAYER, S8_ENTRY_S8_EXACT);
192     END_LIST;
193 }
194 
CFG(s8s8s8f32)195 CFG(s8s8s8f32) {
196     CASE(SRC_LAYER, S8_ENTRY_S8);
197     CASE(SRC_ITER, S8_ENTRY_S8);
198     CASE(SRC_ITER_C, S8_ENTRY_F32);
199     CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
200     CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
201     CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
202     CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
203     CASE(BIAS, S8_ENTRY_F32);
204     CASE(DST_ITER, S8_ENTRY_S8);
205     CASE(DST_ITER_C, S8_ENTRY_F32);
206     CASE(DST_LAYER, S8_ENTRY_F32);
207     END_LIST;
208 }
209 
CFG(f32s8f32s8)210 CFG(f32s8f32s8) {
211     CASE(SRC_LAYER, S8_ENTRY_S8);
212     CASE(SRC_ITER, S8_ENTRY_F32);
213     CASE(SRC_ITER_C, S8_ENTRY_F32);
214     CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
215     CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
216     CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
217     CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
218     CASE(BIAS, S8_ENTRY_F32);
219     CASE(DST_ITER, S8_ENTRY_F32);
220     CASE(DST_ITER_C, S8_ENTRY_F32);
221     CASE(DST_LAYER, S8_ENTRY_S8_EXACT);
222     END_LIST;
223 }
224 
CFG(f32s8f32f32)225 CFG(f32s8f32f32) {
226     CASE(SRC_LAYER, S8_ENTRY_S8);
227     CASE(SRC_ITER, S8_ENTRY_F32);
228     CASE(SRC_ITER_C, S8_ENTRY_F32);
229     CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
230     CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
231     CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
232     CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
233     CASE(BIAS, S8_ENTRY_F32);
234     CASE(DST_ITER, S8_ENTRY_F32);
235     CASE(DST_ITER_C, S8_ENTRY_F32);
236     CASE(DST_LAYER, S8_ENTRY_F32);
237     END_LIST;
238 }
239 
240 } // namespace
241 
create(const std::string & str)242 const dt_conf_t &dt_conf_t::create(const std::string &str) {
243     for (const auto cfg : cfg_list)
244         if (cfg->str() == str) return *cfg;
245     SAFE_V(CRIT);
246     return conf_f32;
247 }
248 
249 } // namespace rnn
250