1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22 
23 #include "oneapi/dnnl/dnnl.h"
24 
25 #include "dnnl_common.hpp"
26 
27 #include "conv/conv_common.hpp"
28 
29 #define HALF_MAX 65504
30 #define HALF_MIN (-65504)
31 
32 namespace conv {
33 
34 /* cfgs definition
35  * arrays: SRC, WEI, BIA, DST, ACC
36  * params: {data_type, min, max, f_min, f_max, f_base, f_step, f_sparsity, eps}
37  */
38 
39 const int int_max_exact_half = 1 << 11;
40 const _dt_conf_t conf_f16 = {
41         {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
42                 0.},
43         {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
44                 0.},
45         {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
46                 0.},
47         {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
48                 0.},
49         {dnnl_f16},
50 };
51 
52 const double int_f16_max = max_dt(dnnl_f16);
53 const double int_f16_lowest = lowest_dt(dnnl_f16);
54 const _dt_conf_t conf_f16_no_limits = {
55         {dnnl_f16, int_f16_lowest, int_f16_max, -4, 4, 0, 1, .25, 0.},
56         {dnnl_f16, int_f16_lowest, int_f16_max, -2, 2, -2, 1, 1.0, 0.},
57         {dnnl_f16, int_f16_lowest, int_f16_max, -8, 8, 0, 1, 1.0, 0.},
58         {dnnl_f16, int_f16_lowest, int_f16_max, -4, 4, 0, 1, .25, 0.},
59         {dnnl_f16},
60 };
61 
62 const int int_max_exact = 1 << 24;
63 const _dt_conf_t conf_f16f16f32 = {
64         {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
65                 0.},
66         {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
67                 0.},
68         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
69         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
70         {dnnl_f16},
71 };
72 
73 const _dt_conf_t conf_f32 = {
74         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
75         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.},
76         {dnnl_f32, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.},
77         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
78         {dnnl_f32},
79 };
80 
81 const _dt_conf_t conf_f32_no_limits = {
82         {dnnl_f32, -FLT_MAX, FLT_MAX, -32, 32, 0, 1, .25, 0.},
83         {dnnl_f32, -FLT_MAX, FLT_MAX, -32, 32, 0, 1, 1.0, 0.},
84         {dnnl_f32, -FLT_MAX, FLT_MAX, -512, 512, 0, 1, 1.0, 0.},
85         {dnnl_f32, -FLT_MAX, FLT_MAX, -32, 32, 0, 1, .25, 0.},
86         {dnnl_f32},
87 };
88 
89 const _dt_conf_t conf_f32_full = {
90         {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, 1, 1.0, 0.},
91         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.},
92         {dnnl_f32, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.},
93         {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, 1, 1.0, 0.},
94         {dnnl_f32},
95 };
96 
97 const _dt_conf_t conf_f32_wino = {
98         {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 1e-5},
99         {dnnl_f32, -FLT_MAX, FLT_MAX, 2, 64, 2, 1, .75, 6e-6},
100         {dnnl_f32, -FLT_MAX, FLT_MAX, 1, 128, 1, 1, .25, 2e-7},
101         {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 2e-5},
102         {dnnl_f32},
103 };
104 
105 const _dt_conf_t conf_f16_wino = {
106         {dnnl_f16, HALF_MIN, HALF_MAX, -2, 16, 0, 1, .25, 5e-3},
107         {dnnl_f16, HALF_MIN, HALF_MAX, 1, 6, -2, 1, .5, 6e-3},
108         {dnnl_f16, HALF_MIN, HALF_MAX, 1, 2048, 0, 1, .25, 2e-3},
109         {dnnl_f16, HALF_MIN, HALF_MAX, -2, 8, 0, 1, .25, 7e-3},
110         {dnnl_f16},
111 };
112 
113 const _dt_conf_t conf_bf16bf16f32 = {
114         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
115         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
116         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
117         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
118         {dnnl_f32},
119 };
120 
121 const _dt_conf_t conf_bf16bf16bf16 = {
122         /* eps is 1e-2 because of loss in precision of
123      * output when converted from fp32 to bf16.
124      * oneDNN output is compared against reference computed in fp32.*/
125         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
126         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
127         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
128         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
129         {dnnl_f32},
130 };
131 
132 const _dt_conf_t conf_f32bf16bf16 = {
133         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
134         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
135         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
136         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
137         {dnnl_f32},
138 };
139 
140 const _dt_conf_t conf_f32f32s8 = {
141         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
142         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
143         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
144         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
145         {dnnl_f32},
146 };
147 
148 const _dt_conf_t conf_f16f16s8 = {
149         {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
150                 0.},
151         {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
152                 0.},
153         {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
154                 0.},
155         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
156         {dnnl_f16},
157 };
158 
159 const _dt_conf_t conf_bf16f32bf16 = {
160         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
161         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
162         {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
163         {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
164         {dnnl_f32},
165 };
166 
167 const _dt_conf_t conf_u8s8f32 = {
168         {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
169         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
170         {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
171         {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
172         {dnnl_s32},
173 };
174 
175 const _dt_conf_t conf_u8s8s32 = {
176         {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
177         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
178         {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
179         {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
180         {dnnl_s32},
181 };
182 
183 const _dt_conf_t conf_u8s8s8 = {
184         {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
185         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
186         {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
187         {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
188         {dnnl_s32},
189 };
190 
191 const _dt_conf_t conf_u8s8u8 = {
192         {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
193         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
194         {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
195         {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
196         {dnnl_s32},
197 };
198 
199 const _dt_conf_t conf_s8s8f32 = {
200         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
201         {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
202         {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
203         {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
204         {dnnl_s32},
205 };
206 
207 const _dt_conf_t conf_s8s8s32 = {
208         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
209         {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
210         {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
211         {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
212         {dnnl_s32},
213 };
214 
215 const _dt_conf_t conf_s8s8s8 = {
216         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
217         {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
218         {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
219         {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
220         {dnnl_s32},
221 };
222 
223 const _dt_conf_t conf_s8s8u8 = {
224         {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
225         {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
226         {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
227         {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
228         {dnnl_s32},
229 };
230 
231 const _dt_conf_t conf_u8s8f32_wino = {
232         {dnnl_u8, 0, UINT8_MAX, 0, 239, 0, 4, .25, 0.},
233         {dnnl_s8, INT8_MIN, INT8_MAX, -72, 71, 0, 9, .25, 0.},
234         {dnnl_f32, INT32_MIN, INT32_MAX, -9, 35, 0, 9, .25, 0.},
235         {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
236         {dnnl_s32},
237 };
238 
239 const _dt_conf_t conf_u8s8s32_wino = {
240         {dnnl_u8, 0, UINT8_MAX, 0, 239, 0, 4, .25, 0.},
241         {dnnl_s8, INT8_MIN, INT8_MAX, -72, 71, 0, 9, .25, 0.},
242         {dnnl_f32, INT32_MIN, INT32_MAX, -9, 35, 0, 9, .25, 0.},
243         {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
244         {dnnl_s32},
245 };
246 
247 const _dt_conf_t conf_u8s8s8_wino = {
248         {dnnl_u8, 0, UINT8_MAX, 0, 239, 0, 4, .25, 0.},
249         {dnnl_s8, INT8_MIN, INT8_MAX, -72, 71, 0, 9, .25, 0.},
250         {dnnl_f32, INT32_MIN, INT32_MAX, -9, 35, 0, 9, .25, 0.},
251         {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
252         {dnnl_s32},
253 };
254 
255 const _dt_conf_t conf_u8s8u8_wino = {
256         {dnnl_u8, 0, UINT8_MAX, 0, 239, 0, 4, .25, 0.},
257         {dnnl_s8, INT8_MIN, INT8_MAX, -72, 71, 0, 9, .25, 0.},
258         {dnnl_f32, INT32_MIN, INT32_MAX, -9, 35, 0, 9, .25, 0.},
259         {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
260         {dnnl_s32},
261 };
262 
str2cfg(const char * str)263 const dt_conf_t *str2cfg(const char *str) {
264 #define CASE(cfg) \
265     if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg)
266     CASE(f16);
267     CASE(f16_no_limits);
268     CASE(f32);
269     CASE(f32_no_limits);
270     CASE(f32_full);
271     CASE(f32_wino);
272     CASE(u8s8f32);
273     CASE(u8s8s32);
274     CASE(u8s8s8);
275     CASE(u8s8u8);
276     CASE(s8s8f32);
277     CASE(s8s8s32);
278     CASE(s8s8s8);
279     CASE(s8s8u8);
280     CASE(u8s8f32_wino);
281     CASE(u8s8s32_wino);
282     CASE(u8s8s8_wino);
283     CASE(u8s8u8_wino);
284     CASE(bf16bf16f32);
285     CASE(bf16bf16bf16);
286     CASE(f32bf16bf16);
287     CASE(bf16f32bf16);
288     CASE(f32f32s8);
289     CASE(f16f16f32);
290     CASE(f16f16s8);
291 #undef CASE
292     []() {
293         SAFE(FAIL, CRIT);
294         return 0;
295     }();
296     return (const dt_conf_t *)1;
297 }
298 
operator <<(std::ostream & s,const dt_conf_t * cfg)299 std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) {
300 #define CASE(_cfg) \
301     if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg)
302     CASE(f16);
303     CASE(f16_no_limits);
304     CASE(f32);
305     CASE(f32_no_limits);
306     CASE(f32_full);
307     CASE(f32_wino);
308     CASE(u8s8f32);
309     CASE(u8s8s32);
310     CASE(u8s8s8);
311     CASE(u8s8u8);
312     CASE(s8s8f32);
313     CASE(s8s8s32);
314     CASE(s8s8s8);
315     CASE(s8s8u8);
316     CASE(u8s8f32_wino);
317     CASE(u8s8s32_wino);
318     CASE(u8s8s8_wino);
319     CASE(u8s8u8_wino);
320     CASE(f16f16f32);
321     CASE(f16f16s8);
322     CASE(bf16bf16f32);
323     CASE(bf16bf16bf16);
324     CASE(f32bf16bf16);
325     CASE(f32f32s8);
326     CASE(bf16f32bf16);
327 #undef CASE
328     SAFE_V(FAIL);
329     return s;
330 }
331 
auto_cfg(const alg_t alg,const dt_conf_t * cfg)332 const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg) {
333     if (alg != WINO) return cfg;
334 
335     std::stringstream ss;
336     ss << cfg << "_wino";
337     const std::string cpp_pstr = ss.str();
338     const char *cfg_s = cpp_pstr.c_str();
339 #define CASE(_cfg_) \
340     if (!strcmp(cfg_s, STRINGIFY(_cfg_))) return CONCAT2(conf_, _cfg_)
341     CASE(f32_wino);
342     CASE(f16_wino);
343     CASE(u8s8f32_wino);
344     CASE(u8s8s32_wino);
345     CASE(u8s8s8_wino);
346     CASE(u8s8u8_wino);
347 #undef CASE
348     return cfg;
349 }
350 
351 } // namespace conv
352