1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/jit/conv/ir_core.hpp"
18 
19 #include <algorithm>
20 
21 namespace dnnl {
22 namespace impl {
23 namespace gpu {
24 namespace jit {
25 
26 expr_t const_fold_non_recursive(const expr_t &expr);
27 object_t const_fold(const object_t &obj);
28 
to_string(type_kind_t kind)29 std::string to_string(type_kind_t kind) {
30 #define CASE(_kind) \
31     case type_kind_t::_kind: return #_kind
32     switch (kind) {
33         CASE(undef);
34         CASE(u8);
35         CASE(s8);
36         CASE(u16);
37         CASE(s16);
38         CASE(u32);
39         CASE(s32);
40         CASE(u64);
41         CASE(s64);
42         CASE(bf16);
43         CASE(f16);
44         CASE(f32);
45         CASE(byte);
46         CASE(dword);
47         CASE(qword);
48         CASE(oword);
49         CASE(hword);
50         case type_kind_t::_bool: return "bool";
51         default: ir_error_not_expected();
52     }
53 #undef CASE
54     return {};
55 }
56 
size() const57 int type_t::size() const {
58     if (is_ptr()) return sizeof(uint64_t);
59 
60     if (elems() != 1) return elems() * scalar().size();
61 
62     switch (kind()) {
63         case type_kind_t::u8:
64         case type_kind_t::s8:
65         case type_kind_t::byte: return 1;
66         case type_kind_t::u16:
67         case type_kind_t::s16:
68         case type_kind_t::bf16:
69         case type_kind_t::f16: return 2;
70         case type_kind_t::u32:
71         case type_kind_t::s32:
72         case type_kind_t::f32:
73         case type_kind_t::dword: return 4;
74         case type_kind_t::u64:
75         case type_kind_t::s64:
76         case type_kind_t::qword: return 8;
77         case type_kind_t::oword: return 16;
78         case type_kind_t::hword: return 32;
79         default: ir_error_not_expected();
80     }
81     return 0;
82 }
83 
to_dnnl(const type_t & type)84 data_type_t to_dnnl(const type_t &type) {
85     ir_assert(type.elems() == 1) << type;
86     ir_assert(!type.is_ptr() == 1) << type;
87     switch (type.kind()) {
88         case type_kind_t::bf16: return data_type::bf16;
89         case type_kind_t::f16: return data_type::f16;
90         case type_kind_t::f32: return data_type::f32;
91         case type_kind_t::s32: return data_type::s32;
92         case type_kind_t::s8: return data_type::s8;
93         case type_kind_t::u8: return data_type::u8;
94         default: ir_error_not_expected();
95     }
96     return data_type::undef;
97 }
98 
to_string(op_kind_t kind)99 std::string to_string(op_kind_t kind) {
100     switch (kind) {
101         case op_kind_t::_minus: return "-";
102 
103         case op_kind_t::_add: return "+";
104         case op_kind_t::_sub: return "-";
105         case op_kind_t::_mul: return "*";
106         case op_kind_t::_div: return "/";
107         case op_kind_t::_mod: return "%";
108         case op_kind_t::_shl: return "<<";
109         case op_kind_t::_shr: return ">>";
110         case op_kind_t::_min: return "min";
111         case op_kind_t::_max: return "max";
112 
113         case op_kind_t::_lt: return "<";
114         case op_kind_t::_le: return "<=";
115         case op_kind_t::_gt: return ">";
116         case op_kind_t::_ge: return ">=";
117         case op_kind_t::_eq: return "==";
118         case op_kind_t::_ne: return "!=";
119 
120         case op_kind_t::_and: return "&&";
121 
122         case op_kind_t::_add3: return "add3";
123         case op_kind_t::_mad: return "mad";
124 
125         default: ir_error_not_expected() << "Unknown op_kind_t value.";
126     }
127     return "";
128 }
129 
is_cmp_op(op_kind_t op_kind)130 bool is_cmp_op(op_kind_t op_kind) {
131     switch (op_kind) {
132         case op_kind_t::_ge:
133         case op_kind_t::_gt:
134         case op_kind_t::_le:
135         case op_kind_t::_lt:
136         case op_kind_t::_eq:
137         case op_kind_t::_ne: return true;
138         default: return false;
139     }
140 }
141 
negate_cmp_op(op_kind_t op_kind)142 op_kind_t negate_cmp_op(op_kind_t op_kind) {
143     switch (op_kind) {
144         case op_kind_t::_ge: return op_kind_t::_le;
145         case op_kind_t::_gt: return op_kind_t::_lt;
146         case op_kind_t::_le: return op_kind_t::_ge;
147         case op_kind_t::_lt: return op_kind_t::_gt;
148         case op_kind_t::_eq: return op_kind_t::_eq;
149         case op_kind_t::_ne: return op_kind_t::_ne;
150         default: ir_error_not_expected();
151     }
152     return op_kind_t::undef;
153 }
154 
unary_op_type(op_kind_t op_kind,const expr_t & a)155 type_t unary_op_type(op_kind_t op_kind, const expr_t &a) {
156     switch (op_kind) {
157         case op_kind_t::_minus: {
158             auto &t = a.type();
159             if (!t.is_int()) return t;
160             if (t.size() < int(sizeof(int32_t))) return type_t::s32(t.elems());
161             return t;
162         }
163         default:
164             ir_error_not_expected() << "Unknown op_kind_t value: " << op_kind;
165     }
166     return type_t::undef();
167 }
168 
common_int_type(const type_t & _a,const type_t & _b)169 type_t common_int_type(const type_t &_a, const type_t &_b) {
170     ir_assert(_a.is_int() && _b.is_int()) << "Unexpected types.";
171 
172     int elems = _a.elems();
173 
174     // Promote to s32 first.
175     type_t a = _a.size() < int(sizeof(int32_t)) ? type_t::s32() : _a;
176     type_t b = _b.size() < int(sizeof(int32_t)) ? type_t::s32() : _b;
177     a = a.scalar();
178     b = b.scalar();
179 
180     // Integer promotion, follow C++ rules.
181     int common_bits = 8 * std::max(a.size(), b.size());
182     if (a.is_signed() == b.is_signed()) {
183         if (a.is_signed()) return type_t::s(common_bits, elems);
184         return type_t::u(common_bits, elems);
185     }
186 
187     if (a.size() >= b.size() && a.is_unsigned())
188         return type_t::u(common_bits, elems);
189     if (b.size() >= a.size() && b.is_unsigned())
190         return type_t::u(common_bits, elems);
191     if (a.size() > b.size() && a.is_signed())
192         return type_t::s(common_bits, elems);
193     if (b.size() > a.size() && b.is_signed())
194         return type_t::s(common_bits, elems);
195 
196     return type_t::u(common_bits, elems);
197 }
198 
common_type(const type_t & a,const type_t & b)199 type_t common_type(const type_t &a, const type_t &b) {
200     ir_assert(a.elems() == b.elems())
201             << "Types must have the same number of components.";
202     if (a.is_undef() || b.is_undef()) return type_t::undef();
203     if (a.is_fp() && !b.is_fp()) return a;
204     if (!a.is_fp() && b.is_fp()) return b;
205     if (a.is_fp() && b.is_fp()) return (a.size() > b.size() ? a : b);
206     if (a.is_bool() && b.is_bool()) return a;
207     return common_int_type(a, b);
208 }
209 
common_type(const expr_t & a,const expr_t & b)210 type_t common_type(const expr_t &a, const expr_t &b) {
211     return common_type(a.type(), b.type());
212 }
213 
binary_op_type(op_kind_t op_kind,const type_t & a,const type_t & b)214 type_t binary_op_type(op_kind_t op_kind, const type_t &a, const type_t &b) {
215     if (a.is_undef() || b.is_undef()) return type_t::undef();
216     ir_assert(a.elems() == b.elems())
217             << "Types must have the same number of components.";
218     if (is_cmp_op(op_kind)) return type_t::_bool(a.elems());
219     if (utils::one_of(op_kind, op_kind_t::_shl, op_kind_t::_shr)) {
220         ir_assert(a.is_unsigned())
221                 << "a must be unsigned for shift left/right.";
222         return type_t::u32(a.elems());
223     }
224     return common_type(a, b);
225 }
226 
binary_op_type(op_kind_t op_kind,const expr_t & a,const expr_t & b)227 type_t binary_op_type(op_kind_t op_kind, const expr_t &a, const expr_t &b) {
228     return binary_op_type(op_kind, a.type(), b.type());
229 }
230 
ternary_op_type(op_kind_t op_kind,const expr_t & a,const expr_t & b,const expr_t & c)231 type_t ternary_op_type(
232         op_kind_t op_kind, const expr_t &a, const expr_t &b, const expr_t &c) {
233     switch (op_kind) {
234         case op_kind_t::_add3:
235             return binary_op_type(op_kind_t::_add, a.type(),
236                     binary_op_type(op_kind_t::_add, b, c));
237         case op_kind_t::_mad:
238             return binary_op_type(op_kind_t::_add, a.type(),
239                     binary_op_type(op_kind_t::_mul, b, c));
240         default: ir_error_not_expected();
241     }
242     return type_t::undef();
243 }
244 
nary_op_type(op_kind_t op_kind,const std::vector<expr_t> & args)245 type_t nary_op_type(op_kind_t op_kind, const std::vector<expr_t> &args) {
246     ir_assert(!args.empty());
247     if (args.size() == 1) return args[0].type();
248 
249     auto type = args[0].type();
250     for (size_t i = 1; i < args.size(); i++)
251         type = common_type(type, args[i].type());
252 
253     return type;
254 }
255 
normalize(expr_t & base,expr_t & off,op_kind_t op_kind)256 void ptr_t::normalize(expr_t &base, expr_t &off, op_kind_t op_kind) {
257     ir_assert(base.type().is_ptr()) << "base is not a pointer: " << base;
258     ir_assert(off.type().is_int()) << "off is not an integer: " << off;
259     ir_assert(utils::one_of(op_kind, op_kind_t::_add, op_kind_t::_sub))
260             << "Can't apply this operation to pointer: " << to_string(op_kind);
261 
262     if (!base.is<ptr_t>()) {
263         if (op_kind == op_kind_t::_sub) off = const_fold(-off);
264         return;
265     }
266 
267     auto &base_off = base.as<ptr_t>().off;
268     base = base.as<ptr_t>().base;
269     off = const_fold_non_recursive(binary_op_t::make(op_kind, base_off, off));
270 }
271 
shift_ptr(op_kind_t op_kind,const expr_t & a,const expr_t & b)272 expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b) {
273     expr_t base = a;
274     expr_t off = b;
275     ptr_t::normalize(base, off, op_kind);
276     return ptr_t::make(base, off);
277 }
278 
normalize_ptr(const type_t & type,expr_t & base_expr,expr_t & off)279 void normalize_ptr(const type_t &type, expr_t &base_expr, expr_t &off) {
280     if (base_expr.is<ptr_t>()) {
281         auto &base = base_expr.as<ptr_t>().base;
282         auto &base_off = base_expr.as<ptr_t>().off;
283 
284         base_expr = base;
285         off = const_fold_non_recursive(base_off + off);
286     }
287     ir_assert(to_cpp<int64_t>(off) % type.size() == 0)
288             << "Incompatible offset: " << off;
289 }
290 
operator [](const expr_t & off) const291 expr_t expr_t::operator[](const expr_t &off) const {
292     if (is<shuffle_t>()) {
293         ir_assert(is_const(off)) << "Offset is not constant.";
294         auto &shuffle = as<shuffle_t>();
295         int idx = shuffle.idx[to_cpp<int>(off)];
296         return shuffle.vec[idx];
297     }
298     return shift_ptr(op_kind_t::_add, *this, off);
299 }
300 
expr_t(bool value)301 expr_t::expr_t(bool value) : object_t(new bool_imm_t(value)) {}
expr_t(float value)302 expr_t::expr_t(float value) : object_t(new float_imm_t(value)) {}
expr_t(int16_t value)303 expr_t::expr_t(int16_t value) : object_t(new int_imm_t(value)) {}
expr_t(int32_t value)304 expr_t::expr_t(int32_t value) : object_t(new int_imm_t(value)) {}
expr_t(int64_t value)305 expr_t::expr_t(int64_t value) : object_t(new int_imm_t(value)) {}
expr_t(uint16_t value)306 expr_t::expr_t(uint16_t value) : object_t(new int_imm_t(value)) {}
expr_t(uint32_t value)307 expr_t::expr_t(uint32_t value) : object_t(new int_imm_t(value)) {}
expr_t(uint64_t value)308 expr_t::expr_t(uint64_t value) : object_t(new int_imm_t(value)) {}
309 
operator -(const expr_t & a)310 expr_t operator-(const expr_t &a) {
311     return const_fold_non_recursive(unary_op_t::make(op_kind_t::_minus, a));
312 }
313 
314 #define DEFINE_BINARY_OPERATOR(op, op_kind) \
315     expr_t operator op(const expr_t &a, const expr_t &b) { \
316         if (a.type().is_ptr()) return shift_ptr(op_kind, a, b); \
317         return const_fold_non_recursive(binary_op_t::make(op_kind, a, b)); \
318     }
319 
320 DEFINE_BINARY_OPERATOR(+, op_kind_t::_add)
321 DEFINE_BINARY_OPERATOR(-, op_kind_t::_sub)
322 DEFINE_BINARY_OPERATOR(*, op_kind_t::_mul)
323 DEFINE_BINARY_OPERATOR(/, op_kind_t::_div)
324 DEFINE_BINARY_OPERATOR(%, op_kind_t::_mod)
325 DEFINE_BINARY_OPERATOR(<<, op_kind_t::_shl)
326 DEFINE_BINARY_OPERATOR(>>, op_kind_t::_shr)
327 
328 DEFINE_BINARY_OPERATOR(==, op_kind_t::_eq)
329 DEFINE_BINARY_OPERATOR(!=, op_kind_t::_ne)
330 DEFINE_BINARY_OPERATOR(>, op_kind_t::_gt)
331 DEFINE_BINARY_OPERATOR(>=, op_kind_t::_ge)
332 DEFINE_BINARY_OPERATOR(<, op_kind_t::_lt)
333 DEFINE_BINARY_OPERATOR(<=, op_kind_t::_le)
334 
335 DEFINE_BINARY_OPERATOR(&, op_kind_t::_and)
336 
337 #undef DEFINE_BINARY_OPERATOR
338 
339 #define DEFINE_BINARY_ASSIGN_OPERATOR(op) \
340     expr_t &expr_t::operator op##=(const expr_t &rhs) { \
341         auto tmp = (*this)op rhs; \
342         *this = tmp; \
343         return *this; \
344     }
345 
346 DEFINE_BINARY_ASSIGN_OPERATOR(+)
347 DEFINE_BINARY_ASSIGN_OPERATOR(-)
348 DEFINE_BINARY_ASSIGN_OPERATOR(*)
349 DEFINE_BINARY_ASSIGN_OPERATOR(/)
350 DEFINE_BINARY_ASSIGN_OPERATOR(%)
351 DEFINE_BINARY_ASSIGN_OPERATOR(&)
352 
353 #undef DEFINE_BINARY_ASSIGN_OPERATOR
354 
_mutate(ir_mutator_t & mutator) const355 object_t object_impl_t::_mutate(ir_mutator_t &mutator) const {
356     return *this;
357 }
358 
359 #define DECL_MUTATE_LEAF(name) \
360     object_t ir_mutator_t::_mutate(const name &obj) { return obj; }
361 
362 DECL_MUTATE_LEAF(bool_imm_t)
DECL_MUTATE_LEAF(float_imm_t)363 DECL_MUTATE_LEAF(float_imm_t)
364 DECL_MUTATE_LEAF(func_impl_t)
365 DECL_MUTATE_LEAF(int_imm_t)
366 DECL_MUTATE_LEAF(var_t)
367 
368 #undef DECL_MUTATE_LEAF
369 
370 object_t ir_mutator_t::_mutate(const alloc_t &obj) {
371     auto buf = mutate(obj.buf);
372     auto body = mutate(obj.body);
373 
374     if (buf.is_same(obj.buf) && body.is_same(obj.body)) return obj;
375 
376     return alloc_t::make(buf, obj.size, obj.kind, obj.attr, body);
377 }
378 
_mutate(const binary_op_t & obj)379 object_t ir_mutator_t::_mutate(const binary_op_t &obj) {
380     auto a = mutate(obj.a);
381     auto b = mutate(obj.b);
382 
383     if (a.is_same(obj.a) && b.is_same(obj.b)) return obj;
384 
385     return binary_op_t::make(obj.op_kind, a, b);
386 }
387 
_mutate(const cast_t & obj)388 object_t ir_mutator_t::_mutate(const cast_t &obj) {
389     auto expr = mutate(obj.expr);
390 
391     if (expr.is_same(obj.expr)) return obj;
392 
393     return cast_t::make(obj.type, expr, obj.saturate);
394 }
395 
_mutate(const for_t & obj)396 object_t ir_mutator_t::_mutate(const for_t &obj) {
397     auto var = mutate(obj.var);
398     auto init = mutate(obj.init);
399     auto bound = mutate(obj.bound);
400     auto body = mutate(obj.body);
401 
402     if (var.is_same(obj.var) && init.is_same(obj.init)
403             && bound.is_same(obj.bound) && body.is_same(obj.body))
404         return obj;
405 
406     return for_t::make(var, init, bound, body, obj.unroll);
407 }
408 
_mutate(const func_call_t & obj)409 object_t ir_mutator_t::_mutate(const func_call_t &obj) {
410     auto func = mutate(obj.func);
411     auto args = mutate(obj.args);
412 
413     if (func.is_same(obj.func) && ir_utils::is_same(args, obj.args)) return obj;
414 
415     return func_call_t::make(func, args, obj.attr);
416 }
417 
_mutate(const if_t & obj)418 object_t ir_mutator_t::_mutate(const if_t &obj) {
419     auto cond = mutate(obj.cond);
420     auto body = mutate(obj.body);
421     auto else_body = mutate(obj.else_body);
422 
423     if (cond.is_same(obj.cond) && body.is_same(obj.body)
424             && else_body.is_same(obj.else_body))
425         return obj;
426 
427     return if_t::make(cond, body, else_body);
428 }
429 
_mutate(const iif_t & obj)430 object_t ir_mutator_t::_mutate(const iif_t &obj) {
431     auto cond = mutate(obj.cond);
432     auto true_expr = mutate(obj.true_expr);
433     auto false_expr = mutate(obj.false_expr);
434 
435     if (cond.is_same(obj.cond) && true_expr.is_same(obj.true_expr)
436             && false_expr.is_same(obj.false_expr))
437         return obj;
438 
439     return iif_t::make(cond, true_expr, false_expr);
440 }
441 
_mutate(const let_t & obj)442 object_t ir_mutator_t::_mutate(const let_t &obj) {
443     auto var = mutate(obj.var);
444     auto value = mutate(obj.value);
445     auto body = mutate(obj.body);
446 
447     if (var.is_same(obj.var) && value.is_same(obj.value)
448             && body.is_same(obj.body))
449         return obj;
450 
451     return let_t::make(var, value, body);
452 }
453 
_mutate(const load_t & obj)454 object_t ir_mutator_t::_mutate(const load_t &obj) {
455     auto buf = mutate(obj.buf);
456     auto off = mutate(obj.off);
457 
458     if (buf.is_same(obj.buf) && off.is_same(obj.off)) return obj;
459 
460     return load_t::make(obj.type, buf, off, obj.stride);
461 }
462 
_mutate(const ptr_t & obj)463 object_t ir_mutator_t::_mutate(const ptr_t &obj) {
464     auto base = mutate(obj.base);
465     auto off = mutate(obj.off);
466 
467     if (base.is_same(obj.base) && off.is_same(obj.off)) return obj;
468 
469     return ptr_t::make(base, off);
470 }
471 
_mutate(const shuffle_t & obj)472 object_t ir_mutator_t::_mutate(const shuffle_t &obj) {
473     auto vec = mutate(obj.vec);
474 
475     if (ir_utils::is_same(vec, obj.vec)) return obj;
476 
477     return shuffle_t::make(vec, obj.idx);
478 }
479 
_mutate(const stmt_group_t & obj)480 object_t ir_mutator_t::_mutate(const stmt_group_t &obj) {
481     auto body = mutate(obj.body);
482 
483     if (body.is_same(obj.body)) return obj;
484 
485     return stmt_group_t::make(obj.label, body);
486 }
487 
_mutate(const stmt_seq_t & obj)488 object_t ir_mutator_t::_mutate(const stmt_seq_t &obj) {
489     auto head = mutate(obj.head);
490     auto tail = mutate(obj.tail);
491 
492     if (head.is_same(obj.head) && tail.is_same(obj.tail)) return obj;
493 
494     return stmt_seq_t::make(head, tail);
495 }
496 
_mutate(const store_t & obj)497 object_t ir_mutator_t::_mutate(const store_t &obj) {
498     auto buf = mutate(obj.buf);
499     auto off = mutate(obj.off);
500     auto value = mutate(obj.value);
501     auto mask = mutate(obj.mask);
502 
503     if (buf.is_same(obj.buf) && off.is_same(obj.off) && value.is_same(obj.value)
504             && mask.is_same(obj.mask))
505         return obj;
506 
507     return store_t::make(buf, off, value, obj.stride, mask);
508 }
509 
_mutate(const ternary_op_t & obj)510 object_t ir_mutator_t::_mutate(const ternary_op_t &obj) {
511     auto a = mutate(obj.a);
512     auto b = mutate(obj.b);
513     auto c = mutate(obj.c);
514 
515     if (a.is_same(obj.a) && b.is_same(obj.b) && c.is_same(obj.c)) return obj;
516 
517     return ternary_op_t::make(obj.op_kind, a, b, c);
518 }
519 
_mutate(const unary_op_t & obj)520 object_t ir_mutator_t::_mutate(const unary_op_t &obj) {
521     auto a = mutate(obj.a);
522     if (a.is_same(obj.a)) return obj;
523     return unary_op_t::make(obj.op_kind, a);
524 }
525 
526 // Catch missing mutates that are not expected to dispatch to the base
527 // mutator
_mutate(const nary_op_t & obj)528 object_t ir_mutator_t::_mutate(const nary_op_t &obj) {
529     ir_error_not_expected() << "Can't handle type: nary_op_t";
530     return {};
531 }
_mutate(const pexpr_t & obj)532 object_t ir_mutator_t::_mutate(const pexpr_t &obj) {
533     ir_error_not_expected() << "Can't handle type: pexpr_t";
534     return {};
535 }
536 
537 } // namespace jit
538 } // namespace gpu
539 } // namespace impl
540 } // namespace dnnl
541