1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <string.h>
22
23 #include "oneapi/dnnl/dnnl.h"
24
25 #include "conv/conv.hpp"
26 #include "dnn_types.hpp"
27 #include "dnnl_common.hpp"
28 #include "dnnl_debug.hpp"
29
30 namespace conv {
31
str2alg(const char * str)32 alg_t str2alg(const char *str) {
33 #define CASE(_alg) \
34 if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
35 CASE(AUTO);
36 CASE(convolution_auto);
37 CASE(DIRECT);
38 CASE(convolution_direct);
39 CASE(WINO);
40 CASE(convolution_wino);
41 #undef CASE
42 assert(!"unknown algorithm");
43 return UNDEF;
44 }
45
alg2str(alg_t alg)46 const char *alg2str(alg_t alg) {
47 if (alg == AUTO) return "auto";
48 if (alg == DIRECT) return "direct";
49 if (alg == WINO) return "wino";
50 assert(!"unknown algorithm");
51 return "undef";
52 }
53
alg_kind2alg(dnnl_alg_kind_t alg)54 alg_t alg_kind2alg(dnnl_alg_kind_t alg) {
55 if (alg == dnnl_convolution_auto) return AUTO;
56 if (alg == dnnl_convolution_direct) return DIRECT;
57 if (alg == dnnl_convolution_winograd) return WINO;
58 assert(!"unknown algorithm");
59 return DIRECT;
60 }
61
str2desc(desc_t * desc,const char * str,bool is_deconv)62 int str2desc(desc_t *desc, const char *str, bool is_deconv) {
63 /* canonical form:
64 * gXmbX_icXidXihXiwX_ocXodXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwXnS
65 *
66 * where X is number, S - string
67 * note: symbol `_` is ignored
68 *
69 * implicit rules:
70 * - if smaller dimensions are not specified => square or cubic form;
71 * - if output is undefined => compute output;
72 * - if padding is undefined => compute trivial padding;
73 */
74
75 desc_t d {0};
76 d.g = 1;
77 d.mb = 2;
78 d.sd = d.sh = d.sw = 1;
79 d.pd = d.ph = d.pw = -1;
80
81 const char *s = str;
82 assert(s);
83
84 #define CASE_NN(prb, c) \
85 do { \
86 if (!strncmp(prb, s, strlen(prb))) { \
87 ok = 1; \
88 s += strlen(prb); \
89 char *end_s; \
90 d.c = strtol(s, &end_s, 10); \
91 s += (end_s - s); \
92 /* check any # groups, including one, works correctly */ \
93 if (!strncmp(prb, "g", 1)) d.has_groups = true; \
94 if (d.c < 0) return FAIL; \
95 /* printf("@@@debug: %s: %d\n", prb, d. c); */ \
96 } \
97 } while (0)
98 #define CASE_N(c) CASE_NN(#c, c)
99 while (*s) {
100 int ok = 0;
101 CASE_N(g);
102 CASE_N(mb);
103 CASE_N(ic);
104 CASE_N(id);
105 CASE_N(ih);
106 CASE_N(iw);
107 CASE_N(oc);
108 CASE_N(od);
109 CASE_N(oh);
110 CASE_N(ow);
111 CASE_N(kd);
112 CASE_N(kh);
113 CASE_N(kw);
114 CASE_N(sd);
115 CASE_N(sh);
116 CASE_N(sw);
117 CASE_N(pd);
118 CASE_N(ph);
119 CASE_N(pw);
120 CASE_N(dd);
121 CASE_N(dh);
122 CASE_N(dw);
123 if (*s == 'n') {
124 d.name = s + 1;
125 break;
126 }
127 if (*s == '_') ++s;
128 if (!ok) return FAIL;
129 }
130 #undef CASE_NN
131 #undef CASE_N
132
133 if (d.has_groups && d.g <= 0) return FAIL;
134 if (d.ic == 0 || d.oc == 0) return FAIL;
135 if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL;
136
137 auto compute_out = [](bool is_deconv, int64_t i, int64_t k, int64_t s,
138 int64_t p, int64_t d) {
139 if (is_deconv)
140 return (i - 1) * s + (k - 1) * (d + 1) - 2 * p + 1;
141 else
142 return (i - ((k - 1) * (d + 1) + 1) + 2 * p) / s + 1;
143 };
144 auto compute_pad = [](bool is_deconv, int64_t o, int64_t i, int64_t k,
145 int64_t s, int64_t d) {
146 if (is_deconv)
147 return ((i - 1) * s - o + ((k - 1) * (d + 1) + 1)) / 2;
148 else
149 return ((o - 1) * s - i + ((k - 1) * (d + 1) + 1)) / 2;
150 };
151
152 const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1;
153 const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1;
154 const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1;
155
156 if (!no_d) {
157 if (!d.id || !d.kd) return FAIL;
158 if (!d.od) {
159 if (d.pd < 0) d.pd = 0;
160 d.od = compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd);
161 if (d.od <= 0) return FAIL;
162 } else if (d.pd < 0)
163 d.pd = compute_pad(is_deconv, d.od, d.id, d.kd, d.sd, d.dd);
164 }
165
166 if (!no_h) {
167 if (!d.ih || !d.kh) return FAIL;
168 if (!d.oh) {
169 if (d.ph < 0) d.ph = 0;
170 d.oh = compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh);
171 if (d.oh <= 0) return FAIL;
172 } else if (d.ph < 0)
173 d.ph = compute_pad(is_deconv, d.oh, d.ih, d.kh, d.sh, d.dh);
174 }
175
176 if (!no_w) {
177 if (!d.iw || !d.kw) return FAIL;
178 if (!d.ow) {
179 if (d.pw < 0) d.pw = 0;
180 d.ow = compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw);
181 if (d.ow <= 0) return FAIL;
182 } else if (d.pw < 0)
183 d.pw = compute_pad(is_deconv, d.ow, d.iw, d.kw, d.sw, d.dw);
184 }
185
186 if (sanitize_desc(d.ndims, {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
187 {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
188 {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}, {1, 1, 1, 1, 0, 0}, true)
189 != OK)
190 return FAIL;
191
192 d.init_pad_r(is_deconv);
193 *desc = d;
194
195 return OK;
196 }
197
operator <<(std::ostream & s,const desc_t & d)198 std::ostream &operator<<(std::ostream &s, const desc_t &d) {
199 bool print_d = true, print_h = true, print_w = true;
200 print_dhw(print_d, print_h, print_w, d.ndims,
201 {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
202 {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
203 {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw});
204
205 auto print_spatial
206 = [&](const char *d_str, int64_t d_val, const char *h_str,
207 int64_t h_val, const char *w_str, int64_t w_val) {
208 if (print_d) s << d_str << d_val;
209 if (print_h) s << h_str << h_val;
210 if (print_w) s << w_str << w_val;
211 };
212
213 if (canonical || d.has_groups) s << "g" << d.g;
214 if (canonical || d.mb != 2) s << "mb" << d.mb;
215 s << "ic" << d.ic;
216 print_spatial("id", d.id, "ih", d.ih, "iw", d.iw);
217 s << "oc" << d.oc;
218 print_spatial("od", d.od, "oh", d.oh, "ow", d.ow);
219 print_spatial("kd", d.kd, "kh", d.kh, "kw", d.kw);
220
221 if (canonical || d.sh != 1 || d.sw != 1 || d.sd != 1)
222 print_spatial("sd", d.sd, "sh", d.sh, "sw", d.sw);
223
224 print_spatial("pd", d.pd, "ph", d.ph, "pw", d.pw);
225
226 if (canonical || d.dh != 0 || d.dw != 0 || d.dd != 0)
227 print_spatial("dd", d.dd, "dh", d.dh, "dw", d.dw);
228
229 if (d.name) s << "n" << d.name;
230
231 return s;
232 }
233
count_ops()234 void prb_t::count_ops() {
235 if (ops > 0) return;
236
237 int64_t od_t = is_deconv ? this->id : this->od;
238 int64_t oh_t = is_deconv ? this->ih : this->oh;
239 int64_t ow_t = is_deconv ? this->iw : this->ow;
240 int64_t id_t = is_deconv ? this->od : this->id;
241 int64_t ih_t = is_deconv ? this->oh : this->ih;
242 int64_t iw_t = is_deconv ? this->ow : this->iw;
243 double sp_ops = 0;
244 for_(int64_t od = 0; od < od_t; ++od)
245 for_(int64_t oh = 0; oh < oh_t; ++oh)
246 for (int64_t ow = 0; ow < ow_t; ++ow) {
247 for (int64_t kd = 0; kd < this->kd; ++kd) {
248 const int64_t id = od * this->sd - this->pd + kd * (this->dd + 1);
249 if (id < 0 || id >= id_t) continue;
250 for (int64_t kh = 0; kh < this->kh; ++kh) {
251 const int64_t ih
252 = oh * this->sh - this->ph + kh * (this->dh + 1);
253 if (ih < 0 || ih >= ih_t) continue;
254 for (int64_t kw = 0; kw < this->kw; ++kw) {
255 const int64_t iw
256 = ow * this->sw - this->pw + kw * (this->dw + 1);
257 if (iw < 0 || iw >= iw_t) continue;
258 sp_ops += 1;
259 }
260 }
261 }
262 }
263
264 ops = 2 * this->mb * this->oc * this->ic / this->g * sp_ops;
265 }
266
generate_oscales(const attr_t::scale_t & oscale,int N)267 float *generate_oscales(const attr_t::scale_t &oscale, int N) {
268 if (oscale.is_def()) return nullptr;
269
270 if (oscale.policy == policy_t::COMMON) {
271 float *scales = (float *)zmalloc(sizeof(float), 4);
272 SAFE_V(scales != nullptr ? OK : FAIL);
273 scales[0] = oscale.scale;
274 return scales;
275 }
276
277 assert(oscale.policy == policy_t::PER_OC);
278
279 float *scales = (float *)zmalloc(sizeof(float) * N, 64);
280 SAFE_V(scales != nullptr ? OK : FAIL);
281
282 const float K = 32;
283 /* scale in [1/K .. K], with starting point at oscale.scale */
284 float s[2] = {oscale.scale, oscale.scale / 2};
285 for (int64_t i = 0; i < N; ++i) {
286 int64_t si = i % 2; // 0 -> left, 1 -> right
287 scales[i] = s[si];
288 if (si == 0) {
289 s[si] /= 2.;
290 if (s[si] < 1. / K) s[si] *= K * K; // turn around to become ~K
291 } else {
292 s[si] *= 2.;
293 if (s[si] > K) s[si] /= K * K; // turn around to become ~K
294 }
295 }
296 return scales;
297 }
298
generate_zero_points(int arg,const attr_t::zero_points_t & zero_points,int N)299 int32_t *generate_zero_points(
300 int arg, const attr_t::zero_points_t &zero_points, int N) {
301 if (zero_points.is_def(arg)) return nullptr;
302
303 const auto &e = zero_points.get(arg);
304 if (e.policy == policy_t::COMMON) {
305 int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t), 4);
306 SAFE_V(zp != nullptr ? OK : FAIL);
307 zp[0] = e.value;
308 return zp;
309 }
310
311 assert(e.policy == policy_t::PER_DIM_1);
312
313 int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t) * N, 64);
314 SAFE_V(zp != nullptr ? OK : FAIL);
315
316 for (int i = 0; i < N; ++i)
317 zp[i] = e.value + i % 3;
318 return zp;
319 }
320
operator <<(std::ostream & s,const prb_t & prb)321 std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
322 dump_global_params(s);
323 settings_t def;
324
325 if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
326 if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " ";
327 if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " ";
328 if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " ";
329 if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " ";
330 if (canonical || prb.alg != def.alg[0])
331 s << "--alg=" << alg2str(prb.alg) << " ";
332
333 s << prb.attr;
334 s << static_cast<const desc_t &>(prb);
335
336 return s;
337 }
338
339 } // namespace conv
340