1 /*******************************************************************************
2 * Copyright 2018-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <string.h>
20
21 #include "oneapi/dnnl/dnnl.h"
22
23 #include "dnnl_common.hpp"
24 #include "dnnl_debug.hpp"
25
26 #include "ip/ip.hpp"
27
28 namespace ip {
29
generate_oscales()30 void prb_t::generate_oscales() {
31 if (attr.oscale.is_def()) return;
32
33 if (attr.oscale.policy == policy_t::COMMON) {
34 scales = (float *)zmalloc(sizeof(float), 4);
35 SAFE_V(scales != nullptr ? OK : FAIL);
36 scales[0] = attr.oscale.scale;
37 return;
38 }
39
40 assert(attr.oscale.policy == policy_t::PER_OC);
41
42 scales = (float *)zmalloc(sizeof(float) * oc, 64);
43 SAFE_V(scales != nullptr ? OK : FAIL);
44
45 const float K = 32;
46 /* scale in [1/K .. K], with starting point at oscale.scale */
47 float s[2] = {attr.oscale.scale, attr.oscale.scale / 2};
48 for (int64_t i = 0; i < oc; ++i) {
49 int64_t si = i % 2; // 0 -> left, 1 -> right
50 scales[i] = s[si];
51 if (si == 0) {
52 s[si] /= 2.;
53 if (s[si] < 1. / K) s[si] *= K * K; // turn around to become ~K
54 } else {
55 s[si] *= 2.;
56 if (s[si] > K) s[si] /= K * K; // turn around to become ~K
57 }
58 }
59 }
60
str2desc(desc_t * desc,const char * str)61 int str2desc(desc_t *desc, const char *str) {
62 // Canonical form: mbXicXidXihXiwXocXnS,
63 // where
64 // X is integer
65 // S is string
66 // note: symbol `_` is ignored.
67 // Cubic/square shapes are supported by specifying just highest dimension.
68
69 desc_t d {0};
70 d.mb = 2;
71
72 const char *s = str;
73 assert(s);
74
75 #define CASE_NN(prb, c) \
76 do { \
77 if (!strncmp(prb, s, strlen(prb))) { \
78 ok = 1; \
79 s += strlen(prb); \
80 char *end_s; \
81 d.c = strtol(s, &end_s, 10); \
82 s += (end_s - s); \
83 if (d.c < 0) return FAIL; \
84 /* printf("@@@debug: %s: %d\n", prb, d. c); */ \
85 } \
86 } while (0)
87 #define CASE_N(c) CASE_NN(#c, c)
88 while (*s) {
89 int ok = 0;
90 CASE_N(mb);
91 CASE_N(ic);
92 CASE_N(ih);
93 CASE_N(iw);
94 CASE_N(id);
95 CASE_N(oc);
96 if (*s == 'n') {
97 d.name = s + 1;
98 break;
99 }
100 if (*s == '_') ++s;
101 if (!ok) return FAIL;
102 }
103 #undef CASE_NN
104 #undef CASE_N
105
106 if (d.ic == 0 || d.oc == 0) return FAIL;
107
108 if (sanitize_desc(d.ndims, {d.id}, {d.ih}, {d.iw}, {1}) != OK) return FAIL;
109
110 *desc = d;
111
112 return OK;
113 }
114
operator <<(std::ostream & s,const desc_t & d)115 std::ostream &operator<<(std::ostream &s, const desc_t &d) {
116 bool print_d = true, print_h = true, print_w = true;
117 print_dhw(print_d, print_h, print_w, d.ndims, {d.id}, {d.ih}, {d.iw});
118
119 if (canonical || d.mb != 2) s << "mb" << d.mb;
120
121 s << "ic" << d.ic;
122
123 if (print_d) s << "id" << d.id;
124 if (print_h) s << "ih" << d.ih;
125 if (print_w) s << "iw" << d.iw;
126
127 s << "oc" << d.oc;
128
129 if (d.name) s << "n" << d.name;
130
131 return s;
132 }
133
operator <<(std::ostream & s,const prb_t & prb)134 std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
135 dump_global_params(s);
136 settings_t def;
137
138 if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
139 if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " ";
140 if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " ";
141 if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " ";
142 if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " ";
143
144 s << prb.attr;
145 s << static_cast<const desc_t &>(prb);
146
147 return s;
148 }
149
150 } // namespace ip
151