1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22
23 #include "oneapi/dnnl/dnnl.h"
24
25 #include "dnnl_common.hpp"
26
27 #include "conv/conv_common.hpp"
28
29 #define HALF_MAX 65504
30 #define HALF_MIN (-65504)
31
32 namespace conv {
33
34 /* cfgs definition
35 * arrays: SRC, WEI, BIA, DST, ACC
36 * params: {data_type, min, max, f_min, f_max, f_base, f_step, f_sparsity, eps}
37 */
38
39 const int int_max_exact_half = 1 << 11;
40 const _dt_conf_t conf_f16 = {
41 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
42 0.},
43 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
44 0.},
45 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
46 0.},
47 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
48 0.},
49 {dnnl_f16},
50 };
51
52 const double int_f16_max = max_dt(dnnl_f16);
53 const double int_f16_lowest = lowest_dt(dnnl_f16);
54 const _dt_conf_t conf_f16_no_limits = {
55 {dnnl_f16, int_f16_lowest, int_f16_max, -4, 4, 0, 1, .25, 0.},
56 {dnnl_f16, int_f16_lowest, int_f16_max, -2, 2, -2, 1, 1.0, 0.},
57 {dnnl_f16, int_f16_lowest, int_f16_max, -8, 8, 0, 1, 1.0, 0.},
58 {dnnl_f16, int_f16_lowest, int_f16_max, -4, 4, 0, 1, .25, 0.},
59 {dnnl_f16},
60 };
61
62 const int int_max_exact = 1 << 24;
63 const _dt_conf_t conf_f16f16f32 = {
64 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
65 0.},
66 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
67 0.},
68 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
69 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
70 {dnnl_f16},
71 };
72
73 const _dt_conf_t conf_f32 = {
74 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
75 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.},
76 {dnnl_f32, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.},
77 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
78 {dnnl_f32},
79 };
80
81 const _dt_conf_t conf_f32_no_limits = {
82 {dnnl_f32, -FLT_MAX, FLT_MAX, -32, 32, 0, 1, .25, 0.},
83 {dnnl_f32, -FLT_MAX, FLT_MAX, -32, 32, 0, 1, 1.0, 0.},
84 {dnnl_f32, -FLT_MAX, FLT_MAX, -512, 512, 0, 1, 1.0, 0.},
85 {dnnl_f32, -FLT_MAX, FLT_MAX, -32, 32, 0, 1, .25, 0.},
86 {dnnl_f32},
87 };
88
89 const _dt_conf_t conf_f32_full = {
90 {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, 1, 1.0, 0.},
91 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.},
92 {dnnl_f32, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.},
93 {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, 1, 1.0, 0.},
94 {dnnl_f32},
95 };
96
97 const _dt_conf_t conf_f32_wino = {
98 {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 1e-5},
99 {dnnl_f32, -FLT_MAX, FLT_MAX, 2, 64, 2, 1, .75, 6e-6},
100 {dnnl_f32, -FLT_MAX, FLT_MAX, 1, 128, 1, 1, .25, 2e-7},
101 {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 2e-5},
102 {dnnl_f32},
103 };
104
105 const _dt_conf_t conf_f16_wino = {
106 {dnnl_f16, HALF_MIN, HALF_MAX, -2, 16, 0, 1, .25, 5e-3},
107 {dnnl_f16, HALF_MIN, HALF_MAX, 1, 6, -2, 1, .5, 6e-3},
108 {dnnl_f16, HALF_MIN, HALF_MAX, 1, 2048, 0, 1, .25, 2e-3},
109 {dnnl_f16, HALF_MIN, HALF_MAX, -2, 8, 0, 1, .25, 7e-3},
110 {dnnl_f16},
111 };
112
113 const _dt_conf_t conf_bf16bf16f32 = {
114 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
115 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
116 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
117 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
118 {dnnl_f32},
119 };
120
121 const _dt_conf_t conf_bf16bf16bf16 = {
122 /* eps is 1e-2 because of loss in precision of
123 * output when converted from fp32 to bf16.
124 * oneDNN output is compared against reference computed in fp32.*/
125 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
126 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
127 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
128 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
129 {dnnl_f32},
130 };
131
132 const _dt_conf_t conf_f32bf16bf16 = {
133 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
134 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
135 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
136 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
137 {dnnl_f32},
138 };
139
140 const _dt_conf_t conf_f32f32s8 = {
141 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
142 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
143 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
144 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
145 {dnnl_f32},
146 };
147
148 const _dt_conf_t conf_f16f16s8 = {
149 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
150 0.},
151 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
152 0.},
153 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
154 0.},
155 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
156 {dnnl_f16},
157 };
158
159 const _dt_conf_t conf_bf16f32bf16 = {
160 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
161 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
162 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
163 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
164 {dnnl_f32},
165 };
166
167 const _dt_conf_t conf_u8s8f32 = {
168 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
169 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
170 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
171 {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
172 {dnnl_s32},
173 };
174
175 const _dt_conf_t conf_u8s8s32 = {
176 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
177 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
178 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
179 {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
180 {dnnl_s32},
181 };
182
183 const _dt_conf_t conf_u8s8s8 = {
184 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
185 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
186 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
187 {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
188 {dnnl_s32},
189 };
190
191 const _dt_conf_t conf_u8s8u8 = {
192 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
193 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
194 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
195 {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
196 {dnnl_s32},
197 };
198
199 const _dt_conf_t conf_s8s8f32 = {
200 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
201 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
202 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
203 {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
204 {dnnl_s32},
205 };
206
207 const _dt_conf_t conf_s8s8s32 = {
208 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
209 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
210 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
211 {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
212 {dnnl_s32},
213 };
214
215 const _dt_conf_t conf_s8s8s8 = {
216 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
217 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
218 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
219 {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
220 {dnnl_s32},
221 };
222
223 const _dt_conf_t conf_s8s8u8 = {
224 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
225 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
226 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
227 {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
228 {dnnl_s32},
229 };
230
231 const _dt_conf_t conf_u8s8f32_wino = {
232 {dnnl_u8, 0, UINT8_MAX, 0, 239, 0, 4, .25, 0.},
233 {dnnl_s8, INT8_MIN, INT8_MAX, -72, 71, 0, 9, .25, 0.},
234 {dnnl_f32, INT32_MIN, INT32_MAX, -9, 35, 0, 9, .25, 0.},
235 {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
236 {dnnl_s32},
237 };
238
239 const _dt_conf_t conf_u8s8s32_wino = {
240 {dnnl_u8, 0, UINT8_MAX, 0, 239, 0, 4, .25, 0.},
241 {dnnl_s8, INT8_MIN, INT8_MAX, -72, 71, 0, 9, .25, 0.},
242 {dnnl_f32, INT32_MIN, INT32_MAX, -9, 35, 0, 9, .25, 0.},
243 {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
244 {dnnl_s32},
245 };
246
247 const _dt_conf_t conf_u8s8s8_wino = {
248 {dnnl_u8, 0, UINT8_MAX, 0, 239, 0, 4, .25, 0.},
249 {dnnl_s8, INT8_MIN, INT8_MAX, -72, 71, 0, 9, .25, 0.},
250 {dnnl_f32, INT32_MIN, INT32_MAX, -9, 35, 0, 9, .25, 0.},
251 {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
252 {dnnl_s32},
253 };
254
255 const _dt_conf_t conf_u8s8u8_wino = {
256 {dnnl_u8, 0, UINT8_MAX, 0, 239, 0, 4, .25, 0.},
257 {dnnl_s8, INT8_MIN, INT8_MAX, -72, 71, 0, 9, .25, 0.},
258 {dnnl_f32, INT32_MIN, INT32_MAX, -9, 35, 0, 9, .25, 0.},
259 {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
260 {dnnl_s32},
261 };
262
str2cfg(const char * str)263 const dt_conf_t *str2cfg(const char *str) {
264 #define CASE(cfg) \
265 if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg)
266 CASE(f16);
267 CASE(f16_no_limits);
268 CASE(f32);
269 CASE(f32_no_limits);
270 CASE(f32_full);
271 CASE(f32_wino);
272 CASE(u8s8f32);
273 CASE(u8s8s32);
274 CASE(u8s8s8);
275 CASE(u8s8u8);
276 CASE(s8s8f32);
277 CASE(s8s8s32);
278 CASE(s8s8s8);
279 CASE(s8s8u8);
280 CASE(u8s8f32_wino);
281 CASE(u8s8s32_wino);
282 CASE(u8s8s8_wino);
283 CASE(u8s8u8_wino);
284 CASE(bf16bf16f32);
285 CASE(bf16bf16bf16);
286 CASE(f32bf16bf16);
287 CASE(bf16f32bf16);
288 CASE(f32f32s8);
289 CASE(f16f16f32);
290 CASE(f16f16s8);
291 #undef CASE
292 []() {
293 SAFE(FAIL, CRIT);
294 return 0;
295 }();
296 return (const dt_conf_t *)1;
297 }
298
operator <<(std::ostream & s,const dt_conf_t * cfg)299 std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) {
300 #define CASE(_cfg) \
301 if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg)
302 CASE(f16);
303 CASE(f16_no_limits);
304 CASE(f32);
305 CASE(f32_no_limits);
306 CASE(f32_full);
307 CASE(f32_wino);
308 CASE(u8s8f32);
309 CASE(u8s8s32);
310 CASE(u8s8s8);
311 CASE(u8s8u8);
312 CASE(s8s8f32);
313 CASE(s8s8s32);
314 CASE(s8s8s8);
315 CASE(s8s8u8);
316 CASE(u8s8f32_wino);
317 CASE(u8s8s32_wino);
318 CASE(u8s8s8_wino);
319 CASE(u8s8u8_wino);
320 CASE(f16f16f32);
321 CASE(f16f16s8);
322 CASE(bf16bf16f32);
323 CASE(bf16bf16bf16);
324 CASE(f32bf16bf16);
325 CASE(f32f32s8);
326 CASE(bf16f32bf16);
327 #undef CASE
328 SAFE_V(FAIL);
329 return s;
330 }
331
auto_cfg(const alg_t alg,const dt_conf_t * cfg)332 const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg) {
333 if (alg != WINO) return cfg;
334
335 std::stringstream ss;
336 ss << cfg << "_wino";
337 const std::string cpp_pstr = ss.str();
338 const char *cfg_s = cpp_pstr.c_str();
339 #define CASE(_cfg_) \
340 if (!strcmp(cfg_s, STRINGIFY(_cfg_))) return CONCAT2(conf_, _cfg_)
341 CASE(f32_wino);
342 CASE(f16_wino);
343 CASE(u8s8f32_wino);
344 CASE(u8s8s32_wino);
345 CASE(u8s8s8_wino);
346 CASE(u8s8u8_wino);
347 #undef CASE
348 return cfg;
349 }
350
351 } // namespace conv
352