1 /*******************************************************************************
2 * Copyright 2017-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 #include <stdlib.h>
19 
20 #include "lrn/lrn.hpp"
21 
22 namespace lrn {
23 
str2alg(const char * str)24 alg_t str2alg(const char *str) {
25 #define CASE(_alg) \
26     if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
27     CASE(ACROSS);
28     CASE(WITHIN);
29 #undef CASE
30     assert(!"unknown algorithm");
31     return ACROSS;
32 }
33 
alg2str(alg_t alg)34 const char *alg2str(alg_t alg) {
35     if (alg == ACROSS) return "ACROSS";
36     if (alg == WITHIN) return "WITHIN";
37     assert(!"unknown algorithm");
38     return "unknown algorithm";
39 }
40 
alg2alg_kind(alg_t alg)41 dnnl_alg_kind_t alg2alg_kind(alg_t alg) {
42     if (alg == ACROSS) return dnnl_lrn_across_channels;
43     if (alg == WITHIN) return dnnl_lrn_within_channel;
44     assert(!"unknown algorithm");
45     return dnnl_alg_kind_undef;
46 }
47 
str2desc(desc_t * desc,const char * str)48 int str2desc(desc_t *desc, const char *str) {
49     // Canonical form: mbXicXidXihXiwX_lsXalphaYbetaYkY_nS,
50     // where
51     //     X is integer
52     //     Y is float
53     //     S is string
54     // note: symbol `_` is ignored.
55     // Cubic/square shapes are supported by specifying just highest dimension.
56 
57     desc_t d {0};
58     d.mb = 2;
59     d.ls = 5;
60     d.alpha = 1.f / 8192; // = 0.000122 ~~ 0.0001, but has exact representation
61     d.beta = 0.75f;
62     d.k = 1;
63 
64     const char *s = str;
65     assert(s);
66 
67     auto mstrtol = [](const char *nptr, char **endptr) {
68         return strtol(nptr, endptr, 10);
69     };
70 
71 #define CASE_NN(prb, c, cvfunc) \
72     do { \
73         if (!strncmp(prb, s, strlen(prb))) { \
74             ok = 1; \
75             s += strlen(prb); \
76             char *end_s; \
77             d.c = cvfunc(s, &end_s); \
78             s += (end_s - s); \
79             if (d.c < 0) return FAIL; \
80             /* printf("@@@debug: %s: " IFMT "\n", prb, d. c); */ \
81         } \
82     } while (0)
83 #define CASE_N(c, cvfunc) CASE_NN(#c, c, cvfunc)
84     while (*s) {
85         int ok = 0;
86         CASE_N(mb, mstrtol);
87         CASE_N(ic, mstrtol);
88         CASE_N(id, mstrtol);
89         CASE_N(ih, mstrtol);
90         CASE_N(iw, mstrtol);
91         CASE_N(ls, mstrtol);
92         CASE_N(alpha, strtof);
93         CASE_N(beta, strtof);
94         CASE_N(k, strtof);
95         if (*s == 'n') {
96             d.name = s + 1;
97             break;
98         }
99         if (*s == '_') ++s;
100         if (!ok) return FAIL;
101     }
102 #undef CASE_NN
103 #undef CASE_N
104 
105     if (d.ic == 0) return FAIL;
106 
107     if (sanitize_desc(d.ndims, {d.id}, {d.ih}, {d.iw}, {1}) != OK) return FAIL;
108 
109     *desc = d;
110 
111     return OK;
112 }
113 
operator <<(std::ostream & s,const desc_t & d)114 std::ostream &operator<<(std::ostream &s, const desc_t &d) {
115     bool print_d = true, print_h = true, print_w = true;
116     print_dhw(print_d, print_h, print_w, d.ndims, {d.id}, {d.ih}, {d.iw});
117 
118     if (canonical || d.mb != 2) s << "mb" << d.mb;
119 
120     s << "ic" << d.ic;
121 
122     if (print_d) s << "id" << d.id;
123     if (print_h) s << "ih" << d.ih;
124     if (print_w) s << "iw" << d.iw;
125 
126     if (canonical || d.ls != 5) s << "ls" << d.ls;
127     if (canonical || d.alpha != 1.f / 8192) s << "alpha" << d.alpha;
128     if (canonical || d.beta != 0.75f) s << "beta" << d.beta;
129     if (canonical || d.k != 1) s << "k" << d.k;
130 
131     if (d.name) s << "n" << d.name;
132 
133     return s;
134 }
135 
operator <<(std::ostream & s,const prb_t & prb)136 std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
137     dump_global_params(s);
138     settings_t def;
139 
140     if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
141     if (canonical || prb.dt != def.dt[0]) s << "--dt=" << prb.dt << " ";
142     if (canonical || prb.tag != def.tag[0]) s << "--tag=" << prb.tag << " ";
143     if (canonical || prb.alg != def.alg[0])
144         s << "--alg=" << alg2str(prb.alg) << " ";
145 
146     s << prb.attr;
147     s << static_cast<const desc_t &>(prb);
148 
149     return s;
150 }
151 
152 } // namespace lrn
153