1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22 
23 #include "oneapi/dnnl/dnnl.h"
24 
25 #include "conv/conv.hpp"
26 #include "dnn_types.hpp"
27 #include "dnnl_common.hpp"
28 #include "dnnl_debug.hpp"
29 
30 namespace conv {
31 
str2alg(const char * str)32 alg_t str2alg(const char *str) {
33 #define CASE(_alg) \
34     if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
35     CASE(AUTO);
36     CASE(convolution_auto);
37     CASE(DIRECT);
38     CASE(convolution_direct);
39     CASE(WINO);
40     CASE(convolution_wino);
41 #undef CASE
42     assert(!"unknown algorithm");
43     return UNDEF;
44 }
45 
alg2str(alg_t alg)46 const char *alg2str(alg_t alg) {
47     if (alg == AUTO) return "auto";
48     if (alg == DIRECT) return "direct";
49     if (alg == WINO) return "wino";
50     assert(!"unknown algorithm");
51     return "undef";
52 }
53 
alg_kind2alg(dnnl_alg_kind_t alg)54 alg_t alg_kind2alg(dnnl_alg_kind_t alg) {
55     if (alg == dnnl_convolution_auto) return AUTO;
56     if (alg == dnnl_convolution_direct) return DIRECT;
57     if (alg == dnnl_convolution_winograd) return WINO;
58     assert(!"unknown algorithm");
59     return DIRECT;
60 }
61 
str2desc(desc_t * desc,const char * str,bool is_deconv)62 int str2desc(desc_t *desc, const char *str, bool is_deconv) {
63     /* canonical form:
64      * gXmbX_icXidXihXiwX_ocXodXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwXnS
65      *
66      * where X is number, S - string
67      * note: symbol `_` is ignored
68      *
69      * implicit rules:
70      *  - if smaller dimensions are not specified => square or cubic form;
71      *  - if output is undefined => compute output;
72      *  - if padding is undefined => compute trivial padding;
73      */
74 
75     desc_t d {0};
76     d.g = 1;
77     d.mb = 2;
78     d.sd = d.sh = d.sw = 1;
79     d.pd = d.ph = d.pw = -1;
80 
81     const char *s = str;
82     assert(s);
83 
84 #define CASE_NN(prb, c) \
85     do { \
86         if (!strncmp(prb, s, strlen(prb))) { \
87             ok = 1; \
88             s += strlen(prb); \
89             char *end_s; \
90             d.c = strtol(s, &end_s, 10); \
91             s += (end_s - s); \
92             /* check any # groups, including one, works correctly */ \
93             if (!strncmp(prb, "g", 1)) d.has_groups = true; \
94             if (d.c < 0) return FAIL; \
95             /* printf("@@@debug: %s: %d\n", prb, d. c); */ \
96         } \
97     } while (0)
98 #define CASE_N(c) CASE_NN(#c, c)
99     while (*s) {
100         int ok = 0;
101         CASE_N(g);
102         CASE_N(mb);
103         CASE_N(ic);
104         CASE_N(id);
105         CASE_N(ih);
106         CASE_N(iw);
107         CASE_N(oc);
108         CASE_N(od);
109         CASE_N(oh);
110         CASE_N(ow);
111         CASE_N(kd);
112         CASE_N(kh);
113         CASE_N(kw);
114         CASE_N(sd);
115         CASE_N(sh);
116         CASE_N(sw);
117         CASE_N(pd);
118         CASE_N(ph);
119         CASE_N(pw);
120         CASE_N(dd);
121         CASE_N(dh);
122         CASE_N(dw);
123         if (*s == 'n') {
124             d.name = s + 1;
125             break;
126         }
127         if (*s == '_') ++s;
128         if (!ok) return FAIL;
129     }
130 #undef CASE_NN
131 #undef CASE_N
132 
133     if (d.has_groups && d.g <= 0) return FAIL;
134     if (d.ic == 0 || d.oc == 0) return FAIL;
135     if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL;
136 
137     auto compute_out = [](bool is_deconv, int64_t i, int64_t k, int64_t s,
138                                int64_t p, int64_t d) {
139         if (is_deconv)
140             return (i - 1) * s + (k - 1) * (d + 1) - 2 * p + 1;
141         else
142             return (i - ((k - 1) * (d + 1) + 1) + 2 * p) / s + 1;
143     };
144     auto compute_pad = [](bool is_deconv, int64_t o, int64_t i, int64_t k,
145                                int64_t s, int64_t d) {
146         if (is_deconv)
147             return ((i - 1) * s - o + ((k - 1) * (d + 1) + 1)) / 2;
148         else
149             return ((o - 1) * s - i + ((k - 1) * (d + 1) + 1)) / 2;
150     };
151 
152     const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1;
153     const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1;
154     const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1;
155 
156     if (!no_d) {
157         if (!d.id || !d.kd) return FAIL;
158         if (!d.od) {
159             if (d.pd < 0) d.pd = 0;
160             d.od = compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd);
161             if (d.od <= 0) return FAIL;
162         } else if (d.pd < 0)
163             d.pd = compute_pad(is_deconv, d.od, d.id, d.kd, d.sd, d.dd);
164     }
165 
166     if (!no_h) {
167         if (!d.ih || !d.kh) return FAIL;
168         if (!d.oh) {
169             if (d.ph < 0) d.ph = 0;
170             d.oh = compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh);
171             if (d.oh <= 0) return FAIL;
172         } else if (d.ph < 0)
173             d.ph = compute_pad(is_deconv, d.oh, d.ih, d.kh, d.sh, d.dh);
174     }
175 
176     if (!no_w) {
177         if (!d.iw || !d.kw) return FAIL;
178         if (!d.ow) {
179             if (d.pw < 0) d.pw = 0;
180             d.ow = compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw);
181             if (d.ow <= 0) return FAIL;
182         } else if (d.pw < 0)
183             d.pw = compute_pad(is_deconv, d.ow, d.iw, d.kw, d.sw, d.dw);
184     }
185 
186     if (sanitize_desc(d.ndims, {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
187                 {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
188                 {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}, {1, 1, 1, 1, 0, 0}, true)
189             != OK)
190         return FAIL;
191 
192     d.init_pad_r(is_deconv);
193     *desc = d;
194 
195     return OK;
196 }
197 
operator <<(std::ostream & s,const desc_t & d)198 std::ostream &operator<<(std::ostream &s, const desc_t &d) {
199     bool print_d = true, print_h = true, print_w = true;
200     print_dhw(print_d, print_h, print_w, d.ndims,
201             {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
202             {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
203             {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw});
204 
205     auto print_spatial
206             = [&](const char *d_str, int64_t d_val, const char *h_str,
207                       int64_t h_val, const char *w_str, int64_t w_val) {
208                   if (print_d) s << d_str << d_val;
209                   if (print_h) s << h_str << h_val;
210                   if (print_w) s << w_str << w_val;
211               };
212 
213     if (canonical || d.has_groups) s << "g" << d.g;
214     if (canonical || d.mb != 2) s << "mb" << d.mb;
215     s << "ic" << d.ic;
216     print_spatial("id", d.id, "ih", d.ih, "iw", d.iw);
217     s << "oc" << d.oc;
218     print_spatial("od", d.od, "oh", d.oh, "ow", d.ow);
219     print_spatial("kd", d.kd, "kh", d.kh, "kw", d.kw);
220 
221     if (canonical || d.sh != 1 || d.sw != 1 || d.sd != 1)
222         print_spatial("sd", d.sd, "sh", d.sh, "sw", d.sw);
223 
224     print_spatial("pd", d.pd, "ph", d.ph, "pw", d.pw);
225 
226     if (canonical || d.dh != 0 || d.dw != 0 || d.dd != 0)
227         print_spatial("dd", d.dd, "dh", d.dh, "dw", d.dw);
228 
229     if (d.name) s << "n" << d.name;
230 
231     return s;
232 }
233 
count_ops()234 void prb_t::count_ops() {
235     if (ops > 0) return;
236 
237     int64_t od_t = is_deconv ? this->id : this->od;
238     int64_t oh_t = is_deconv ? this->ih : this->oh;
239     int64_t ow_t = is_deconv ? this->iw : this->ow;
240     int64_t id_t = is_deconv ? this->od : this->id;
241     int64_t ih_t = is_deconv ? this->oh : this->ih;
242     int64_t iw_t = is_deconv ? this->ow : this->iw;
243     double sp_ops = 0;
244     for_(int64_t od = 0; od < od_t; ++od)
245     for_(int64_t oh = 0; oh < oh_t; ++oh)
246     for (int64_t ow = 0; ow < ow_t; ++ow) {
247         for (int64_t kd = 0; kd < this->kd; ++kd) {
248             const int64_t id = od * this->sd - this->pd + kd * (this->dd + 1);
249             if (id < 0 || id >= id_t) continue;
250             for (int64_t kh = 0; kh < this->kh; ++kh) {
251                 const int64_t ih
252                         = oh * this->sh - this->ph + kh * (this->dh + 1);
253                 if (ih < 0 || ih >= ih_t) continue;
254                 for (int64_t kw = 0; kw < this->kw; ++kw) {
255                     const int64_t iw
256                             = ow * this->sw - this->pw + kw * (this->dw + 1);
257                     if (iw < 0 || iw >= iw_t) continue;
258                     sp_ops += 1;
259                 }
260             }
261         }
262     }
263 
264     ops = 2 * this->mb * this->oc * this->ic / this->g * sp_ops;
265 }
266 
generate_oscales(const attr_t::scale_t & oscale,int N)267 float *generate_oscales(const attr_t::scale_t &oscale, int N) {
268     if (oscale.is_def()) return nullptr;
269 
270     if (oscale.policy == policy_t::COMMON) {
271         float *scales = (float *)zmalloc(sizeof(float), 4);
272         SAFE_V(scales != nullptr ? OK : FAIL);
273         scales[0] = oscale.scale;
274         return scales;
275     }
276 
277     assert(oscale.policy == policy_t::PER_OC);
278 
279     float *scales = (float *)zmalloc(sizeof(float) * N, 64);
280     SAFE_V(scales != nullptr ? OK : FAIL);
281 
282     const float K = 32;
283     /* scale in [1/K .. K], with starting point at oscale.scale */
284     float s[2] = {oscale.scale, oscale.scale / 2};
285     for (int64_t i = 0; i < N; ++i) {
286         int64_t si = i % 2; // 0 -> left, 1 -> right
287         scales[i] = s[si];
288         if (si == 0) {
289             s[si] /= 2.;
290             if (s[si] < 1. / K) s[si] *= K * K; // turn around to become ~K
291         } else {
292             s[si] *= 2.;
293             if (s[si] > K) s[si] /= K * K; // turn around to become ~K
294         }
295     }
296     return scales;
297 }
298 
generate_zero_points(int arg,const attr_t::zero_points_t & zero_points,int N)299 int32_t *generate_zero_points(
300         int arg, const attr_t::zero_points_t &zero_points, int N) {
301     if (zero_points.is_def(arg)) return nullptr;
302 
303     const auto &e = zero_points.get(arg);
304     if (e.policy == policy_t::COMMON) {
305         int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t), 4);
306         SAFE_V(zp != nullptr ? OK : FAIL);
307         zp[0] = e.value;
308         return zp;
309     }
310 
311     assert(e.policy == policy_t::PER_DIM_1);
312 
313     int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t) * N, 64);
314     SAFE_V(zp != nullptr ? OK : FAIL);
315 
316     for (int i = 0; i < N; ++i)
317         zp[i] = e.value + i % 3;
318     return zp;
319 }
320 
operator <<(std::ostream & s,const prb_t & prb)321 std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
322     dump_global_params(s);
323     settings_t def;
324 
325     if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
326     if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " ";
327     if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " ";
328     if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " ";
329     if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " ";
330     if (canonical || prb.alg != def.alg[0])
331         s << "--alg=" << alg2str(prb.alg) << " ";
332 
333     s << prb.attr;
334     s << static_cast<const desc_t &>(prb);
335 
336     return s;
337 }
338 
339 } // namespace conv
340