1 /*******************************************************************************
2 * Copyright 2018-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <string.h>
20 
21 #include "oneapi/dnnl/dnnl.h"
22 
23 #include "dnnl_common.hpp"
24 #include "dnnl_debug.hpp"
25 
26 #include "ip/ip.hpp"
27 
28 namespace ip {
29 
generate_oscales()30 void prb_t::generate_oscales() {
31     if (attr.oscale.is_def()) return;
32 
33     if (attr.oscale.policy == policy_t::COMMON) {
34         scales = (float *)zmalloc(sizeof(float), 4);
35         SAFE_V(scales != nullptr ? OK : FAIL);
36         scales[0] = attr.oscale.scale;
37         return;
38     }
39 
40     assert(attr.oscale.policy == policy_t::PER_OC);
41 
42     scales = (float *)zmalloc(sizeof(float) * oc, 64);
43     SAFE_V(scales != nullptr ? OK : FAIL);
44 
45     const float K = 32;
46     /* scale in [1/K .. K], with starting point at oscale.scale */
47     float s[2] = {attr.oscale.scale, attr.oscale.scale / 2};
48     for (int64_t i = 0; i < oc; ++i) {
49         int64_t si = i % 2; // 0 -> left, 1 -> right
50         scales[i] = s[si];
51         if (si == 0) {
52             s[si] /= 2.;
53             if (s[si] < 1. / K) s[si] *= K * K; // turn around to become ~K
54         } else {
55             s[si] *= 2.;
56             if (s[si] > K) s[si] /= K * K; // turn around to become ~K
57         }
58     }
59 }
60 
str2desc(desc_t * desc,const char * str)61 int str2desc(desc_t *desc, const char *str) {
62     // Canonical form: mbXicXidXihXiwXocXnS,
63     // where
64     //     X is integer
65     //     S is string
66     // note: symbol `_` is ignored.
67     // Cubic/square shapes are supported by specifying just highest dimension.
68 
69     desc_t d {0};
70     d.mb = 2;
71 
72     const char *s = str;
73     assert(s);
74 
75 #define CASE_NN(prb, c) \
76     do { \
77         if (!strncmp(prb, s, strlen(prb))) { \
78             ok = 1; \
79             s += strlen(prb); \
80             char *end_s; \
81             d.c = strtol(s, &end_s, 10); \
82             s += (end_s - s); \
83             if (d.c < 0) return FAIL; \
84             /* printf("@@@debug: %s: %d\n", prb, d. c); */ \
85         } \
86     } while (0)
87 #define CASE_N(c) CASE_NN(#c, c)
88     while (*s) {
89         int ok = 0;
90         CASE_N(mb);
91         CASE_N(ic);
92         CASE_N(ih);
93         CASE_N(iw);
94         CASE_N(id);
95         CASE_N(oc);
96         if (*s == 'n') {
97             d.name = s + 1;
98             break;
99         }
100         if (*s == '_') ++s;
101         if (!ok) return FAIL;
102     }
103 #undef CASE_NN
104 #undef CASE_N
105 
106     if (d.ic == 0 || d.oc == 0) return FAIL;
107 
108     if (sanitize_desc(d.ndims, {d.id}, {d.ih}, {d.iw}, {1}) != OK) return FAIL;
109 
110     *desc = d;
111 
112     return OK;
113 }
114 
operator <<(std::ostream & s,const desc_t & d)115 std::ostream &operator<<(std::ostream &s, const desc_t &d) {
116     bool print_d = true, print_h = true, print_w = true;
117     print_dhw(print_d, print_h, print_w, d.ndims, {d.id}, {d.ih}, {d.iw});
118 
119     if (canonical || d.mb != 2) s << "mb" << d.mb;
120 
121     s << "ic" << d.ic;
122 
123     if (print_d) s << "id" << d.id;
124     if (print_h) s << "ih" << d.ih;
125     if (print_w) s << "iw" << d.iw;
126 
127     s << "oc" << d.oc;
128 
129     if (d.name) s << "n" << d.name;
130 
131     return s;
132 }
133 
operator <<(std::ostream & s,const prb_t & prb)134 std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
135     dump_global_params(s);
136     settings_t def;
137 
138     if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
139     if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " ";
140     if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " ";
141     if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " ";
142     if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " ";
143 
144     s << prb.attr;
145     s << static_cast<const desc_t &>(prb);
146 
147     return s;
148 }
149 
150 } // namespace ip
151