1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef GPU_JIT_CONV_UTILS_HPP
18 #define GPU_JIT_CONV_UTILS_HPP
19
20 #include <functional>
21 #include <iomanip>
22 #include <iostream>
23 #include <sstream>
24 #include <string>
25 #include <type_traits>
26 #include <unordered_map>
27 #include <unordered_set>
28
29 #include "common/utils.hpp"
30
31 namespace dnnl {
32 namespace impl {
33 namespace gpu {
34 namespace jit {
35 namespace ir_utils {
36
37 const int LOG_OFF = 0;
38 const int LOG_WARNING = 100;
39 const int LOG_TRACE = 200;
40
41 #ifdef GEN_CONV_DEBUG
42 const int LOG_LEVEL = LOG_WARNING;
43 #else
44 const int LOG_LEVEL = LOG_OFF;
45 #endif
46
47 template <typename T>
48 size_t get_hash(const T &t);
49
50 template <typename T>
51 size_t get_hash(const std::vector<T> &v);
52
53 template <typename T>
get_hash_impl(size_t & h,const T & t)54 void get_hash_impl(size_t &h, const T &t) {
55 h = hash_combine(h, get_hash(t));
56 }
57
58 template <typename ArgHeadT, typename... ArgsT>
get_hash_impl(size_t & h,const ArgHeadT & head,const ArgsT &...args)59 void get_hash_impl(size_t &h, const ArgHeadT &head, const ArgsT &... args) {
60 size_t h_head = get_hash(head);
61 h = hash_combine(h, h_head);
62 get_hash_impl(h, args...);
63 }
64
65 template <typename E>
66 struct enum_hash_t {
operator ()dnnl::impl::gpu::jit::ir_utils::enum_hash_t67 size_t operator()(const E &e) const noexcept {
68 return std::hash<size_t>()((size_t)e);
69 }
70 };
71
72 template <typename T, typename = void>
73 struct get_std_hash_helper_t {
calldnnl::impl::gpu::jit::ir_utils::get_std_hash_helper_t74 static size_t call(const T &t) { return std::hash<T>()(t); }
75 };
76
77 template <typename T>
78 struct get_std_hash_helper_t<T,
79 typename std::enable_if<std::is_enum<T>::value>::type> {
calldnnl::impl::gpu::jit::ir_utils::get_std_hash_helper_t80 static size_t call(const T &t) { return enum_hash_t<T>()(t); }
81 };
82
83 template <typename T, typename = void>
84 struct get_hash_helper_t {
calldnnl::impl::gpu::jit::ir_utils::get_hash_helper_t85 static size_t call(const T &t) { return get_std_hash_helper_t<T>::call(t); }
86 };
87
88 template <typename T>
89 struct get_hash_helper_t<T, decltype(std::declval<T>().get_hash(), void())> {
calldnnl::impl::gpu::jit::ir_utils::get_hash_helper_t90 static size_t call(const T &t) { return t.get_hash(); }
91 };
92
93 template <typename T>
get_hash(const T & t)94 size_t get_hash(const T &t) {
95 return get_hash_helper_t<T>::call(t);
96 }
97
98 template <typename T>
get_hash(const std::vector<T> & v)99 size_t get_hash(const std::vector<T> &v) {
100 size_t h = 0;
101 for (auto &e : v)
102 h = hash_combine(h, get_hash(e));
103 return h;
104 }
105
106 template <typename... ArgsT>
get_hash(const ArgsT &...args)107 size_t get_hash(const ArgsT &... args) {
108 size_t h = 0;
109 get_hash_impl(h, args...);
110 return h;
111 }
112
113 template <typename T, typename U, typename = void>
114 struct is_equal_helper_t {
calldnnl::impl::gpu::jit::ir_utils::is_equal_helper_t115 static bool call(const T &t, const U &u) { return t == u; }
116 };
117
118 template <typename T, typename U>
119 struct is_equal_helper_t<T, U,
120 decltype(std::declval<T>().is_equal(std::declval<U>()), void())> {
calldnnl::impl::gpu::jit::ir_utils::is_equal_helper_t121 static bool call(const T &t, const U &u) { return t.is_equal(u); }
122 };
123
124 // Checks equality of objects:
125 // 1. Uses t.is_equal(u) if is_equal() is available
126 // 2. Uses (t == u) otherwise
127 template <typename T, typename U>
is_equal(const T & t,const U & u)128 bool is_equal(const T &t, const U &u) {
129 return is_equal_helper_t<T, U>::call(t, u);
130 }
131
132 // Checks equality of vector elements.
133 template <typename T, typename U>
is_equal(const std::vector<T> & a,const std::vector<U> & b)134 bool is_equal(const std::vector<T> &a, const std::vector<U> &b) {
135 if (a.size() != b.size()) return false;
136 for (size_t i = 0; i < a.size(); i++)
137 if (!ir_utils::is_equal(a[i], b[i])) return false;
138 return true;
139 }
140
141 // Checks equality of vector elements between each other.
142 template <typename T>
are_all_equal(const std::vector<T> & a)143 bool are_all_equal(const std::vector<T> &a) {
144 if (a.empty()) return true;
145 for (size_t i = 1; i < a.size(); i++)
146 if (!ir_utils::is_equal(a[i], a[0])) return false;
147 return true;
148 }
149
150 // Checks identity of vector elements.
151 template <typename T, typename U>
is_same(const std::vector<T> & a,const std::vector<U> & b)152 bool is_same(const std::vector<T> &a, const std::vector<U> &b) {
153 if (a.size() != b.size()) return false;
154 for (size_t i = 0; i < a.size(); i++)
155 if (!a[i].is_same(b[i])) return false;
156 return true;
157 }
158
159 class error_stream_t {
160 public:
error_stream_t(const char * file,int line,const char * assert_msg)161 error_stream_t(const char *file, int line, const char *assert_msg)
162 : file_(file), line_(line) {
163 out_ << "Assertion " << assert_msg << " failed at " << file_ << ":"
164 << line_ << std::endl;
165 }
166
167 // This is to be able use a steam object in short-circuit evaluation with
168 // booleans, see below.
operator bool() const169 operator bool() const { return true; }
170
171 template <typename T>
operator <<(const T & t)172 error_stream_t &operator<<(const T &t) {
173 out_ << t;
174 return *this;
175 }
176
~error_stream_t()177 ~error_stream_t() {
178 std::cerr << out_.str() << std::endl;
179 abort();
180 }
181
182 private:
183 const char *file_;
184 int line_;
185 std::ostringstream out_;
186 };
187
188 // Checks assertion and, in case of error, evaluates output operators to print
189 // related messages. Usage:
190 // ir_assert(condition) << "Error message" << ...;
191
192 #ifndef NDEBUG
193 #define ir_assert(cond) \
194 !(cond) \
195 && dnnl::impl::gpu::jit::ir_utils::error_stream_t( \
196 __FILE__, __LINE__, #cond)
197 #else
198 #define ir_assert(cond) \
199 (false) && !(cond) \
200 && dnnl::impl::gpu::jit::ir_utils::error_stream_t( \
201 __FILE__, __LINE__, #cond)
202 #endif
203
204 #define ir_error_not_expected() ir_assert(false) << "Not expected. "
205 #define ir_error_not_implemented() ir_assert(false) << "Not implemented. "
206
207 template <int level>
208 class logger_t {
209 public:
logger_t(std::ostream & out=std::cout)210 logger_t(std::ostream &out = std::cout) : out_(out) {}
211
operator bool() const212 operator bool() const { return true; }
213
is_enabled()214 static bool is_enabled() {
215 #ifdef GEN_CONV_DEBUG
216 static int log_level = getenv_int("log_level", LOG_LEVEL);
217 return log_level >= level;
218 #else
219 return LOG_LEVEL >= level;
220 #endif
221 }
222
223 template <typename T>
operator <<(const T & obj)224 logger_t &operator<<(const T &obj) {
225 maybe_print_header();
226 out_ << obj;
227 return *this;
228 }
229
operator <<(std::ostream & (* os)(std::ostream &))230 logger_t &operator<<(std::ostream &(*os)(std::ostream &)) {
231 maybe_print_header();
232 out_ << os;
233 return *this;
234 }
235
236 private:
maybe_print_header()237 void maybe_print_header() {
238 if (!is_first_print_) return;
239
240 switch (level) {
241 case LOG_WARNING: out_ << "[WARNING] "; break;
242 default: break;
243 }
244 is_first_print_ = false;
245 }
246
247 std::ostream &out_;
248 bool is_first_print_ = true;
249 };
250
251 #define ir_warning() \
252 ir_utils::logger_t<ir_utils::LOG_WARNING>::is_enabled() \
253 && ir_utils::logger_t<ir_utils::LOG_WARNING>()
254
255 #define ir_trace() \
256 ir_utils::logger_t<ir_utils::LOG_TRACE>::is_enabled() \
257 && ir_utils::logger_t<ir_utils::LOG_TRACE>()
258
259 // Pretty printers for STL objects.
260 template <typename KeyT, typename HashT, typename EqualT>
operator <<(std::ostream & out,const std::unordered_set<KeyT,HashT,EqualT> & s)261 inline std::ostream &operator<<(
262 std::ostream &out, const std::unordered_set<KeyT, HashT, EqualT> &s) {
263 out << "{";
264 for (auto it = s.begin(); it != s.end(); it++) {
265 out << (it != s.begin() ? ", " : "") << *it;
266 }
267 out << "}";
268 return out;
269 }
270
271 template <typename KeyT, typename ValueT, typename HashT, typename EqualT>
operator <<(std::ostream & out,const std::unordered_map<KeyT,ValueT,HashT,EqualT> & m)272 inline std::ostream &operator<<(std::ostream &out,
273 const std::unordered_map<KeyT, ValueT, HashT, EqualT> &m) {
274 out << "{";
275 for (auto it = m.begin(); it != m.end(); it++) {
276 out << (it != m.begin() ? ", " : "") << it->first << ": " << it->second;
277 }
278 out << "}";
279 return out;
280 }
281
282 template <typename ContainerT>
283 struct seq_print_helper_t {
seq_print_helper_tdnnl::impl::gpu::jit::ir_utils::seq_print_helper_t284 seq_print_helper_t(const ContainerT &v, const std::string &sep, int width)
285 : v(v), sep(sep), width(width) {}
286
287 const ContainerT &v;
288 const std::string sep;
289 int width;
290 };
291
292 template <typename T>
make_seq_print_helper(const T & v,const std::string & sep=", ",int width=0)293 seq_print_helper_t<T> make_seq_print_helper(
294 const T &v, const std::string &sep = ", ", int width = 0) {
295 return seq_print_helper_t<T>(v, sep, width);
296 }
297
298 template <typename T>
operator <<(std::ostream & out,const seq_print_helper_t<T> & seq)299 inline std::ostream &operator<<(
300 std::ostream &out, const seq_print_helper_t<T> &seq) {
301 for (auto it = seq.v.begin(); it != seq.v.end(); it++) {
302 out << (it != seq.v.begin() ? seq.sep : "") << std::setw(seq.width)
303 << *it;
304 }
305 return out;
306 }
307
308 template <typename T>
operator <<(std::ostream & out,const std::vector<T> & v)309 inline std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
310 out << "[";
311 out << make_seq_print_helper(v);
312 out << "]";
313 return out;
314 }
315
getenv_bool(const char * s,bool def)316 inline bool getenv_bool(const char *s, bool def) {
317 return getenv_int(s, def ? 1 : 0) == 1;
318 }
319
getenv_str(const char * s,const std::string & def)320 inline std::string getenv_str(const char *s, const std::string &def) {
321 char buf[1024];
322 int ret = getenv(s, buf, sizeof(buf));
323 if (ret > 0) return buf;
324 return def;
325 }
326
to_string(bool b)327 inline std::string to_string(bool b) {
328 return b ? "True" : "False";
329 }
330
331 template <typename T>
max_divisor(T n,std::initializer_list<T> divisors)332 inline T max_divisor(T n, std::initializer_list<T> divisors) {
333 T ret = -1;
334 for (auto d : divisors) {
335 if (n % d == 0) ret = std::max(ret, d);
336 }
337 ir_assert(ret != -1);
338 return ret;
339 }
340
341 // Equivalent of BLSI instruction (extract lowest set isolated bit).
342 template <typename T>
max_pow2_divisor(T n)343 inline T max_pow2_divisor(T n) {
344 return n & ~(n - 1);
345 }
346
347 template <typename T, typename U>
safe_divide(T a,U b)348 inline T safe_divide(T a, U b) {
349 ir_assert(b != 0 && a % b == 0) << "Can't divide: " << a << " / " << b;
350 return a / b;
351 }
352
353 template <typename ContainerT, typename T>
find_index(const ContainerT & c,const T & value)354 inline int find_index(const ContainerT &c, const T &value) {
355 for (int i = 0; i < int(c.size()); i++) {
356 if (c[i] == value) return i;
357 }
358 return -1;
359 }
360
361 template <typename T, typename F>
for_each_impl(size_t pos,std::vector<T> & idx,const std::vector<T> & bounds,const F & f)362 void for_each_impl(size_t pos, std::vector<T> &idx,
363 const std::vector<T> &bounds, const F &f) {
364 if (pos == bounds.size()) {
365 f(idx);
366 return;
367 }
368
369 for (T i = 0; i < bounds[pos]; i++) {
370 idx[pos] = i;
371 for_each_impl(pos + 1, idx, bounds, f);
372 }
373 }
374
375 template <typename T, typename F>
for_each(const std::vector<T> & bounds,const F & f)376 void for_each(const std::vector<T> &bounds, const F &f) {
377 std::vector<T> idx(bounds.size());
378 for_each_impl(0, idx, bounds, f);
379 }
380
381 } // namespace ir_utils
382 } // namespace jit
383 } // namespace gpu
384 } // namespace impl
385 } // namespace dnnl
386
387 #endif
388