1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gpu/jit/conv/ir_core.hpp"
18
19 #include <algorithm>
20
21 namespace dnnl {
22 namespace impl {
23 namespace gpu {
24 namespace jit {
25
26 expr_t const_fold_non_recursive(const expr_t &expr);
27 object_t const_fold(const object_t &obj);
28
to_string(type_kind_t kind)29 std::string to_string(type_kind_t kind) {
30 #define CASE(_kind) \
31 case type_kind_t::_kind: return #_kind
32 switch (kind) {
33 CASE(undef);
34 CASE(u8);
35 CASE(s8);
36 CASE(u16);
37 CASE(s16);
38 CASE(u32);
39 CASE(s32);
40 CASE(u64);
41 CASE(s64);
42 CASE(bf16);
43 CASE(f16);
44 CASE(f32);
45 CASE(byte);
46 CASE(dword);
47 CASE(qword);
48 CASE(oword);
49 CASE(hword);
50 case type_kind_t::_bool: return "bool";
51 default: ir_error_not_expected();
52 }
53 #undef CASE
54 return {};
55 }
56
size() const57 int type_t::size() const {
58 if (is_ptr()) return sizeof(uint64_t);
59
60 if (elems() != 1) return elems() * scalar().size();
61
62 switch (kind()) {
63 case type_kind_t::u8:
64 case type_kind_t::s8:
65 case type_kind_t::byte: return 1;
66 case type_kind_t::u16:
67 case type_kind_t::s16:
68 case type_kind_t::bf16:
69 case type_kind_t::f16: return 2;
70 case type_kind_t::u32:
71 case type_kind_t::s32:
72 case type_kind_t::f32:
73 case type_kind_t::dword: return 4;
74 case type_kind_t::u64:
75 case type_kind_t::s64:
76 case type_kind_t::qword: return 8;
77 case type_kind_t::oword: return 16;
78 case type_kind_t::hword: return 32;
79 default: ir_error_not_expected();
80 }
81 return 0;
82 }
83
to_dnnl(const type_t & type)84 data_type_t to_dnnl(const type_t &type) {
85 ir_assert(type.elems() == 1) << type;
86 ir_assert(!type.is_ptr() == 1) << type;
87 switch (type.kind()) {
88 case type_kind_t::bf16: return data_type::bf16;
89 case type_kind_t::f16: return data_type::f16;
90 case type_kind_t::f32: return data_type::f32;
91 case type_kind_t::s32: return data_type::s32;
92 case type_kind_t::s8: return data_type::s8;
93 case type_kind_t::u8: return data_type::u8;
94 default: ir_error_not_expected();
95 }
96 return data_type::undef;
97 }
98
to_string(op_kind_t kind)99 std::string to_string(op_kind_t kind) {
100 switch (kind) {
101 case op_kind_t::_minus: return "-";
102
103 case op_kind_t::_add: return "+";
104 case op_kind_t::_sub: return "-";
105 case op_kind_t::_mul: return "*";
106 case op_kind_t::_div: return "/";
107 case op_kind_t::_mod: return "%";
108 case op_kind_t::_shl: return "<<";
109 case op_kind_t::_shr: return ">>";
110 case op_kind_t::_min: return "min";
111 case op_kind_t::_max: return "max";
112
113 case op_kind_t::_lt: return "<";
114 case op_kind_t::_le: return "<=";
115 case op_kind_t::_gt: return ">";
116 case op_kind_t::_ge: return ">=";
117 case op_kind_t::_eq: return "==";
118 case op_kind_t::_ne: return "!=";
119
120 case op_kind_t::_and: return "&&";
121
122 case op_kind_t::_add3: return "add3";
123 case op_kind_t::_mad: return "mad";
124
125 default: ir_error_not_expected() << "Unknown op_kind_t value.";
126 }
127 return "";
128 }
129
is_cmp_op(op_kind_t op_kind)130 bool is_cmp_op(op_kind_t op_kind) {
131 switch (op_kind) {
132 case op_kind_t::_ge:
133 case op_kind_t::_gt:
134 case op_kind_t::_le:
135 case op_kind_t::_lt:
136 case op_kind_t::_eq:
137 case op_kind_t::_ne: return true;
138 default: return false;
139 }
140 }
141
negate_cmp_op(op_kind_t op_kind)142 op_kind_t negate_cmp_op(op_kind_t op_kind) {
143 switch (op_kind) {
144 case op_kind_t::_ge: return op_kind_t::_le;
145 case op_kind_t::_gt: return op_kind_t::_lt;
146 case op_kind_t::_le: return op_kind_t::_ge;
147 case op_kind_t::_lt: return op_kind_t::_gt;
148 case op_kind_t::_eq: return op_kind_t::_eq;
149 case op_kind_t::_ne: return op_kind_t::_ne;
150 default: ir_error_not_expected();
151 }
152 return op_kind_t::undef;
153 }
154
unary_op_type(op_kind_t op_kind,const expr_t & a)155 type_t unary_op_type(op_kind_t op_kind, const expr_t &a) {
156 switch (op_kind) {
157 case op_kind_t::_minus: {
158 auto &t = a.type();
159 if (!t.is_int()) return t;
160 if (t.size() < int(sizeof(int32_t))) return type_t::s32(t.elems());
161 return t;
162 }
163 default:
164 ir_error_not_expected() << "Unknown op_kind_t value: " << op_kind;
165 }
166 return type_t::undef();
167 }
168
common_int_type(const type_t & _a,const type_t & _b)169 type_t common_int_type(const type_t &_a, const type_t &_b) {
170 ir_assert(_a.is_int() && _b.is_int()) << "Unexpected types.";
171
172 int elems = _a.elems();
173
174 // Promote to s32 first.
175 type_t a = _a.size() < int(sizeof(int32_t)) ? type_t::s32() : _a;
176 type_t b = _b.size() < int(sizeof(int32_t)) ? type_t::s32() : _b;
177 a = a.scalar();
178 b = b.scalar();
179
180 // Integer promotion, follow C++ rules.
181 int common_bits = 8 * std::max(a.size(), b.size());
182 if (a.is_signed() == b.is_signed()) {
183 if (a.is_signed()) return type_t::s(common_bits, elems);
184 return type_t::u(common_bits, elems);
185 }
186
187 if (a.size() >= b.size() && a.is_unsigned())
188 return type_t::u(common_bits, elems);
189 if (b.size() >= a.size() && b.is_unsigned())
190 return type_t::u(common_bits, elems);
191 if (a.size() > b.size() && a.is_signed())
192 return type_t::s(common_bits, elems);
193 if (b.size() > a.size() && b.is_signed())
194 return type_t::s(common_bits, elems);
195
196 return type_t::u(common_bits, elems);
197 }
198
common_type(const type_t & a,const type_t & b)199 type_t common_type(const type_t &a, const type_t &b) {
200 ir_assert(a.elems() == b.elems())
201 << "Types must have the same number of components.";
202 if (a.is_undef() || b.is_undef()) return type_t::undef();
203 if (a.is_fp() && !b.is_fp()) return a;
204 if (!a.is_fp() && b.is_fp()) return b;
205 if (a.is_fp() && b.is_fp()) return (a.size() > b.size() ? a : b);
206 if (a.is_bool() && b.is_bool()) return a;
207 return common_int_type(a, b);
208 }
209
common_type(const expr_t & a,const expr_t & b)210 type_t common_type(const expr_t &a, const expr_t &b) {
211 return common_type(a.type(), b.type());
212 }
213
binary_op_type(op_kind_t op_kind,const type_t & a,const type_t & b)214 type_t binary_op_type(op_kind_t op_kind, const type_t &a, const type_t &b) {
215 if (a.is_undef() || b.is_undef()) return type_t::undef();
216 ir_assert(a.elems() == b.elems())
217 << "Types must have the same number of components.";
218 if (is_cmp_op(op_kind)) return type_t::_bool(a.elems());
219 if (utils::one_of(op_kind, op_kind_t::_shl, op_kind_t::_shr)) {
220 ir_assert(a.is_unsigned())
221 << "a must be unsigned for shift left/right.";
222 return type_t::u32(a.elems());
223 }
224 return common_type(a, b);
225 }
226
binary_op_type(op_kind_t op_kind,const expr_t & a,const expr_t & b)227 type_t binary_op_type(op_kind_t op_kind, const expr_t &a, const expr_t &b) {
228 return binary_op_type(op_kind, a.type(), b.type());
229 }
230
ternary_op_type(op_kind_t op_kind,const expr_t & a,const expr_t & b,const expr_t & c)231 type_t ternary_op_type(
232 op_kind_t op_kind, const expr_t &a, const expr_t &b, const expr_t &c) {
233 switch (op_kind) {
234 case op_kind_t::_add3:
235 return binary_op_type(op_kind_t::_add, a.type(),
236 binary_op_type(op_kind_t::_add, b, c));
237 case op_kind_t::_mad:
238 return binary_op_type(op_kind_t::_add, a.type(),
239 binary_op_type(op_kind_t::_mul, b, c));
240 default: ir_error_not_expected();
241 }
242 return type_t::undef();
243 }
244
nary_op_type(op_kind_t op_kind,const std::vector<expr_t> & args)245 type_t nary_op_type(op_kind_t op_kind, const std::vector<expr_t> &args) {
246 ir_assert(!args.empty());
247 if (args.size() == 1) return args[0].type();
248
249 auto type = args[0].type();
250 for (size_t i = 1; i < args.size(); i++)
251 type = common_type(type, args[i].type());
252
253 return type;
254 }
255
normalize(expr_t & base,expr_t & off,op_kind_t op_kind)256 void ptr_t::normalize(expr_t &base, expr_t &off, op_kind_t op_kind) {
257 ir_assert(base.type().is_ptr()) << "base is not a pointer: " << base;
258 ir_assert(off.type().is_int()) << "off is not an integer: " << off;
259 ir_assert(utils::one_of(op_kind, op_kind_t::_add, op_kind_t::_sub))
260 << "Can't apply this operation to pointer: " << to_string(op_kind);
261
262 if (!base.is<ptr_t>()) {
263 if (op_kind == op_kind_t::_sub) off = const_fold(-off);
264 return;
265 }
266
267 auto &base_off = base.as<ptr_t>().off;
268 base = base.as<ptr_t>().base;
269 off = const_fold_non_recursive(binary_op_t::make(op_kind, base_off, off));
270 }
271
shift_ptr(op_kind_t op_kind,const expr_t & a,const expr_t & b)272 expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b) {
273 expr_t base = a;
274 expr_t off = b;
275 ptr_t::normalize(base, off, op_kind);
276 return ptr_t::make(base, off);
277 }
278
normalize_ptr(const type_t & type,expr_t & base_expr,expr_t & off)279 void normalize_ptr(const type_t &type, expr_t &base_expr, expr_t &off) {
280 if (base_expr.is<ptr_t>()) {
281 auto &base = base_expr.as<ptr_t>().base;
282 auto &base_off = base_expr.as<ptr_t>().off;
283
284 base_expr = base;
285 off = const_fold_non_recursive(base_off + off);
286 }
287 ir_assert(to_cpp<int64_t>(off) % type.size() == 0)
288 << "Incompatible offset: " << off;
289 }
290
operator [](const expr_t & off) const291 expr_t expr_t::operator[](const expr_t &off) const {
292 if (is<shuffle_t>()) {
293 ir_assert(is_const(off)) << "Offset is not constant.";
294 auto &shuffle = as<shuffle_t>();
295 int idx = shuffle.idx[to_cpp<int>(off)];
296 return shuffle.vec[idx];
297 }
298 return shift_ptr(op_kind_t::_add, *this, off);
299 }
300
expr_t(bool value)301 expr_t::expr_t(bool value) : object_t(new bool_imm_t(value)) {}
expr_t(float value)302 expr_t::expr_t(float value) : object_t(new float_imm_t(value)) {}
expr_t(int16_t value)303 expr_t::expr_t(int16_t value) : object_t(new int_imm_t(value)) {}
expr_t(int32_t value)304 expr_t::expr_t(int32_t value) : object_t(new int_imm_t(value)) {}
expr_t(int64_t value)305 expr_t::expr_t(int64_t value) : object_t(new int_imm_t(value)) {}
expr_t(uint16_t value)306 expr_t::expr_t(uint16_t value) : object_t(new int_imm_t(value)) {}
expr_t(uint32_t value)307 expr_t::expr_t(uint32_t value) : object_t(new int_imm_t(value)) {}
expr_t(uint64_t value)308 expr_t::expr_t(uint64_t value) : object_t(new int_imm_t(value)) {}
309
operator -(const expr_t & a)310 expr_t operator-(const expr_t &a) {
311 return const_fold_non_recursive(unary_op_t::make(op_kind_t::_minus, a));
312 }
313
314 #define DEFINE_BINARY_OPERATOR(op, op_kind) \
315 expr_t operator op(const expr_t &a, const expr_t &b) { \
316 if (a.type().is_ptr()) return shift_ptr(op_kind, a, b); \
317 return const_fold_non_recursive(binary_op_t::make(op_kind, a, b)); \
318 }
319
320 DEFINE_BINARY_OPERATOR(+, op_kind_t::_add)
321 DEFINE_BINARY_OPERATOR(-, op_kind_t::_sub)
322 DEFINE_BINARY_OPERATOR(*, op_kind_t::_mul)
323 DEFINE_BINARY_OPERATOR(/, op_kind_t::_div)
324 DEFINE_BINARY_OPERATOR(%, op_kind_t::_mod)
325 DEFINE_BINARY_OPERATOR(<<, op_kind_t::_shl)
326 DEFINE_BINARY_OPERATOR(>>, op_kind_t::_shr)
327
328 DEFINE_BINARY_OPERATOR(==, op_kind_t::_eq)
329 DEFINE_BINARY_OPERATOR(!=, op_kind_t::_ne)
330 DEFINE_BINARY_OPERATOR(>, op_kind_t::_gt)
331 DEFINE_BINARY_OPERATOR(>=, op_kind_t::_ge)
332 DEFINE_BINARY_OPERATOR(<, op_kind_t::_lt)
333 DEFINE_BINARY_OPERATOR(<=, op_kind_t::_le)
334
335 DEFINE_BINARY_OPERATOR(&, op_kind_t::_and)
336
337 #undef DEFINE_BINARY_OPERATOR
338
339 #define DEFINE_BINARY_ASSIGN_OPERATOR(op) \
340 expr_t &expr_t::operator op##=(const expr_t &rhs) { \
341 auto tmp = (*this)op rhs; \
342 *this = tmp; \
343 return *this; \
344 }
345
346 DEFINE_BINARY_ASSIGN_OPERATOR(+)
347 DEFINE_BINARY_ASSIGN_OPERATOR(-)
348 DEFINE_BINARY_ASSIGN_OPERATOR(*)
349 DEFINE_BINARY_ASSIGN_OPERATOR(/)
350 DEFINE_BINARY_ASSIGN_OPERATOR(%)
351 DEFINE_BINARY_ASSIGN_OPERATOR(&)
352
353 #undef DEFINE_BINARY_ASSIGN_OPERATOR
354
_mutate(ir_mutator_t & mutator) const355 object_t object_impl_t::_mutate(ir_mutator_t &mutator) const {
356 return *this;
357 }
358
359 #define DECL_MUTATE_LEAF(name) \
360 object_t ir_mutator_t::_mutate(const name &obj) { return obj; }
361
362 DECL_MUTATE_LEAF(bool_imm_t)
DECL_MUTATE_LEAF(float_imm_t)363 DECL_MUTATE_LEAF(float_imm_t)
364 DECL_MUTATE_LEAF(func_impl_t)
365 DECL_MUTATE_LEAF(int_imm_t)
366 DECL_MUTATE_LEAF(var_t)
367
368 #undef DECL_MUTATE_LEAF
369
370 object_t ir_mutator_t::_mutate(const alloc_t &obj) {
371 auto buf = mutate(obj.buf);
372 auto body = mutate(obj.body);
373
374 if (buf.is_same(obj.buf) && body.is_same(obj.body)) return obj;
375
376 return alloc_t::make(buf, obj.size, obj.kind, obj.attr, body);
377 }
378
_mutate(const binary_op_t & obj)379 object_t ir_mutator_t::_mutate(const binary_op_t &obj) {
380 auto a = mutate(obj.a);
381 auto b = mutate(obj.b);
382
383 if (a.is_same(obj.a) && b.is_same(obj.b)) return obj;
384
385 return binary_op_t::make(obj.op_kind, a, b);
386 }
387
_mutate(const cast_t & obj)388 object_t ir_mutator_t::_mutate(const cast_t &obj) {
389 auto expr = mutate(obj.expr);
390
391 if (expr.is_same(obj.expr)) return obj;
392
393 return cast_t::make(obj.type, expr, obj.saturate);
394 }
395
_mutate(const for_t & obj)396 object_t ir_mutator_t::_mutate(const for_t &obj) {
397 auto var = mutate(obj.var);
398 auto init = mutate(obj.init);
399 auto bound = mutate(obj.bound);
400 auto body = mutate(obj.body);
401
402 if (var.is_same(obj.var) && init.is_same(obj.init)
403 && bound.is_same(obj.bound) && body.is_same(obj.body))
404 return obj;
405
406 return for_t::make(var, init, bound, body, obj.unroll);
407 }
408
_mutate(const func_call_t & obj)409 object_t ir_mutator_t::_mutate(const func_call_t &obj) {
410 auto func = mutate(obj.func);
411 auto args = mutate(obj.args);
412
413 if (func.is_same(obj.func) && ir_utils::is_same(args, obj.args)) return obj;
414
415 return func_call_t::make(func, args, obj.attr);
416 }
417
_mutate(const if_t & obj)418 object_t ir_mutator_t::_mutate(const if_t &obj) {
419 auto cond = mutate(obj.cond);
420 auto body = mutate(obj.body);
421 auto else_body = mutate(obj.else_body);
422
423 if (cond.is_same(obj.cond) && body.is_same(obj.body)
424 && else_body.is_same(obj.else_body))
425 return obj;
426
427 return if_t::make(cond, body, else_body);
428 }
429
_mutate(const iif_t & obj)430 object_t ir_mutator_t::_mutate(const iif_t &obj) {
431 auto cond = mutate(obj.cond);
432 auto true_expr = mutate(obj.true_expr);
433 auto false_expr = mutate(obj.false_expr);
434
435 if (cond.is_same(obj.cond) && true_expr.is_same(obj.true_expr)
436 && false_expr.is_same(obj.false_expr))
437 return obj;
438
439 return iif_t::make(cond, true_expr, false_expr);
440 }
441
_mutate(const let_t & obj)442 object_t ir_mutator_t::_mutate(const let_t &obj) {
443 auto var = mutate(obj.var);
444 auto value = mutate(obj.value);
445 auto body = mutate(obj.body);
446
447 if (var.is_same(obj.var) && value.is_same(obj.value)
448 && body.is_same(obj.body))
449 return obj;
450
451 return let_t::make(var, value, body);
452 }
453
_mutate(const load_t & obj)454 object_t ir_mutator_t::_mutate(const load_t &obj) {
455 auto buf = mutate(obj.buf);
456 auto off = mutate(obj.off);
457
458 if (buf.is_same(obj.buf) && off.is_same(obj.off)) return obj;
459
460 return load_t::make(obj.type, buf, off, obj.stride);
461 }
462
_mutate(const ptr_t & obj)463 object_t ir_mutator_t::_mutate(const ptr_t &obj) {
464 auto base = mutate(obj.base);
465 auto off = mutate(obj.off);
466
467 if (base.is_same(obj.base) && off.is_same(obj.off)) return obj;
468
469 return ptr_t::make(base, off);
470 }
471
_mutate(const shuffle_t & obj)472 object_t ir_mutator_t::_mutate(const shuffle_t &obj) {
473 auto vec = mutate(obj.vec);
474
475 if (ir_utils::is_same(vec, obj.vec)) return obj;
476
477 return shuffle_t::make(vec, obj.idx);
478 }
479
_mutate(const stmt_group_t & obj)480 object_t ir_mutator_t::_mutate(const stmt_group_t &obj) {
481 auto body = mutate(obj.body);
482
483 if (body.is_same(obj.body)) return obj;
484
485 return stmt_group_t::make(obj.label, body);
486 }
487
_mutate(const stmt_seq_t & obj)488 object_t ir_mutator_t::_mutate(const stmt_seq_t &obj) {
489 auto head = mutate(obj.head);
490 auto tail = mutate(obj.tail);
491
492 if (head.is_same(obj.head) && tail.is_same(obj.tail)) return obj;
493
494 return stmt_seq_t::make(head, tail);
495 }
496
_mutate(const store_t & obj)497 object_t ir_mutator_t::_mutate(const store_t &obj) {
498 auto buf = mutate(obj.buf);
499 auto off = mutate(obj.off);
500 auto value = mutate(obj.value);
501 auto mask = mutate(obj.mask);
502
503 if (buf.is_same(obj.buf) && off.is_same(obj.off) && value.is_same(obj.value)
504 && mask.is_same(obj.mask))
505 return obj;
506
507 return store_t::make(buf, off, value, obj.stride, mask);
508 }
509
_mutate(const ternary_op_t & obj)510 object_t ir_mutator_t::_mutate(const ternary_op_t &obj) {
511 auto a = mutate(obj.a);
512 auto b = mutate(obj.b);
513 auto c = mutate(obj.c);
514
515 if (a.is_same(obj.a) && b.is_same(obj.b) && c.is_same(obj.c)) return obj;
516
517 return ternary_op_t::make(obj.op_kind, a, b, c);
518 }
519
_mutate(const unary_op_t & obj)520 object_t ir_mutator_t::_mutate(const unary_op_t &obj) {
521 auto a = mutate(obj.a);
522 if (a.is_same(obj.a)) return obj;
523 return unary_op_t::make(obj.op_kind, a);
524 }
525
526 // Catch missing mutates that are not expected to dispatch to the base
527 // mutator
_mutate(const nary_op_t & obj)528 object_t ir_mutator_t::_mutate(const nary_op_t &obj) {
529 ir_error_not_expected() << "Can't handle type: nary_op_t";
530 return {};
531 }
_mutate(const pexpr_t & obj)532 object_t ir_mutator_t::_mutate(const pexpr_t &obj) {
533 ir_error_not_expected() << "Can't handle type: pexpr_t";
534 return {};
535 }
536
537 } // namespace jit
538 } // namespace gpu
539 } // namespace impl
540 } // namespace dnnl
541