1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <set>
18
19 #include "dnnl_common.hpp"
20 #include "rnn/rnn.hpp"
21
22 namespace rnn {
23
24 namespace {
25
26 #define CASE(KIND, ENTRY) \
27 if (kind == (KIND)) return ENTRY
28 #define DEFAULT(ENTRY) return ENTRY
29 #define END_LIST \
30 SAFE_V(CRIT); \
31 return F32_ENTRY
32
33 std::set<const dt_conf_t *> cfg_list;
34 #define CFG(name) \
35 struct conf_##name##_t : dt_conf_t { \
36 using dt_conf_t::dt_conf_t; \
37 const entry_t &operator[](data_kind_t kind) const override; \
38 } conf_##name(STRINGIFY(name)); \
39 static auto __reg_##name = cfg_list.insert(&conf_##name); \
40 const dt_conf_t::entry_t &conf_##name##_t::operator[](data_kind_t kind) \
41 const
42
43 // f32
44 #define MIN_F32 0.0f
45 #define MAX_F32 .999999f
46 #define MEAN_F32 .5f
47 #define STDDEV_F32 0.01f
48 #define EPS_F32 epsilon_dt(dnnl_f32)
49 const int f32_max_exact = 1 << 24;
50 dt_conf_t::entry_t F32_ENTRY {dnnl_f32, -f32_max_exact, f32_max_exact, MIN_F32,
51 MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32};
52
CFG(f32)53 CFG(f32) {
54 return F32_ENTRY;
55 }
56
57 // bf16
58 #define MIN_BF16 0.0f
59 #define MAX_BF16 .999999f
60 #define MEAN_BF16 .5f
61 #define STDDEV_BF16 0.01f
62 #define EPS_BF16 epsilon_dt(dnnl_bf16)
63 dt_conf_t::entry_t BF16_ENTRY_BF16 {dnnl_bf16, -f32_max_exact, f32_max_exact,
64 MIN_BF16, MAX_BF16, MEAN_BF16, STDDEV_BF16, EPS_BF16};
65 dt_conf_t::entry_t BF16_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact,
66 MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_BF16};
67
CFG(bf16)68 CFG(bf16) {
69 CASE(SRC_LAYER, BF16_ENTRY_BF16);
70 CASE(SRC_ITER, BF16_ENTRY_BF16);
71 CASE(WEIGHTS_LAYER, BF16_ENTRY_BF16);
72 CASE(WEIGHTS_ITER, BF16_ENTRY_BF16);
73 CASE(DST_ITER, BF16_ENTRY_BF16);
74 CASE(DST_LAYER, BF16_ENTRY_BF16);
75 DEFAULT(BF16_ENTRY_F32);
76 }
77
78 // f16
79 const int f16_max_exact = 1 << 11;
80 dt_conf_t::entry_t F16_ENTRY {dnnl_f16, -f16_max_exact, f16_max_exact, 0.0f,
81 0.999999f, 0.5f, 0.01f, epsilon_dt(dnnl_f16)};
82
CFG(f16)83 CFG(f16) {
84 return F16_ENTRY;
85 }
86
87 // s8
88 #define EPS_U8 4e-3
89 #define EPS_S8 8e-3
90
91 #define MIN_U8 0.0f
92 #define MAX_U8 127.f
93 #define MEAN_U8 28.f
94 #define STDDEV_U8 16.f
95
96 #define MIN_S8 (-64.f)
97 #define MAX_S8 64.f
98 #define MEAN_S8 8.f
99 #define STDDEV_S8 32.f
100 #define MEAN_WEIGHT_S8 0.f
101
102 dt_conf_t::entry_t U8_ENTRY_U8_EXACT {
103 dnnl_u8, 0, UINT8_MAX, MIN_U8, MAX_U8, MEAN_U8, STDDEV_U8, 0.f};
104 dt_conf_t::entry_t U8_ENTRY_U8 {
105 dnnl_u8, 0, UINT8_MAX, MIN_U8, MAX_U8, MEAN_U8, STDDEV_U8, EPS_U8};
106 dt_conf_t::entry_t U8_ENTRY_S8 {dnnl_s8, INT8_MIN, INT8_MAX, MIN_S8, MAX_S8,
107 MEAN_WEIGHT_S8, STDDEV_S8, EPS_S8};
108 dt_conf_t::entry_t U8_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact,
109 MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32};
110
111 dt_conf_t::entry_t S8_ENTRY_S8_EXACT {
112 dnnl_s8, INT8_MIN, INT8_MAX, 0, MAX_S8, MEAN_S8, STDDEV_S8, 0.f};
113 dt_conf_t::entry_t S8_ENTRY_S8 {
114 dnnl_s8, INT8_MIN, INT8_MAX, 0, MAX_S8, MEAN_S8, STDDEV_S8, EPS_S8};
115 dt_conf_t::entry_t S8_ENTRY_WEIGHT_S8 {dnnl_s8, INT8_MIN, INT8_MAX, MIN_S8,
116 MAX_S8, MEAN_WEIGHT_S8, STDDEV_S8, EPS_S8};
117 dt_conf_t::entry_t S8_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact,
118 MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32};
119
CFG(u8u8u8u8)120 CFG(u8u8u8u8) {
121 CASE(SRC_LAYER, U8_ENTRY_U8);
122 CASE(SRC_ITER, U8_ENTRY_U8);
123 CASE(SRC_ITER_C, U8_ENTRY_F32);
124 CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
125 CASE(WEIGHTS_ITER, U8_ENTRY_S8);
126 CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
127 CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
128 CASE(BIAS, U8_ENTRY_F32);
129 CASE(DST_ITER, U8_ENTRY_U8);
130 CASE(DST_ITER_C, U8_ENTRY_F32);
131 CASE(DST_LAYER, U8_ENTRY_U8_EXACT);
132 END_LIST;
133 }
134
CFG(u8u8u8f32)135 CFG(u8u8u8f32) {
136 CASE(SRC_LAYER, U8_ENTRY_U8);
137 CASE(SRC_ITER, U8_ENTRY_U8);
138 CASE(SRC_ITER_C, U8_ENTRY_F32);
139 CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
140 CASE(WEIGHTS_ITER, U8_ENTRY_S8);
141 CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
142 CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
143 CASE(BIAS, U8_ENTRY_F32);
144 CASE(DST_ITER, U8_ENTRY_U8);
145 CASE(DST_ITER_C, U8_ENTRY_F32);
146 CASE(DST_LAYER, U8_ENTRY_F32);
147 END_LIST;
148 }
149
CFG(f32u8f32u8)150 CFG(f32u8f32u8) {
151 CASE(SRC_LAYER, U8_ENTRY_U8);
152 CASE(SRC_ITER, U8_ENTRY_F32);
153 CASE(SRC_ITER_C, U8_ENTRY_F32);
154 CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
155 CASE(WEIGHTS_ITER, U8_ENTRY_S8);
156 CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
157 CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
158 CASE(BIAS, U8_ENTRY_F32);
159 CASE(DST_ITER, U8_ENTRY_F32);
160 CASE(DST_ITER_C, U8_ENTRY_F32);
161 CASE(DST_LAYER, U8_ENTRY_U8_EXACT);
162 END_LIST;
163 }
164
CFG(f32u8f32f32)165 CFG(f32u8f32f32) {
166 CASE(SRC_LAYER, U8_ENTRY_U8);
167 CASE(SRC_ITER, U8_ENTRY_F32);
168 CASE(SRC_ITER_C, U8_ENTRY_F32);
169 CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
170 CASE(WEIGHTS_ITER, U8_ENTRY_S8);
171 CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
172 CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
173 CASE(BIAS, U8_ENTRY_F32);
174 CASE(DST_ITER, U8_ENTRY_F32);
175 CASE(DST_ITER_C, U8_ENTRY_F32);
176 CASE(DST_LAYER, U8_ENTRY_F32);
177 END_LIST;
178 }
179
CFG(s8s8s8s8)180 CFG(s8s8s8s8) {
181 CASE(SRC_LAYER, S8_ENTRY_S8);
182 CASE(SRC_ITER, S8_ENTRY_S8);
183 CASE(SRC_ITER_C, S8_ENTRY_F32);
184 CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
185 CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
186 CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
187 CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
188 CASE(BIAS, S8_ENTRY_F32);
189 CASE(DST_ITER, S8_ENTRY_S8);
190 CASE(DST_ITER_C, S8_ENTRY_F32);
191 CASE(DST_LAYER, S8_ENTRY_S8_EXACT);
192 END_LIST;
193 }
194
CFG(s8s8s8f32)195 CFG(s8s8s8f32) {
196 CASE(SRC_LAYER, S8_ENTRY_S8);
197 CASE(SRC_ITER, S8_ENTRY_S8);
198 CASE(SRC_ITER_C, S8_ENTRY_F32);
199 CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
200 CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
201 CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
202 CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
203 CASE(BIAS, S8_ENTRY_F32);
204 CASE(DST_ITER, S8_ENTRY_S8);
205 CASE(DST_ITER_C, S8_ENTRY_F32);
206 CASE(DST_LAYER, S8_ENTRY_F32);
207 END_LIST;
208 }
209
CFG(f32s8f32s8)210 CFG(f32s8f32s8) {
211 CASE(SRC_LAYER, S8_ENTRY_S8);
212 CASE(SRC_ITER, S8_ENTRY_F32);
213 CASE(SRC_ITER_C, S8_ENTRY_F32);
214 CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
215 CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
216 CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
217 CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
218 CASE(BIAS, S8_ENTRY_F32);
219 CASE(DST_ITER, S8_ENTRY_F32);
220 CASE(DST_ITER_C, S8_ENTRY_F32);
221 CASE(DST_LAYER, S8_ENTRY_S8_EXACT);
222 END_LIST;
223 }
224
CFG(f32s8f32f32)225 CFG(f32s8f32f32) {
226 CASE(SRC_LAYER, S8_ENTRY_S8);
227 CASE(SRC_ITER, S8_ENTRY_F32);
228 CASE(SRC_ITER_C, S8_ENTRY_F32);
229 CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
230 CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
231 CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
232 CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
233 CASE(BIAS, S8_ENTRY_F32);
234 CASE(DST_ITER, S8_ENTRY_F32);
235 CASE(DST_ITER_C, S8_ENTRY_F32);
236 CASE(DST_LAYER, S8_ENTRY_F32);
237 END_LIST;
238 }
239
240 } // namespace
241
create(const std::string & str)242 const dt_conf_t &dt_conf_t::create(const std::string &str) {
243 for (const auto cfg : cfg_list)
244 if (cfg->str() == str) return *cfg;
245 SAFE_V(CRIT);
246 return conf_f32;
247 }
248
249 } // namespace rnn
250