1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef GPU_JIT_CONV_IR_CORE_HPP
18 #define GPU_JIT_CONV_IR_CORE_HPP
19
20 #include <algorithm>
21 #include <atomic>
22 #include <cstdio>
23 #include <numeric>
24 #include <string>
25
26 #include "common/c_types_map.hpp"
27 #include "gpu/jit/conv/ngen_proxy.hpp"
28 #include "gpu/jit/conv/utils.hpp"
29
30 // All IR expression objects.
31 #define HANDLE_EXPR_IR_OBJECTS() \
32 HANDLE_IR_OBJECT(binary_op_t) \
33 HANDLE_IR_OBJECT(bool_imm_t) \
34 HANDLE_IR_OBJECT(cast_t) \
35 HANDLE_IR_OBJECT(float_imm_t) \
36 HANDLE_IR_OBJECT(iif_t) \
37 HANDLE_IR_OBJECT(int_imm_t) \
38 HANDLE_IR_OBJECT(load_t) \
39 HANDLE_IR_OBJECT(ptr_t) \
40 HANDLE_IR_OBJECT(shuffle_t) \
41 HANDLE_IR_OBJECT(ternary_op_t) \
42 HANDLE_IR_OBJECT(unary_op_t) \
43 HANDLE_IR_OBJECT(var_t)
44
45 // All IR statement objects.
46 #define HANDLE_STMT_IR_OBJECTS() \
47 HANDLE_IR_OBJECT(alloc_t) \
48 HANDLE_IR_OBJECT(for_t) \
49 HANDLE_IR_OBJECT(func_call_t) \
50 HANDLE_IR_OBJECT(if_t) \
51 HANDLE_IR_OBJECT(let_t) \
52 HANDLE_IR_OBJECT(stmt_group_t) \
53 HANDLE_IR_OBJECT(stmt_seq_t) \
54 HANDLE_IR_OBJECT(store_t)
55
56 #define HANDLE_MUTATE_TARGETS() \
57 HANDLE_EXPR_IR_OBJECTS() \
58 HANDLE_STMT_IR_OBJECTS() \
59 HANDLE_IR_OBJECT(func_impl_t) \
60 HANDLE_IR_OBJECT(nary_op_t) \
61 HANDLE_IR_OBJECT(pexpr_t)
62
63 #define HANDLE_ALL_IR_OBJECTS() \
64 HANDLE_EXPR_IR_OBJECTS() \
65 HANDLE_STMT_IR_OBJECTS() \
66 HANDLE_IR_OBJECT(func_impl_t)
67
68 enum ir_type_id_t {
69 #define HANDLE_IR_OBJECT(type) type,
70
71 // Create typeid for objects which can be visited/mutated. These need to be
72 // first as the typeid is used as an index into an array to dispatch to the
73 // correct mutate function.
74 HANDLE_ALL_IR_OBJECTS()
75
76 //Used to calculate number of IR objects that can be visited/mutated
77 end_visitable_ir_objects,
78
79 // Other IR object
80 expr_impl_t = end_visitable_ir_objects,
81 nary_op_t,
82 stmt_impl_t,
83 grf_alloc_attr_t,
84 instruction_modifier_attr_t,
85 builtin_t,
86 pexpr_t,
87 pint_imm_t,
88 factored_expr_t,
89 send_t,
90 dpas_t,
91 mad_t,
92 reduce_t,
93 reorder_t,
94 eltwise_t,
95
96 #undef HANDLE_IR_OBJECT
97 };
98
99 // Auxiliary macros to reduce boilerplate.
100 #define IR_DECL_TYPE_ID(class_name) \
101 using self_type = class_name; \
102 static int64_t _type_id() { return ir_type_id_t::class_name; } \
103 static int64_t _dispatch_type_id() { return _type_id(); } \
104 int64_t type_id() const override { return _type_id(); }
105
106 #define IR_DECL_DERIVED_TYPE_ID(class_name, base_name) \
107 using self_type = class_name; \
108 static int64_t _type_id() { return ir_type_id_t::class_name; } \
109 int64_t type_id() const override { return _type_id(); } \
110 static int64_t _dispatch_type_id() { return base_name::_type_id(); } \
111 int64_t dispatch_type_id() const override { return _dispatch_type_id(); }
112
113 #define IR_DECL_EXPR_TYPE_ID(class_name) \
114 IR_DECL_TYPE_ID(class_name) \
115 bool is_expr() const override { return true; }
116
117 #define IR_DECL_STMT_TYPE_ID(class_name) \
118 IR_DECL_TYPE_ID(class_name) \
119 bool is_stmt() const override { return true; }
120
121 #define IR_DECL_MUTATE(mutator_template) \
122 object_t _mutate(mutator_template &mutator) const override { \
123 return mutator._mutate(*this); \
124 }
125
126 #define IR_DECLARE_TRAVERSERS() IR_DECL_MUTATE(ir_mutator_t)
127
128 // Defines getter for a function argument.
129 #define IR_DEFINE_ARG_GET(name, index) \
130 static const expr_t &arg_##name(const stmt_t &s) { \
131 ir_assert(s.is<func_call_t>()) << s; \
132 auto &c = s.as<func_call_t>(); \
133 ir_assert(c.func.is<self_type>()) << s; \
134 return c.args[index]; \
135 } \
136 template <typename T> \
137 static T &arg_##name(std::vector<T> &args) { \
138 return args[index]; \
139 } \
140 template <typename T> \
141 static const T &arg_##name(const std::vector<T> &args) { \
142 return args[index]; \
143 }
144
145 #if defined(__GNUC__)
146 // clang-format off
147 // Defines dump() method for debugging purposes, to pretty print the object.
148 #define IR_DEFINE_DUMP() \
149 __attribute__((noinline)) \
150 __attribute__((used)) \
151 void dump() const { \
152 printf("%s\n", str().c_str()); \
153 }
154 // clang-format on
155 #else
156 #define IR_DEFINE_DUMP()
157 #endif
158
159 namespace dnnl {
160 namespace impl {
161 namespace gpu {
162 namespace jit {
163
164 enum class type_kind_t {
165 undef,
166 _bool,
167
168 // Integer types.
169 u8,
170 s8,
171 u16,
172 s16,
173 u32,
174 s32,
175 u64,
176 s64,
177
178 // Floating point types.
179 bf16,
180 f16,
181 f32,
182
183 // Message data types.
184 byte,
185 dword,
186 qword,
187 oword,
188 hword
189 };
190
191 std::string to_string(type_kind_t kind);
192
193 class type_t {
194 public:
undef()195 static type_t undef() { return type_t(type_kind_t::undef); }
_bool(int elems=1)196 static type_t _bool(int elems = 1) {
197 return type_t(type_kind_t::_bool, elems);
198 }
199
u8(int elems=1)200 static type_t u8(int elems = 1) { return type_t(type_kind_t::u8, elems); }
s8(int elems=1)201 static type_t s8(int elems = 1) { return type_t(type_kind_t::s8, elems); }
u16(int elems=1)202 static type_t u16(int elems = 1) { return type_t(type_kind_t::u16, elems); }
s16(int elems=1)203 static type_t s16(int elems = 1) { return type_t(type_kind_t::s16, elems); }
u32(int elems=1)204 static type_t u32(int elems = 1) { return type_t(type_kind_t::u32, elems); }
s32(int elems=1)205 static type_t s32(int elems = 1) { return type_t(type_kind_t::s32, elems); }
u64(int elems=1)206 static type_t u64(int elems = 1) { return type_t(type_kind_t::u64, elems); }
s64(int elems=1)207 static type_t s64(int elems = 1) { return type_t(type_kind_t::s64, elems); }
208
209 // Returns unsigned integer type.
u(int bits,int elems=1)210 static type_t u(int bits, int elems = 1) {
211 switch (bits) {
212 case 8: return u8(elems);
213 case 16: return u16(elems);
214 case 32: return u32(elems);
215 case 64: return u64(elems);
216 default: ir_error_not_expected();
217 }
218 return type_t::undef();
219 }
220
221 // Returns signed integer type.
s(int bits,int elems=1)222 static type_t s(int bits, int elems = 1) {
223 switch (bits) {
224 case 8: return s8(elems);
225 case 16: return s16(elems);
226 case 32: return s32(elems);
227 case 64: return s64(elems);
228 default: ir_error_not_expected();
229 }
230 return type_t::undef();
231 }
232
bf16(int elems=1)233 static type_t bf16(int elems = 1) {
234 return type_t(type_kind_t::bf16, elems);
235 }
f16(int elems=1)236 static type_t f16(int elems = 1) { return type_t(type_kind_t::f16, elems); }
f32(int elems=1)237 static type_t f32(int elems = 1) { return type_t(type_kind_t::f32, elems); }
238
byte(int elems=1)239 static type_t byte(int elems = 1) {
240 return type_t(type_kind_t::byte, elems);
241 }
byte_ptr(int elems=1)242 static type_t byte_ptr(int elems = 1) {
243 return type_t(type_kind_t::byte, elems).with_ptr();
244 }
dword(int elems=1)245 static type_t dword(int elems = 1) {
246 return type_t(type_kind_t::dword, elems);
247 }
qword(int elems=1)248 static type_t qword(int elems = 1) {
249 return type_t(type_kind_t::qword, elems);
250 }
oword(int elems=1)251 static type_t oword(int elems = 1) {
252 return type_t(type_kind_t::oword, elems);
253 }
hword(int elems=1)254 static type_t hword(int elems = 1) {
255 return type_t(type_kind_t::hword, elems);
256 }
257
258 template <typename T>
from_cpp()259 static type_t from_cpp() {
260 #define CASE(cpp_type, type) \
261 if (std::is_same<T, cpp_type>::value) return type()
262
263 CASE(bool, _bool);
264 CASE(float, f32);
265 CASE(int16_t, s16);
266 CASE(int32_t, s32);
267 CASE(int64_t, s64);
268 CASE(uint16_t, u16);
269 CASE(uint32_t, u32);
270 CASE(uint64_t, u64);
271
272 #undef CASE
273
274 ir_error_not_expected();
275
276 return undef();
277 }
278
is_vector(int elems)279 static bool is_vector(int elems) { return elems != 1; }
280
type_t()281 type_t() : type_t(type_t::undef()) {}
282
type_t(type_kind_t kind,uint32_t elems=1)283 type_t(type_kind_t kind, uint32_t elems = 1) : kind_(kind), elems_(elems) {}
284
285 // Constructor from dnnl_data_type_t.
type_t(data_type_t dt)286 type_t(data_type_t dt) {
287 elems_ = 1;
288 switch (dt) {
289 #define CASE(x) \
290 case data_type::x: kind_ = type_kind_t::x; break;
291 CASE(bf16);
292 CASE(f16);
293 CASE(f32);
294 CASE(s32);
295 CASE(s8);
296 CASE(u8);
297 #undef CASE
298 default: ir_error_not_expected();
299 }
300 }
301
kind() const302 type_kind_t kind() const { return kind_; }
303
elems() const304 int elems() const { return elems_; }
305
is_ptr() const306 bool is_ptr() const { return is_ptr_; }
307
operator ==(const type_t & other) const308 bool operator==(const type_t &other) const {
309 return (kind() == other.kind()) && (elems() == other.elems())
310 && (is_ptr() == other.is_ptr());
311 }
312
operator !=(const type_t & other) const313 bool operator!=(const type_t &other) const { return !operator==(other); }
314
is_equal(const type_t & other) const315 bool is_equal(const type_t &other) const { return operator==(other); }
316
get_hash() const317 size_t get_hash() const {
318 return ir_utils::get_hash(kind(), elems(), is_ptr());
319 }
320
is_undef() const321 bool is_undef() const { return kind() == type_kind_t::undef; }
322
is_vector() const323 bool is_vector() const { return type_t::is_vector(elems()); }
324
is_bool() const325 bool is_bool() const { return kind() == type_kind_t::_bool; }
326
is_fp() const327 bool is_fp() const {
328 return utils::one_of(
329 kind(), type_kind_t::bf16, type_kind_t::f16, type_kind_t::f32);
330 }
331
is_bf16() const332 bool is_bf16() const { return kind() == type_kind_t::bf16; }
is_f16() const333 bool is_f16() const { return kind() == type_kind_t::f16; }
is_f32() const334 bool is_f32() const { return kind() == type_kind_t::f32; }
335
is_int() const336 bool is_int() const {
337 return utils::one_of(kind(), type_kind_t::u8, type_kind_t::s8,
338 type_kind_t::u16, type_kind_t::s16, type_kind_t::u32,
339 type_kind_t::s32, type_kind_t::u64, type_kind_t::s64);
340 }
341
is_s8() const342 bool is_s8() const { return kind() == type_kind_t::s8; }
is_u8() const343 bool is_u8() const { return kind() == type_kind_t::u8; }
is_x8() const344 bool is_x8() const {
345 return utils::one_of(kind(), type_kind_t::s8, type_kind_t::u8);
346 }
347
is_s16() const348 bool is_s16() const { return kind() == type_kind_t::s16; }
is_u16() const349 bool is_u16() const { return kind() == type_kind_t::u16; }
is_x16() const350 bool is_x16() const {
351 return utils::one_of(kind(), type_kind_t::s16, type_kind_t::u16);
352 }
353
is_s32() const354 bool is_s32() const { return kind() == type_kind_t::s32; }
is_u32() const355 bool is_u32() const { return kind() == type_kind_t::u32; }
is_x32() const356 bool is_x32() const {
357 return utils::one_of(kind(), type_kind_t::s32, type_kind_t::u32);
358 }
359
is_signed(int elems=-1) const360 bool is_signed(int elems = -1) const {
361 if (elems != -1 && elems_ != elems) return false;
362 return utils::one_of(kind(), type_kind_t::s8, type_kind_t::s16,
363 type_kind_t::s32, type_kind_t::s64);
364 }
365
is_unsigned(int elems=-1) const366 bool is_unsigned(int elems = -1) const {
367 if (elems != -1 && elems_ != elems) return false;
368 return utils::one_of(kind(), type_kind_t::u8, type_kind_t::u16,
369 type_kind_t::u32, type_kind_t::u64);
370 }
371
is_scalar() const372 bool is_scalar() const { return elems() == 1; }
373
374 template <typename T>
is_cpp() const375 bool is_cpp() const {
376 return *this == type_t::from_cpp<T>();
377 }
378
remove_elems() const379 type_t remove_elems() const { return with_elems(1); }
380
remove_ptr() const381 type_t remove_ptr() const {
382 type_t copy = *this;
383 copy.is_ptr_ = false;
384 return copy;
385 }
386
with_elems(int new_elems) const387 type_t with_elems(int new_elems) const {
388 type_t copy = *this;
389 copy.elems_ = new_elems;
390 return copy;
391 }
392
with_ptr() const393 type_t with_ptr() const {
394 type_t copy = *this;
395 copy.is_ptr_ = true;
396 return copy;
397 }
398
scalar() const399 type_t scalar() const { return with_elems(1); }
400
401 // Returns size in bytes.
402 int size() const;
403
str() const404 std::string str() const {
405 std::ostringstream oss;
406 oss << to_string(kind());
407 if (elems() > 1) oss << "x" << elems();
408 if (is_ptr()) oss << "*";
409 return oss.str();
410 }
411
412 IR_DEFINE_DUMP()
413
414 private:
415 type_kind_t kind_ = type_kind_t::undef;
416 int elems_ = 0;
417 bool is_ptr_ = false;
418 };
419
operator <<(std::ostream & out,const type_t & type)420 inline std::ostream &operator<<(std::ostream &out, const type_t &type) {
421 out << type.str();
422 return out;
423 }
424
425 // type_t to dnnl_data_type_t convertor.
426 data_type_t to_dnnl(const type_t &type);
427
428 // Reference counter for IR objects.
429 class ref_count_t {
430 public:
ref_count_t()431 ref_count_t() : value_(0) {}
432 ref_count_t(const ref_count_t &) = delete;
433
increment()434 uint32_t increment() { return ++value_; }
decrement()435 uint32_t decrement() { return --value_; }
436
437 private:
438 std::atomic<uint32_t> value_;
439 };
440
441 // Forward Declare IR objects
442 class object_t;
443 class ir_mutator_t;
444 #define HANDLE_IR_OBJECT(type) class type;
445 HANDLE_MUTATE_TARGETS()
446 #undef HANDLE_IR_OBJECT
447
448 // Base class for all IR objects. Implemented as an intrusive pointer, with
449 // the reference counter stored inside the object.
450 class object_impl_t {
451 public:
452 object_impl_t() = default;
453
454 object_impl_t(const object_impl_t &) = delete;
455
456 virtual ~object_impl_t() = default;
457
ref_count()458 ref_count_t &ref_count() { return ref_count_; }
459
460 // Unique type ID.
461 virtual int64_t type_id() const = 0;
462
463 // Type ID used for dispatching in ir_visitor_t and ir_mutator_t.
464 // For some IR objects
dispatch_type_id() const465 virtual int64_t dispatch_type_id() const { return type_id(); }
466
467 // Provides equality semantics.
468 virtual bool is_equal(const object_impl_t &obj) const = 0;
469
470 virtual size_t get_hash() const = 0;
471
is_expr() const472 virtual bool is_expr() const { return false; }
is_stmt() const473 virtual bool is_stmt() const { return false; }
474
475 // Downcasts the object to the IR type, returns a reference. The IR type
476 // must match the real IR type.
477 template <typename T>
as() const478 const T &as() const {
479 ir_assert(type_id() == T::_type_id());
480 return *(const T *)this;
481 }
482
483 template <typename T>
as()484 T &as() {
485 ir_assert(type_id() == T::_type_id());
486 return *(T *)this;
487 }
488
489 // Downcasts the object to the IR type, returns a pointer. If the IR type
490 // doesn't match the real IR type, returns nullptr.
491 template <typename T>
as_ptr() const492 const T *as_ptr() const {
493 if (type_id() != T::_type_id()) return nullptr;
494 return (const T *)this;
495 }
496
497 template <typename T>
as_ptr()498 T *as_ptr() {
499 if (type_id() != T::_type_id()) return nullptr;
500 return (T *)this;
501 }
502
503 // Returns true if T matches the real IR type.
504 template <typename T>
is() const505 bool is() const {
506 return type_id() == T::_type_id();
507 }
508
509 virtual std::string str() const;
510
511 virtual object_t _mutate(ir_mutator_t &mutator) const;
512 IR_DEFINE_DUMP()
513
514 private:
515 ref_count_t ref_count_;
516 };
517
518 // Base wrapper for IR objects.
519 class object_t {
520 public:
object_t(object_impl_t * impl=nullptr)521 object_t(object_impl_t *impl = nullptr) : impl_(impl) {
522 increment(impl_);
523 #ifndef NDEBUG
524 sanity_check();
525 #endif
526 }
object_t(const object_impl_t & impl)527 object_t(const object_impl_t &impl)
528 : object_t(const_cast<object_impl_t *>(&impl)) {}
object_t(const object_impl_t * impl)529 object_t(const object_impl_t *impl)
530 : object_t(const_cast<object_impl_t *>(impl)) {}
object_t(const object_t & obj)531 object_t(const object_t &obj) : object_t(obj.impl()) {}
object_t(object_t && obj)532 object_t(object_t &&obj) : impl_(obj.impl_) {
533 obj.impl_ = nullptr;
534 #ifndef NDEBUG
535 sanity_check();
536 #endif
537 }
538
~object_t()539 virtual ~object_t() { decrement_and_maybe_destroy(impl_); }
540
operator =(const object_t & other)541 object_t &operator=(const object_t &other) {
542 increment(other.impl());
543 decrement_and_maybe_destroy(impl_);
544 impl_ = other.impl();
545 #ifndef NDEBUG
546 sanity_check();
547 #endif
548 return *this;
549 }
550
operator =(object_t && other)551 object_t &operator=(object_t &&other) {
552 std::swap(impl_, other.impl_);
553 #ifndef NDEBUG
554 sanity_check();
555 #endif
556 return *this;
557 }
558
impl() const559 object_impl_t *impl() const { return impl_; }
560
is_empty() const561 bool is_empty() const { return !impl_; }
562
type_id() const563 int64_t type_id() const { return impl_->type_id(); }
564
dispatch_type_id() const565 int64_t dispatch_type_id() const { return impl_->dispatch_type_id(); }
566
567 template <typename T>
as() const568 const T &as() const {
569 ir_assert(impl_);
570 return impl_->as<T>();
571 }
572
573 template <typename T>
as()574 T &as() {
575 ir_assert(impl_);
576 return impl_->as<T>();
577 }
578
579 template <typename T>
as_ptr() const580 const T *as_ptr() const {
581 if (!impl_) return nullptr;
582 return impl_->as_ptr<T>();
583 }
584
585 template <typename T>
as_ptr()586 T *as_ptr() {
587 if (!impl_) return nullptr;
588 return impl_->as_ptr<T>();
589 }
590
591 template <typename T>
is() const592 bool is() const {
593 if (is_empty()) return false;
594 return impl_->is<T>();
595 }
596
597 // Comparison with identity semantics.
is_same(const object_t & other) const598 bool is_same(const object_t &other) const { return impl_ == other.impl(); }
599
600 // Comparison with equality semantics.
is_equal(const object_t & other) const601 bool is_equal(const object_t &other) const {
602 if (is_empty() || other.is_empty())
603 return is_empty() == other.is_empty();
604
605 return impl_->is_equal(*other.impl());
606 }
607
get_hash() const608 size_t get_hash() const {
609 if (is_empty()) return 0;
610 return impl()->get_hash();
611 }
612
is_expr() const613 bool is_expr() const { return impl_ && impl_->is_expr(); }
is_stmt() const614 bool is_stmt() const { return impl_ && impl_->is_stmt(); }
615
str() const616 std::string str() const {
617 if (is_empty()) return "(nil)";
618 return impl()->str();
619 }
620
621 IR_DEFINE_DUMP()
622
623 protected:
sanity_check() const624 virtual void sanity_check() const {}
625
626 private:
increment(object_impl_t * impl)627 static void increment(object_impl_t *impl) {
628 if (!impl) return;
629 impl->ref_count().increment();
630 }
631
decrement_and_maybe_destroy(object_impl_t * impl)632 static void decrement_and_maybe_destroy(object_impl_t *impl) {
633 if (!impl) return;
634 if (impl->ref_count().decrement() == 0) { delete impl; }
635 }
636
637 object_impl_t *impl_;
638 };
639
operator <<(std::ostream & out,const object_t & obj)640 inline std::ostream &operator<<(std::ostream &out, const object_t &obj) {
641 out << obj.str();
642 return out;
643 }
644
645 // Helper classes for containers to store object_t.
646 struct object_id_hash_t {
operator ()dnnl::impl::gpu::jit::object_id_hash_t647 size_t operator()(const object_t &obj) const {
648 return std::hash<const object_impl_t *>()(obj.impl());
649 }
650 };
651
652 struct object_eq_hash_t {
operator ()dnnl::impl::gpu::jit::object_eq_hash_t653 size_t operator()(const object_t &obj) const { return obj.get_hash(); }
654 };
655
656 struct object_id_equal_t {
operator ()dnnl::impl::gpu::jit::object_id_equal_t657 bool operator()(const object_t &a, const object_t &b) const {
658 return a.is_same(b);
659 }
660 };
661
662 struct object_eq_equal_t {
operator ()dnnl::impl::gpu::jit::object_eq_equal_t663 bool operator()(const object_t &a, const object_t &b) const {
664 return a.is_equal(b);
665 }
666 };
667
668 // Containers to store object_t.
669
670 // Unordered set, uses identity comparison for keys.
671 template <typename KeyT>
672 using object_set_t
673 = std::unordered_set<KeyT, object_id_hash_t, object_id_equal_t>;
674
675 // Unordered set, uses equality comparison for keys.
676 template <typename KeyT>
677 using object_eq_set_t
678 = std::unordered_set<KeyT, object_eq_hash_t, object_eq_equal_t>;
679
680 // Unordered map, uses identity comparison for keys.
681 template <typename KeyT, typename ValueT>
682 using object_map_t
683 = std::unordered_map<KeyT, ValueT, object_id_hash_t, object_id_equal_t>;
684
685 // Unordered map, uses equality comparison for keys.
686 template <typename KeyT, typename ValueT>
687 using object_eq_map_t
688 = std::unordered_map<KeyT, ValueT, object_eq_hash_t, object_eq_equal_t>;
689
690 // Helper class to mutate IR tree.
691 class ir_mutator_t {
692 public:
693 virtual ~ir_mutator_t() = default;
694
mutate(const object_t & obj)695 object_t mutate(const object_t &obj) {
696 auto impl = obj.impl();
697 if (!impl) return impl;
698 return impl->_mutate(*this);
699 }
700
701 template <typename T>
mutate(const std::vector<T> & v)702 std::vector<T> mutate(const std::vector<T> &v) {
703 std::vector<T> new_v;
704 for (auto &e : v)
705 new_v.push_back(mutate(e));
706 return new_v;
707 }
708
709 // To catch missing _mutate() handlers in ir_mutator_t.
_mutate(const object_impl_t & obj)710 object_t _mutate(const object_impl_t &obj) {
711 ir_error_not_expected() << "Can't handle type: " << object_t(&obj);
712 return {};
713 }
714
715 #define HANDLE_IR_OBJECT(type) virtual object_t _mutate(const type &obj);
716 HANDLE_MUTATE_TARGETS()
717 #undef HANDLE_IR_OBJECT
718 };
719
720 // Base class for IR expression objects.
721 class expr_impl_t : public object_impl_t {
722 public:
723 IR_DECL_TYPE_ID(expr_impl_t)
724
expr_impl_t(const type_t & type)725 expr_impl_t(const type_t &type) : type(type) {}
726
727 type_t type;
728 };
729
730 // Wrapper for IR expression objects.
731 class expr_t : public object_t {
732 public:
733 using object_t::object_t;
734
735 expr_t() = default;
expr_t(const object_t & obj)736 expr_t(const object_t &obj) : object_t(obj) {}
expr_t(object_t && obj)737 expr_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)738 expr_t &operator=(const object_t &obj) {
739 object_t::operator=(obj);
740 return *this;
741 }
operator =(object_t && obj)742 expr_t &operator=(object_t &&obj) {
743 object_t::operator=(obj);
744 return *this;
745 }
746
747 explicit expr_t(bool v);
748 expr_t(float v);
749 expr_t(int16_t v);
750 expr_t(int32_t v);
751 expr_t(int64_t v);
752 expr_t(uint16_t v);
753 expr_t(uint32_t v);
754 expr_t(uint64_t v);
755
type() const756 const type_t &type() const {
757 ir_assert(!is_empty());
758 return ((const expr_impl_t *)impl())->type;
759 }
760
761 #define DECLARE_BINARY_ASSIGN_OPERATOR(op) \
762 expr_t &operator op##=(const expr_t &rhs);
763
764 DECLARE_BINARY_ASSIGN_OPERATOR(+)
765 DECLARE_BINARY_ASSIGN_OPERATOR(-)
766 DECLARE_BINARY_ASSIGN_OPERATOR(*)
767 DECLARE_BINARY_ASSIGN_OPERATOR(/)
768 DECLARE_BINARY_ASSIGN_OPERATOR(%)
769 DECLARE_BINARY_ASSIGN_OPERATOR(&)
770
771 #undef DECLARE_BINARY_ASSIGN_OPERATOR
772
773 // Returns a pointer shifted by `off` bytes relative to this pointer. The
774 // base expression must be a pointer.
775 expr_t operator[](const expr_t &off) const;
776
777 private:
sanity_check() const778 void sanity_check() const override {
779 ir_assert(dynamic_cast<const expr_impl_t *>(impl()) == impl())
780 << object_t(impl());
781 }
782 };
783
784 // Helper functions.
785 inline bool is_const(const expr_t &e);
786 inline bool is_var(const expr_t &e);
787
788 // Unary and binary operators.
789 enum class op_kind_t {
790 undef,
791
792 _minus,
793 _add,
794 _sub,
795 _mul,
796 _div,
797 _mod,
798 _shl,
799 _shr,
800 _min,
801 _max,
802
803 _lt,
804 _le,
805 _gt,
806 _ge,
807 _ne,
808 _eq,
809
810 _and,
811
812 _add3, // a + b + c
813 _mad, // a + b * c
814 };
815
816 std::string to_string(op_kind_t kind);
817
operator <<(std::ostream & out,op_kind_t kind)818 inline std::ostream &operator<<(std::ostream &out, op_kind_t kind) {
819 out << to_string(kind);
820 return out;
821 }
822
823 bool is_cmp_op(op_kind_t op_kind);
824
825 op_kind_t negate_cmp_op(op_kind_t op_kind);
826
827 type_t unary_op_type(op_kind_t op_kind, const expr_t &a);
828
829 type_t common_int_type(const type_t &_a, const type_t &_b);
830
831 type_t common_type(const type_t &a, const type_t &b);
832
833 type_t common_type(const expr_t &a, const expr_t &b);
834
835 type_t binary_op_type(op_kind_t op_kind, const expr_t &a, const expr_t &b);
836
837 type_t ternary_op_type(
838 op_kind_t op_kind, const expr_t &a, const expr_t &b, const expr_t &c);
839
840 type_t nary_op_type(op_kind_t op_kind, const std::vector<expr_t> &args);
841
842 // Binary operation: (a op b).
843 class binary_op_t : public expr_impl_t {
844 public:
IR_DECL_EXPR_TYPE_ID(binary_op_t)845 IR_DECL_EXPR_TYPE_ID(binary_op_t)
846
847 static expr_t make(op_kind_t op_kind, const expr_t &a, const expr_t &b) {
848 return expr_t(new binary_op_t(op_kind, a, b));
849 }
850
is_equal(const object_impl_t & obj) const851 bool is_equal(const object_impl_t &obj) const override {
852 if (!obj.is<self_type>()) return false;
853 auto &other = obj.as<self_type>();
854
855 return (op_kind == other.op_kind) && a.is_equal(other.a)
856 && b.is_equal(other.b);
857 }
858
get_hash() const859 size_t get_hash() const override {
860 return ir_utils::get_hash(op_kind, a, b);
861 }
862
863 IR_DECLARE_TRAVERSERS()
864
865 op_kind_t op_kind;
866 expr_t a;
867 expr_t b;
868
869 private:
binary_op_t(op_kind_t op_kind,const expr_t & a,const expr_t & b)870 binary_op_t(op_kind_t op_kind, const expr_t &a, const expr_t &b)
871 : expr_impl_t(binary_op_type(op_kind, a, b))
872 , op_kind(op_kind)
873 , a(a)
874 , b(b) {}
875 };
876
877 // Boolean immediate value.
878 class bool_imm_t : public expr_impl_t {
879 public:
880 friend class expr_t;
IR_DECL_EXPR_TYPE_ID(bool_imm_t)881 IR_DECL_EXPR_TYPE_ID(bool_imm_t)
882
883 static expr_t make(bool value) { return expr_t(new bool_imm_t(value)); }
884
is_equal(const object_impl_t & obj) const885 bool is_equal(const object_impl_t &obj) const override {
886 if (!obj.is<self_type>()) return false;
887 auto &other = obj.as<self_type>();
888
889 return value == other.value;
890 }
891
get_hash() const892 size_t get_hash() const override { return ir_utils::get_hash(value); }
893
894 bool value;
895
896 private:
bool_imm_t(bool value)897 bool_imm_t(bool value) : expr_impl_t(type_t::_bool()), value(value) {}
898 };
899
900 // Cast between data types.
901 class cast_t : public expr_impl_t {
902 public:
IR_DECL_EXPR_TYPE_ID(cast_t)903 IR_DECL_EXPR_TYPE_ID(cast_t)
904
905 static expr_t make(
906 const type_t &type, const expr_t &expr, bool saturate = false) {
907 return expr_t(new cast_t(type, expr, saturate));
908 }
909
is_equal(const object_impl_t & obj) const910 bool is_equal(const object_impl_t &obj) const override {
911 if (!obj.is<self_type>()) return false;
912 auto &other = obj.as<self_type>();
913
914 return (type == other.type) && expr.is_equal(other.expr)
915 && (saturate == other.saturate);
916 }
917
get_hash() const918 size_t get_hash() const override {
919 return ir_utils::get_hash(type, expr, saturate);
920 }
921
922 IR_DECLARE_TRAVERSERS()
923
924 expr_t expr;
925 bool saturate;
926
927 private:
cast_t(const type_t & type,const expr_t & expr,bool saturate)928 cast_t(const type_t &type, const expr_t &expr, bool saturate)
929 : expr_impl_t(type), expr(expr), saturate(saturate) {
930 ir_assert(type.elems() == expr.type().elems())
931 << "Number of elements must match.";
932 }
933 };
934
935 // Floating-point immediate value.
936 class float_imm_t : public expr_impl_t {
937 public:
938 friend class expr_t;
IR_DECL_EXPR_TYPE_ID(float_imm_t)939 IR_DECL_EXPR_TYPE_ID(float_imm_t)
940
941 static expr_t make(float value) { return expr_t(new float_imm_t(value)); }
942
is_equal(const object_impl_t & obj) const943 bool is_equal(const object_impl_t &obj) const override {
944 if (!obj.is<self_type>()) return false;
945 auto &other = obj.as<self_type>();
946
947 return value == other.value;
948 }
949
get_hash() const950 size_t get_hash() const override { return ir_utils::get_hash(value); }
951
952 float value;
953
954 private:
float_imm_t(float value)955 float_imm_t(float value) : expr_impl_t(type_t::f32()), value(value) {}
956 };
957
958 // Immediate if or the conditional (ternary) operator.
959 // C++ equivalent: (cond ? true_expr : false_expr).
960 class iif_t : public expr_impl_t {
961 public:
962 IR_DECL_EXPR_TYPE_ID(iif_t);
963
make(const expr_t & cond,const expr_t & true_expr,const expr_t & false_expr)964 static expr_t make(const expr_t &cond, const expr_t &true_expr,
965 const expr_t &false_expr) {
966 return expr_t(new iif_t(cond, true_expr, false_expr));
967 }
968
is_equal(const object_impl_t & obj) const969 bool is_equal(const object_impl_t &obj) const override {
970 if (!obj.is<self_type>()) return false;
971 auto &other = obj.as<self_type>();
972
973 return cond.is_equal(other.cond) && true_expr.is_equal(other.true_expr)
974 && false_expr.is_equal(other.false_expr);
975 }
976
get_hash() const977 size_t get_hash() const override {
978 return ir_utils::get_hash(cond, true_expr, false_expr);
979 }
980
981 IR_DECLARE_TRAVERSERS()
982
983 expr_t cond;
984 expr_t true_expr;
985 expr_t false_expr;
986
987 private:
iif_t(const expr_t & cond,const expr_t & true_expr,const expr_t & false_expr)988 iif_t(const expr_t &cond, const expr_t &true_expr, const expr_t &false_expr)
989 : expr_impl_t(common_type(true_expr.type(), false_expr.type()))
990 , cond(cond)
991 , true_expr(true_expr)
992 , false_expr(false_expr) {}
993 };
994
995 // Integer immediate value.
996 class int_imm_t : public expr_impl_t {
997 public:
998 friend class expr_t;
999 IR_DECL_EXPR_TYPE_ID(int_imm_t);
1000
1001 template <typename T>
make(T value,const type_t & type=type_t::undef ())1002 static expr_t make(T value, const type_t &type = type_t::undef()) {
1003 return expr_t(new int_imm_t(value, type));
1004 }
1005
is_equal(const object_impl_t & obj) const1006 bool is_equal(const object_impl_t &obj) const override {
1007 if (!obj.is<self_type>()) return false;
1008 auto &other = obj.as<self_type>();
1009
1010 return value == other.value;
1011 }
1012
get_hash() const1013 size_t get_hash() const override { return ir_utils::get_hash(value); }
1014
shrink_type(const expr_t & e)1015 static expr_t shrink_type(const expr_t &e) {
1016 auto &imm = e.as<int_imm_t>();
1017 type_t new_type = shrink_type(imm.value);
1018 if (new_type == imm.type) return e;
1019 return make(imm.value, new_type);
1020 }
1021
1022 template <typename T>
try_shrink_type(int64_t v)1023 static bool try_shrink_type(int64_t v) {
1024 if (v >= std::numeric_limits<T>::min()
1025 && v <= std::numeric_limits<T>::max())
1026 return true;
1027 return false;
1028 }
1029
1030 int64_t value;
1031
1032 private:
int_imm_t(int64_t value,const type_t & type=type_t::undef ())1033 int_imm_t(int64_t value, const type_t &type = type_t::undef())
1034 : expr_impl_t(type.is_undef() ? shrink_type(value) : type)
1035 , value(value) {}
1036
shrink_type(int64_t v)1037 static type_t shrink_type(int64_t v) {
1038 if (try_shrink_type<int32_t>(v)) return type_t::s32();
1039 return type_t::s64();
1040 }
1041 };
1042
1043 // Updates `base_expr` and `off` so that after return:
1044 // - base_expr contains a variable of a pointer type
1045 // - off contains an offset
1046 void normalize_ptr(const type_t &type, expr_t &base, expr_t &off);
1047
1048 // Load from a GRF buffer.
1049 // C++ equivalent (when type is scalar):
1050 // load = *(type *)(&buf[off]);
1051 // C++ equivalent (when type is vector):
1052 // int _stride = (has_default_stride() ? sizeof(scalar_type) : stride);
1053 // for (int i = 0; i < elems; i++) {
1054 // load[i] = *(scalar_type *)(&buf[off + i * _stride]);
1055 // }
1056 class load_t : public expr_impl_t {
1057 public:
IR_DECL_EXPR_TYPE_ID(load_t)1058 IR_DECL_EXPR_TYPE_ID(load_t)
1059
1060 // offset and stride are expressed in bytes.
1061 // default stride means unit stride (in terms of value.type().scalar()
1062 // elements).
1063 static expr_t make(const type_t &type, const expr_t &buf, const expr_t &off,
1064 int stride = default_stride) {
1065 return expr_t(new load_t(type, buf, off, stride));
1066 }
1067
is_equal(const object_impl_t & obj) const1068 bool is_equal(const object_impl_t &obj) const override {
1069 if (!obj.is<self_type>()) return false;
1070 auto &other = obj.as<self_type>();
1071
1072 return type.is_equal(other.type) && buf.is_equal(other.buf)
1073 && off.is_equal(other.off) && (stride == other.stride);
1074 }
1075
get_hash() const1076 size_t get_hash() const override {
1077 return ir_utils::get_hash(type, buf, off, stride);
1078 }
1079
has_default_stride() const1080 bool has_default_stride() const { return stride == default_stride; }
1081
1082 IR_DECLARE_TRAVERSERS()
1083
1084 static const int default_stride = -1;
1085
1086 expr_t buf;
1087 expr_t off;
1088 int stride;
1089
1090 private:
load_t(const type_t & _type,const expr_t & _buf,const expr_t & _off,int _stride)1091 load_t(const type_t &_type, const expr_t &_buf, const expr_t &_off,
1092 int _stride)
1093 : expr_impl_t(_type), buf(_buf), off(_off), stride(_stride) {
1094 normalize_ptr(type, buf, off);
1095 ir_assert(is_var(buf)) << buf;
1096 ir_assert(buf.type().is_ptr()) << buf;
1097 if (stride == type.scalar().size()) stride = default_stride;
1098 }
1099 };
1100
1101 // N-ary expression: (a[0] op a[1] op ... op a[n - 1]),
1102 // where <op> is either addition or multiplication.
1103 class nary_op_t : public expr_impl_t {
1104 public:
IR_DECL_EXPR_TYPE_ID(nary_op_t)1105 IR_DECL_EXPR_TYPE_ID(nary_op_t)
1106
1107 static expr_t make(op_kind_t op_kind, const std::vector<expr_t> &args) {
1108 return expr_t(new nary_op_t(op_kind, args));
1109 }
1110
is_equal(const object_impl_t & obj) const1111 bool is_equal(const object_impl_t &obj) const override {
1112 if (!obj.is<self_type>()) return false;
1113 auto &other = obj.as<self_type>();
1114
1115 return (op_kind == other.op_kind)
1116 && ir_utils::is_equal(args, other.args);
1117 }
1118
get_hash() const1119 size_t get_hash() const override {
1120 return ir_utils::get_hash(op_kind, args);
1121 }
1122
str() const1123 std::string str() const override {
1124 std::ostringstream oss;
1125 oss << "(";
1126 for (size_t i = 0; i < args.size(); i++) {
1127 oss << (i != 0 ? " " + to_string(op_kind) + " " : "") << args[i];
1128 }
1129
1130 oss << ")";
1131 return oss.str();
1132 }
1133
1134 IR_DECLARE_TRAVERSERS()
1135
1136 op_kind_t op_kind;
1137 std::vector<expr_t> args;
1138
1139 private:
nary_op_t(op_kind_t op_kind,const std::vector<expr_t> & args)1140 nary_op_t(op_kind_t op_kind, const std::vector<expr_t> &args)
1141 : expr_impl_t(nary_op_type(op_kind, args))
1142 , op_kind(op_kind)
1143 , args(args) {}
1144 };
1145
1146 // Pointer expression: (base_ptr + off).
1147 class ptr_t : public expr_impl_t {
1148 public:
IR_DECL_EXPR_TYPE_ID(ptr_t)1149 IR_DECL_EXPR_TYPE_ID(ptr_t)
1150
1151 // off - offset in bytes.
1152 static expr_t make(const expr_t &base, const expr_t &off) {
1153 return expr_t(new ptr_t(base, off));
1154 }
1155
is_equal(const object_impl_t & obj) const1156 bool is_equal(const object_impl_t &obj) const override {
1157 if (!obj.is<self_type>()) return false;
1158 auto &other = obj.as<self_type>();
1159
1160 return base.is_equal(other.base) && off.is_equal(other.off);
1161 }
1162
get_hash() const1163 size_t get_hash() const override { return ir_utils::get_hash(base, off); }
1164
1165 // Normalizes (base op off) pointer so that the new base is a variable and
1166 // off is an offset expression.
1167 // Example:
1168 // Before call: base = (base0 + off0), off = off1
1169 // After call: base = base0, off = off0 + off1
1170 static void normalize(
1171 expr_t &base, expr_t &off, op_kind_t op_kind = op_kind_t::_add);
1172
1173 IR_DECLARE_TRAVERSERS()
1174
1175 expr_t base;
1176 expr_t off;
1177
1178 private:
ptr_t(const expr_t & base,const expr_t & off)1179 ptr_t(const expr_t &base, const expr_t &off)
1180 : expr_impl_t(base.type()), base(base), off(off) {
1181 normalize(this->base, this->off);
1182 }
1183 };
1184
1185 class shuffle_t : public expr_impl_t {
1186 public:
IR_DECL_EXPR_TYPE_ID(shuffle_t)1187 IR_DECL_EXPR_TYPE_ID(shuffle_t)
1188
1189 static expr_t make(
1190 const std::vector<expr_t> &vec, const std::vector<int> &idx) {
1191 if (idx.size() == 1) return vec[idx[0]];
1192 return expr_t(new shuffle_t(vec, idx));
1193 }
1194
make(const std::vector<expr_t> & _vec,bool find_equal=true)1195 static expr_t make(
1196 const std::vector<expr_t> &_vec, bool find_equal = true) {
1197 std::vector<expr_t> vec;
1198 std::vector<int> idx;
1199 for (auto &v : _vec) {
1200 bool found = false;
1201 int size = int(vec.size());
1202 if (find_equal) {
1203 for (int i = 0; i < size; i++) {
1204 if (v.is_equal(vec[i])) {
1205 idx.push_back(i);
1206 found = true;
1207 break;
1208 }
1209 }
1210 }
1211 if (!found) {
1212 vec.push_back(v);
1213 idx.push_back(size);
1214 }
1215 }
1216 return make(vec, idx);
1217 }
1218
make_broadcast(const expr_t & expr,int elems)1219 static expr_t make_broadcast(const expr_t &expr, int elems) {
1220 ir_assert(expr.type().is_scalar()) << expr;
1221 return make({expr}, std::vector<int>(elems, 0));
1222 }
1223
1224 // Slices the existing shuffle expression. For inputs (S, beg, end) returns
1225 // (S[beg], S[beg + 1], ..., S[end - 1]) vector.
make(const expr_t & _shuffle,int beg,int end)1226 static expr_t make(const expr_t &_shuffle, int beg, int end) {
1227 auto &shuffle = _shuffle.as<shuffle_t>();
1228 ir_assert(beg >= 0 && beg <= shuffle.elems());
1229 ir_assert(end >= 0 && end <= shuffle.elems());
1230 ir_assert(beg < end);
1231 std::vector<expr_t> vec;
1232 std::vector<int> idx(end - beg, -1);
1233 for (int i = beg; i < end; i++) {
1234 if (idx[i - beg] != -1) continue;
1235 int old_idx = shuffle.idx[i];
1236 vec.push_back(shuffle.vec[old_idx]);
1237 for (int j = i; j < end; j++) {
1238 if (shuffle.idx[j] == old_idx)
1239 idx[j - beg] = int(vec.size()) - 1;
1240 }
1241 }
1242 return make(vec, idx);
1243 }
1244
is_equal(const object_impl_t & obj) const1245 bool is_equal(const object_impl_t &obj) const override {
1246 if (!obj.is<self_type>()) return false;
1247 auto &other = obj.as<self_type>();
1248
1249 return ir_utils::is_equal(vec, other.vec)
1250 && ir_utils::is_equal(idx, other.idx);
1251 }
1252
get_hash() const1253 size_t get_hash() const override { return ir_utils::get_hash(vec, idx); }
1254
elems() const1255 int elems() const { return int(idx.size()); }
1256
is_vector() const1257 bool is_vector() const {
1258 for (int i = 0; i < elems(); i++)
1259 if (idx[i] != i) return false;
1260 return true;
1261 }
1262
is_broadcast() const1263 bool is_broadcast() const { return vec.size() == 1; }
1264
1265 IR_DECLARE_TRAVERSERS()
1266
1267 std::vector<expr_t> vec;
1268 std::vector<int> idx;
1269
1270 private:
shuffle_t(const std::vector<expr_t> & vec,const std::vector<int> & idx)1271 shuffle_t(const std::vector<expr_t> &vec, const std::vector<int> &idx)
1272 : expr_impl_t(shuffle_type(vec, idx)), vec(vec), idx(idx) {
1273 ir_assert(idx.size() > 1) << "Unexpected empty or scalar shuffle.";
1274 }
1275
shuffle_type(const std::vector<expr_t> & vec,const std::vector<int> & idx)1276 static type_t shuffle_type(
1277 const std::vector<expr_t> &vec, const std::vector<int> &idx) {
1278 ir_assert(!vec.empty() && !idx.empty());
1279
1280 auto elem_type = vec[0].type();
1281 for (auto &v : vec)
1282 elem_type = common_type(elem_type, v.type());
1283
1284 for (size_t i = 0; i < idx.size(); i++) {
1285 ir_assert(idx[i] >= 0 && idx[i] < int(vec.size()))
1286 << "Incorrect index.";
1287 MAYBE_UNUSED(i);
1288 }
1289
1290 int elems = int(idx.size());
1291 return elem_type.with_elems(elems);
1292 }
1293 };
1294
1295 // Ternary operation: op(a, b, c).
1296 class ternary_op_t : public expr_impl_t {
1297 public:
IR_DECL_EXPR_TYPE_ID(ternary_op_t)1298 IR_DECL_EXPR_TYPE_ID(ternary_op_t)
1299
1300 static expr_t make(op_kind_t op_kind, const expr_t &a, const expr_t &b,
1301 const expr_t &c) {
1302 return expr_t(new ternary_op_t(op_kind, a, b, c));
1303 }
1304
is_equal(const object_impl_t & obj) const1305 bool is_equal(const object_impl_t &obj) const override {
1306 if (!obj.is<self_type>()) return false;
1307 auto &other = obj.as<self_type>();
1308
1309 return (op_kind == other.op_kind) && a.is_equal(other.a)
1310 && b.is_equal(other.b) && c.is_equal(other.c);
1311 }
1312
get_hash() const1313 size_t get_hash() const override {
1314 return ir_utils::get_hash(op_kind, a, b, c);
1315 }
1316
1317 IR_DECLARE_TRAVERSERS()
1318
1319 op_kind_t op_kind;
1320 expr_t a;
1321 expr_t b;
1322 expr_t c;
1323
1324 private:
ternary_op_t(op_kind_t op_kind,const expr_t & a,const expr_t & b,const expr_t & c)1325 ternary_op_t(op_kind_t op_kind, const expr_t &a, const expr_t &b,
1326 const expr_t &c)
1327 : expr_impl_t(ternary_op_type(op_kind, a, b, c))
1328 , op_kind(op_kind)
1329 , a(a)
1330 , b(b)
1331 , c(c) {}
1332 };
1333
ternary_mad(const expr_t & a,const expr_t & b,const expr_t & c)1334 inline expr_t ternary_mad(const expr_t &a, const expr_t &b, const expr_t &c) {
1335 return ternary_op_t::make(op_kind_t::_mad, a, b, c);
1336 }
1337
ternary_add3(const expr_t & a,const expr_t & b,const expr_t & c)1338 inline expr_t ternary_add3(const expr_t &a, const expr_t &b, const expr_t &c) {
1339 return ternary_op_t::make(op_kind_t::_add3, a, b, c);
1340 }
1341
1342 // Unary operation: (op a).
1343 class unary_op_t : public expr_impl_t {
1344 public:
IR_DECL_EXPR_TYPE_ID(unary_op_t)1345 IR_DECL_EXPR_TYPE_ID(unary_op_t)
1346
1347 static expr_t make(op_kind_t op_kind, const expr_t &a) {
1348 return expr_t(new unary_op_t(op_kind, a));
1349 }
1350
is_equal(const object_impl_t & obj) const1351 bool is_equal(const object_impl_t &obj) const override {
1352 if (!obj.is<self_type>()) return false;
1353 auto &other = obj.as<self_type>();
1354
1355 return (op_kind == other.op_kind) && a.is_equal(other.a);
1356 }
1357
get_hash() const1358 size_t get_hash() const override { return ir_utils::get_hash(op_kind, a); }
1359
1360 IR_DECLARE_TRAVERSERS()
1361
1362 op_kind_t op_kind;
1363 expr_t a;
1364
1365 private:
unary_op_t(op_kind_t op_kind,const expr_t & a)1366 unary_op_t(op_kind_t op_kind, const expr_t &a)
1367 : expr_impl_t(unary_op_type(op_kind, a)), op_kind(op_kind), a(a) {}
1368 };
1369
1370 class var_t : public expr_impl_t {
1371 public:
IR_DECL_EXPR_TYPE_ID(var_t)1372 IR_DECL_EXPR_TYPE_ID(var_t)
1373
1374 static expr_t make(const type_t &type, const std::string &name) {
1375 return expr_t(new var_t(type, name));
1376 }
1377
is_equal(const object_impl_t & obj) const1378 bool is_equal(const object_impl_t &obj) const override {
1379 // Do not allow variable cloning.
1380 return this == &obj;
1381 }
1382
get_hash() const1383 size_t get_hash() const override { return ir_utils::get_hash(name); }
1384
1385 IR_DECLARE_TRAVERSERS()
1386
1387 std::string name;
1388
1389 private:
var_t(const type_t & type,const std::string & name)1390 var_t(const type_t &type, const std::string &name)
1391 : expr_impl_t(type), name(name) {}
1392 };
1393
1394 // Convertor from C++ type to IR expression.
1395 template <typename T>
to_expr(T value,const type_t & type)1396 expr_t to_expr(T value, const type_t &type) {
1397 #define CASE(ir_type, cpp_type) \
1398 if (type == type_t::ir_type()) return expr_t((cpp_type)value)
1399
1400 CASE(_bool, bool);
1401 CASE(f32, float);
1402 CASE(s16, int16_t);
1403 CASE(s32, int32_t);
1404 CASE(s64, int64_t);
1405 CASE(u16, uint16_t);
1406 CASE(u32, uint32_t);
1407 CASE(u64, uint64_t);
1408
1409 #undef CASE
1410
1411 ir_error_not_expected() << type;
1412
1413 return expr_t();
1414 }
1415
1416 template <typename T>
to_expr(T value)1417 expr_t to_expr(T value) {
1418 return to_expr(value, type_t::from_cpp<T>());
1419 }
1420
is_binary_op(const expr_t & e)1421 inline bool is_binary_op(const expr_t &e) {
1422 return e.is<binary_op_t>();
1423 }
1424
is_binary_op(const expr_t & e,op_kind_t op_kind)1425 inline bool is_binary_op(const expr_t &e, op_kind_t op_kind) {
1426 if (!is_binary_op(e)) return false;
1427 return e.as<binary_op_t>().op_kind == op_kind;
1428 }
1429
is_binary_cmp_op(const expr_t & e)1430 inline bool is_binary_cmp_op(const expr_t &e) {
1431 if (!is_binary_op(e)) return false;
1432 return is_cmp_op(e.as<binary_op_t>().op_kind);
1433 }
1434
is_const(const expr_t & e)1435 inline bool is_const(const expr_t &e) {
1436 return e.is<bool_imm_t>() || e.is<int_imm_t>() || e.is<float_imm_t>();
1437 }
1438
is_shuffle_const(const expr_t & e)1439 inline bool is_shuffle_const(const expr_t &e) {
1440 auto *shuffle = e.as_ptr<shuffle_t>();
1441 if (!shuffle) return false;
1442 for (auto &v : shuffle->vec)
1443 if (!is_const(v)) return false;
1444 return true;
1445 }
1446
is_var(const expr_t & e)1447 inline bool is_var(const expr_t &e) {
1448 return e.is<var_t>();
1449 }
1450
1451 // Convertor from IR expression to C++ constant.
1452 template <typename T>
to_cpp(const expr_t & e)1453 T to_cpp(const expr_t &e) {
1454 ir_assert(is_const(e)) << "Expression must be constant.";
1455
1456 if (e.is<int_imm_t>()) return (T)e.as<int_imm_t>().value;
1457 if (e.is<float_imm_t>()) return (T)e.as<float_imm_t>().value;
1458 if (e.is<bool_imm_t>()) return (T)e.as<bool_imm_t>().value;
1459
1460 ir_error_not_expected();
1461 return 0;
1462 }
1463
1464 expr_t operator-(const expr_t &a);
1465
1466 #define DECLARE_BINARY_OPERATOR(op, op_kind) \
1467 expr_t operator op(const expr_t &a, const expr_t &b);
1468
1469 DECLARE_BINARY_OPERATOR(+, op_kind_t::_add)
1470 DECLARE_BINARY_OPERATOR(-, op_kind_t::_sub)
1471 DECLARE_BINARY_OPERATOR(*, op_kind_t::_mul)
1472 DECLARE_BINARY_OPERATOR(/, op_kind_t::_div)
1473 DECLARE_BINARY_OPERATOR(%, op_kind_t::_mod)
1474 DECLARE_BINARY_OPERATOR(<<, op_kind_t::_shl)
1475 DECLARE_BINARY_OPERATOR(>>, op_kind_t::_shr)
1476
1477 DECLARE_BINARY_OPERATOR(==, op_kind_t::_eq)
1478 DECLARE_BINARY_OPERATOR(!=, op_kind_t::_ne)
1479 DECLARE_BINARY_OPERATOR(>, op_kind_t::_gt)
1480 DECLARE_BINARY_OPERATOR(>=, op_kind_t::_ge)
1481 DECLARE_BINARY_OPERATOR(<, op_kind_t::_lt)
1482 DECLARE_BINARY_OPERATOR(<=, op_kind_t::_le)
1483
1484 DECLARE_BINARY_OPERATOR(&, op_kind_t::_and)
1485
1486 #undef DECLARE_BINARY_OPERATOR
1487
1488 // Returns a shifted pointer with base `a` (pointer) and offset `b` (in bytes).
1489 // shift_ptr(op, a, b) returns &(a op b) in C++ terms (op is either addition or
1490 // subtraction).
1491 expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b);
1492
1493 // Base class for IR statement objects.
1494 class stmt_impl_t : public object_impl_t {
1495 public:
1496 IR_DECL_TYPE_ID(stmt_impl_t)
1497 };
1498
1499 // Wrapper for IR statement objects.
1500 class stmt_t : public object_t {
1501 public:
1502 using object_t::object_t;
1503
1504 stmt_t() = default;
stmt_t(const object_t & obj)1505 stmt_t(const object_t &obj) : object_t(obj) {}
stmt_t(object_t && obj)1506 stmt_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)1507 stmt_t &operator=(const object_t &obj) {
1508 object_t::operator=(obj);
1509 return *this;
1510 }
operator =(object_t && obj)1511 stmt_t &operator=(object_t &&obj) {
1512 object_t::operator=(obj);
1513 return *this;
1514 }
1515
1516 stmt_t append(const stmt_t &s) const;
1517
1518 private:
sanity_check() const1519 void sanity_check() const override {
1520 ir_assert(dynamic_cast<const stmt_impl_t *>(impl()) == impl())
1521 << object_t(impl());
1522 }
1523 };
1524
1525 enum class alloc_kind_t {
1526 undef,
1527 grf, // GRF - general register file.
1528 slm, // SLM - shared local memory.
1529 global, // Global memory.
1530 };
1531
1532 class alloc_attr_impl_t : public object_impl_t {};
1533
1534 class alloc_attr_t : public object_t {
1535 public:
1536 using object_t::object_t;
1537
1538 alloc_attr_t() = default;
alloc_attr_t(const object_t & obj)1539 alloc_attr_t(const object_t &obj) : object_t(obj) {}
alloc_attr_t(object_t && obj)1540 alloc_attr_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)1541 alloc_attr_t &operator=(const object_t &obj) {
1542 object_t::operator=(obj);
1543 return *this;
1544 }
operator =(object_t && obj)1545 alloc_attr_t &operator=(object_t &&obj) {
1546 object_t::operator=(obj);
1547 return *this;
1548 }
1549
1550 private:
sanity_check() const1551 void sanity_check() const override {
1552 ir_assert(dynamic_cast<const alloc_attr_impl_t *>(impl()) == impl())
1553 << object_t(impl());
1554 }
1555 };
1556
1557 // Allocation attribute for GRF.
1558 class grf_alloc_attr_t : public alloc_attr_impl_t {
1559 public:
IR_DECL_TYPE_ID(grf_alloc_attr_t)1560 IR_DECL_TYPE_ID(grf_alloc_attr_t)
1561
1562 static alloc_attr_t make(const ngen_proxy::Bundle &bundle) {
1563 return alloc_attr_t(new grf_alloc_attr_t(bundle));
1564 }
1565
is_equal(const object_impl_t & obj) const1566 bool is_equal(const object_impl_t &obj) const override {
1567 if (!obj.is<self_type>()) return false;
1568 auto &other = obj.as<self_type>();
1569
1570 return bundle == other.bundle;
1571 }
1572
get_hash() const1573 size_t get_hash() const override {
1574 return ir_utils::get_hash(bundle.bundle_id, bundle.bank_id);
1575 }
1576
1577 ngen_proxy::Bundle bundle;
1578
1579 private:
grf_alloc_attr_t(const ngen_proxy::Bundle & bundle)1580 grf_alloc_attr_t(const ngen_proxy::Bundle &bundle) : bundle(bundle) {}
1581 };
1582
1583 // Allocation for SLM and GRF buffers.
1584 // C++ equivalent:
1585 // {
1586 // byte *buf = new byte[size];
1587 // body;
1588 // }
1589 class alloc_t : public stmt_impl_t {
1590 public:
IR_DECL_STMT_TYPE_ID(alloc_t)1591 IR_DECL_STMT_TYPE_ID(alloc_t)
1592
1593 static stmt_t make(const expr_t &buf, int size, alloc_kind_t kind,
1594 const alloc_attr_t &attr = {}, const stmt_t &body = {}) {
1595 return stmt_t(new alloc_t(buf, size, kind, attr, body));
1596 }
1597
is_equal(const object_impl_t & obj) const1598 bool is_equal(const object_impl_t &obj) const override {
1599 if (!obj.is<self_type>()) return false;
1600 auto &other = obj.as<self_type>();
1601
1602 return buf.is_equal(other.buf) && (size == other.size)
1603 && (kind == other.kind) && attr.is_equal(other.attr)
1604 && body.is_equal(other.body);
1605 }
1606
get_hash() const1607 size_t get_hash() const override {
1608 return ir_utils::get_hash(buf, size, kind, attr, body);
1609 }
1610
1611 IR_DECLARE_TRAVERSERS()
1612
1613 expr_t buf;
1614 int size;
1615 alloc_kind_t kind;
1616 alloc_attr_t attr;
1617 stmt_t body;
1618
1619 private:
alloc_t(const expr_t & buf,int size,alloc_kind_t kind,const alloc_attr_t & attr,const stmt_t & body)1620 alloc_t(const expr_t &buf, int size, alloc_kind_t kind,
1621 const alloc_attr_t &attr, const stmt_t &body)
1622 : buf(buf), size(size), kind(kind), attr(attr), body(body) {
1623 ir_assert(buf.type().is_ptr()) << buf;
1624 }
1625 };
1626
1627 // Store to a GRF buffer.
1628 // C++ equivalent (when value is scalar):
1629 // *(value_type *)(&buf[off]) = value;
1630 // C++ equivalent (when value is vector):
1631 // int _stride = (has_default_stride() ? sizeof(scalar_type) : stride);
1632 // for (int i = 0; i < elems; i++) {
1633 // *(scalar_type *)(&buf[off + i * _stride]) = value[i];
1634 // }
1635 class store_t : public stmt_impl_t {
1636 public:
IR_DECL_STMT_TYPE_ID(store_t)1637 IR_DECL_STMT_TYPE_ID(store_t)
1638
1639 // offset and stride are expressed in bytes.
1640 // default stride means unit stride (in terms of value.type().scalar()
1641 // elements).
1642 static stmt_t make(const expr_t &buf, const expr_t &off,
1643 const expr_t &value, int stride = default_stride,
1644 const expr_t &mask = expr_t()) {
1645 return stmt_t(new store_t(buf, off, value, stride, mask));
1646 }
1647
is_equal(const object_impl_t & obj) const1648 bool is_equal(const object_impl_t &obj) const override {
1649 if (!obj.is<self_type>()) return false;
1650 auto &other = obj.as<self_type>();
1651
1652 return buf.is_equal(other.buf) && off.is_equal(other.off)
1653 && value.is_equal(other.value) && (stride == other.stride)
1654 && mask.is_equal(other.mask);
1655 }
1656
get_hash() const1657 size_t get_hash() const override {
1658 return ir_utils::get_hash(buf, off, value, stride, mask);
1659 }
1660
has_default_stride() const1661 bool has_default_stride() const { return stride == default_stride; }
1662
1663 IR_DECLARE_TRAVERSERS()
1664
1665 static const int default_stride = -1;
1666
1667 expr_t buf;
1668 expr_t off;
1669 expr_t value;
1670 int stride;
1671 expr_t mask;
1672
1673 private:
store_t(const expr_t & _buf,const expr_t & _off,const expr_t & _value,int _stride,const expr_t & _mask)1674 store_t(const expr_t &_buf, const expr_t &_off, const expr_t &_value,
1675 int _stride, const expr_t &_mask)
1676 : buf(_buf), off(_off), value(_value), stride(_stride), mask(_mask) {
1677 normalize_ptr(value.type(), buf, off);
1678 ir_assert(is_var(buf)) << buf;
1679 ir_assert(buf.type().is_ptr()) << buf;
1680 if (stride == value.type().scalar().size()) stride = default_stride;
1681 if (!mask.is_empty())
1682 ir_assert(mask.type() == type_t::_bool(value.type().elems()));
1683 }
1684 };
1685
1686 // Loop statement with unit increment.
1687 // C++ equivalent:
1688 // for (var = init; var < bound; var++) {
1689 // body;
1690 // }
1691 // unroll specifies the unroll factor, unroll = 1 means no unrolling.
1692 class for_t : public stmt_impl_t {
1693 public:
IR_DECL_STMT_TYPE_ID(for_t)1694 IR_DECL_STMT_TYPE_ID(for_t)
1695
1696 static stmt_t make(const expr_t &var, const expr_t &init,
1697 const expr_t &bound, const stmt_t &body = {}, int unroll = 1) {
1698 return stmt_t(new for_t(var, init, bound, body, unroll));
1699 }
1700
is_equal(const object_impl_t & obj) const1701 bool is_equal(const object_impl_t &obj) const override {
1702 if (!obj.is<self_type>()) return false;
1703 auto &other = obj.as<self_type>();
1704
1705 return var.is_equal(other.var) && init.is_equal(other.init)
1706 && bound.is_equal(other.bound) && body.is_equal(other.body)
1707 && (unroll == other.unroll);
1708 }
1709
get_hash() const1710 size_t get_hash() const override {
1711 return ir_utils::get_hash(var, init, bound, body, unroll);
1712 }
1713
1714 IR_DECLARE_TRAVERSERS()
1715
1716 expr_t var;
1717 expr_t init;
1718 expr_t bound;
1719 stmt_t body;
1720 int unroll;
1721
1722 private:
for_t(const expr_t & var,const expr_t & init,const expr_t & bound,const stmt_t & body,int unroll)1723 for_t(const expr_t &var, const expr_t &init, const expr_t &bound,
1724 const stmt_t &body, int unroll)
1725 : var(var), init(init), bound(bound), body(body), unroll(unroll) {}
1726 };
1727
1728 // If-else statement.
1729 // C++ equivalent:
1730 // if (cond) {
1731 // body;
1732 // } else {
1733 // else_body;
1734 // }
1735 class if_t : public stmt_impl_t {
1736 public:
IR_DECL_STMT_TYPE_ID(if_t)1737 IR_DECL_STMT_TYPE_ID(if_t)
1738
1739 static stmt_t make(const expr_t &cond, const stmt_t &body,
1740 const stmt_t &else_body = stmt_t()) {
1741 return stmt_t(new if_t(cond, body, else_body));
1742 }
1743
is_equal(const object_impl_t & obj) const1744 bool is_equal(const object_impl_t &obj) const override {
1745 if (!obj.is<self_type>()) return false;
1746 auto &other = obj.as<self_type>();
1747
1748 return cond.is_equal(other.cond) && body.is_equal(other.body)
1749 && else_body.is_equal(other.else_body);
1750 }
1751
get_hash() const1752 size_t get_hash() const override {
1753 return ir_utils::get_hash(cond, body, else_body);
1754 }
1755
1756 IR_DECLARE_TRAVERSERS()
1757
1758 expr_t cond;
1759 stmt_t body;
1760 stmt_t else_body;
1761
1762 private:
if_t(const expr_t & cond,const stmt_t & body,const stmt_t & else_body)1763 if_t(const expr_t &cond, const stmt_t &body, const stmt_t &else_body)
1764 : cond(cond), body(body), else_body(else_body) {}
1765 };
1766
1767 // Let statement, used to bind a variable to a value within a scope.
1768 // C++ equivalent:
1769 // {
1770 // var = value;
1771 // body;
1772 // }
1773 class let_t : public stmt_impl_t {
1774 public:
IR_DECL_STMT_TYPE_ID(let_t)1775 IR_DECL_STMT_TYPE_ID(let_t)
1776
1777 static stmt_t make(
1778 const expr_t &var, const expr_t &value, const stmt_t &body = {}) {
1779 return stmt_t(new let_t(var, value, body));
1780 }
1781
is_equal(const object_impl_t & obj) const1782 bool is_equal(const object_impl_t &obj) const override {
1783 if (!obj.is<self_type>()) return false;
1784 auto &other = obj.as<self_type>();
1785
1786 return var.is_equal(other.var) && value.is_equal(other.value)
1787 && body.is_equal(other.body);
1788 }
1789
get_hash() const1790 size_t get_hash() const override {
1791 return ir_utils::get_hash(var, value, body);
1792 }
1793
1794 IR_DECLARE_TRAVERSERS()
1795
1796 expr_t var;
1797 expr_t value;
1798 stmt_t body;
1799
1800 private:
let_t(const expr_t & var,const expr_t & value,const stmt_t & body)1801 let_t(const expr_t &var, const expr_t &value, const stmt_t &body)
1802 : var(var), value(value), body(body) {}
1803 };
1804
1805 // Statement label, specific to GEMM/convolution.
1806 class stmt_label_t {
1807 public:
kernel(int index=-1)1808 static stmt_label_t kernel(int index = -1) {
1809 return stmt_label_t(kind_t::_kernel, index);
1810 }
compute_loop(int index=-1)1811 static stmt_label_t compute_loop(int index = -1) {
1812 return stmt_label_t(kind_t::_compute_loop, index);
1813 }
c_store(int index=-1)1814 static stmt_label_t c_store(int index = -1) {
1815 return stmt_label_t(kind_t::_c_store, index);
1816 }
c_zero_out(int index=-1)1817 static stmt_label_t c_zero_out(int index = -1) {
1818 return stmt_label_t(kind_t::_c_zero_out, index);
1819 }
b_reduced_zero_out(int index=-1)1820 static stmt_label_t b_reduced_zero_out(int index = -1) {
1821 return stmt_label_t(kind_t::_b_reduced_zero_out, index);
1822 }
g2s_load(int index=-1)1823 static stmt_label_t g2s_load(int index = -1) {
1824 return stmt_label_t(kind_t::_g2s_load, index);
1825 }
g2s_store(int index=-1)1826 static stmt_label_t g2s_store(int index = -1) {
1827 return stmt_label_t(kind_t::_g2s_store, index);
1828 }
g2r_load(int index=-1)1829 static stmt_label_t g2r_load(int index = -1) {
1830 return stmt_label_t(kind_t::_g2r_load, index);
1831 }
s2r_load(int index=-1)1832 static stmt_label_t s2r_load(int index = -1) {
1833 return stmt_label_t(kind_t::_s2r_load, index);
1834 }
prefetch(int index=-1)1835 static stmt_label_t prefetch(int index = -1) {
1836 return stmt_label_t(kind_t::_prefetch, index);
1837 }
mul(int index=-1)1838 static stmt_label_t mul(int index = -1) {
1839 return stmt_label_t(kind_t::_mul, index);
1840 }
1841
operator ==(const stmt_label_t & other) const1842 bool operator==(const stmt_label_t &other) const {
1843 if (kind_ != other.kind_) return false;
1844 if (index_ == -1 || other.index_ == -1) return true;
1845 return index_ == other.index_;
1846 }
1847
get_hash() const1848 size_t get_hash() const { return ir_utils::get_hash(kind_, index_); }
1849
str() const1850 std::string str() const {
1851 switch (kind_) {
1852 #define CASE(kind) \
1853 case kind_t::_##kind: return #kind
1854 CASE(kernel);
1855 CASE(compute_loop);
1856 CASE(c_store);
1857 CASE(c_zero_out);
1858 CASE(g2r_load);
1859 CASE(g2s_load);
1860 CASE(g2s_store);
1861 CASE(s2r_load);
1862 CASE(prefetch);
1863 CASE(mul);
1864 #undef CASE
1865 default: ir_error_not_expected();
1866 }
1867 return {};
1868 }
1869
1870 private:
1871 enum class kind_t {
1872 _undef,
1873 _kernel, // All kernel.
1874 _compute_loop, // Compute loop.
1875 _c_store, // GRF to GMEM store of C.
1876 _c_zero_out, // Zeroing-out of C.
1877 _b_reduced_zero_out, // Zeroing-out of B reduced buffer.
1878 _g2r_load, // GMEM to GRF load for further multiplication.
1879 _g2s_load, // GMEM to GRF load for GMEM -> SLM copy.
1880 _g2s_store, // GRF to SLM store for GMEM -> SLM copy.
1881 _s2r_load, // SLM to GRF load for further multiplication.
1882 _prefetch, // GMEM prefetch.
1883 _mul, // Multiplication.
1884 };
1885
stmt_label_t()1886 stmt_label_t() : kind_(kind_t::_undef), index_(-1) {}
stmt_label_t(kind_t kind,int index)1887 stmt_label_t(kind_t kind, int index) : kind_(kind), index_(index) {}
1888
1889 kind_t kind_;
1890 int index_; // Used to differentiate groups with the same kind.
1891 };
1892
operator <<(std::ostream & out,const stmt_label_t & label)1893 inline std::ostream &operator<<(std::ostream &out, const stmt_label_t &label) {
1894 out << label.str();
1895 return out;
1896 }
1897
1898 // Statement group, used to assign a label to a group of statements.
1899 class stmt_group_t : public stmt_impl_t {
1900 public:
IR_DECL_STMT_TYPE_ID(stmt_group_t)1901 IR_DECL_STMT_TYPE_ID(stmt_group_t)
1902
1903 static stmt_t make(const stmt_label_t &label, const stmt_t &body) {
1904 return stmt_t(new stmt_group_t(label, body));
1905 }
1906
is_equal(const object_impl_t & obj) const1907 bool is_equal(const object_impl_t &obj) const override {
1908 if (!obj.is<self_type>()) return false;
1909 auto &other = obj.as<self_type>();
1910
1911 return (label == other.label) && body.is_equal(other.body);
1912 }
1913
get_hash() const1914 size_t get_hash() const override { return ir_utils::get_hash(label, body); }
1915
1916 IR_DECLARE_TRAVERSERS()
1917
1918 stmt_label_t label;
1919 stmt_t body;
1920
1921 private:
stmt_group_t(const stmt_label_t & label,const stmt_t & body)1922 stmt_group_t(const stmt_label_t &label, const stmt_t &body)
1923 : label(label), body(body) {}
1924 };
1925
1926 // Statement sequence, allows combining two statements.
1927 // C++ equivalent:
1928 // {
1929 // head;
1930 // tail;
1931 // }
1932 class stmt_seq_t : public stmt_impl_t {
1933 public:
IR_DECL_STMT_TYPE_ID(stmt_seq_t)1934 IR_DECL_STMT_TYPE_ID(stmt_seq_t)
1935
1936 static stmt_t make(const stmt_t &head, const stmt_t &tail) {
1937 return stmt_t(new stmt_seq_t(head, tail));
1938 }
1939
is_equal(const object_impl_t & obj) const1940 bool is_equal(const object_impl_t &obj) const override {
1941 if (!obj.is<self_type>()) return false;
1942 auto &other = obj.as<self_type>();
1943
1944 return head.is_equal(other.head) && tail.is_equal(other.tail);
1945 }
1946
get_hash() const1947 size_t get_hash() const override { return ir_utils::get_hash(head, tail); }
1948
1949 IR_DECLARE_TRAVERSERS()
1950
1951 stmt_t head;
1952 stmt_t tail;
1953
1954 private:
stmt_seq_t(const stmt_t & head,const stmt_t & tail)1955 stmt_seq_t(const stmt_t &head, const stmt_t &tail)
1956 : head(head), tail(tail) {}
1957 };
1958
append(const stmt_t & s) const1959 inline stmt_t stmt_t::append(const stmt_t &s) const {
1960 if (is_empty()) return s;
1961 return stmt_seq_t::make(*this, s);
1962 }
1963
1964 // Function call attribute.
1965 class func_call_attr_impl_t : public object_impl_t {};
1966
1967 class func_call_attr_t : public object_t {
1968 public:
1969 using object_t::object_t;
1970
1971 func_call_attr_t() = default;
func_call_attr_t(const object_t & obj)1972 func_call_attr_t(const object_t &obj) : object_t(obj) {}
func_call_attr_t(object_t && obj)1973 func_call_attr_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)1974 func_call_attr_t &operator=(const object_t &obj) {
1975 object_t::operator=(obj);
1976 return *this;
1977 }
operator =(object_t && obj)1978 func_call_attr_t &operator=(object_t &&obj) {
1979 object_t::operator=(obj);
1980 return *this;
1981 }
1982
1983 // Returns a function call with the attribute applied. The input statement
1984 // must be a function call.
1985 stmt_t apply_to(const stmt_t &s) const;
1986
1987 private:
sanity_check() const1988 void sanity_check() const override {
1989 ir_assert(dynamic_cast<const func_call_attr_impl_t *>(impl()) == impl())
1990 << object_t(impl());
1991 }
1992 };
1993
1994 // Instruction modifier, relies on nGEN API.
1995 class instruction_modifier_attr_t : public func_call_attr_impl_t {
1996 public:
IR_DECL_TYPE_ID(instruction_modifier_attr_t)1997 IR_DECL_TYPE_ID(instruction_modifier_attr_t)
1998
1999 static func_call_attr_t make(const ngen_proxy::InstructionModifier &mod) {
2000 return func_call_attr_t(new instruction_modifier_attr_t(mod));
2001 }
2002
is_equal(const object_impl_t & obj) const2003 bool is_equal(const object_impl_t &obj) const override {
2004 if (!obj.is<self_type>()) return false;
2005 auto &other = obj.as<self_type>();
2006
2007 return mod == other.mod;
2008 }
2009
get_hash() const2010 size_t get_hash() const override { return ir_utils::get_hash(mod); }
2011
str() const2012 std::string str() const override {
2013 std::ostringstream oss;
2014 oss << "{";
2015 bool is_first = true;
2016 auto append = [&](const std::string &s) {
2017 if (!is_first) oss << ", ";
2018 oss << s;
2019 is_first = false;
2020 };
2021 if (mod.is_atomic) append("Atomic");
2022 if (!mod.sbid.is_empty()) {
2023 append(std::string("$") + std::to_string(mod.sbid.token));
2024 }
2025 oss << "}";
2026 return oss.str();
2027 }
2028
2029 ngen_proxy::InstructionModifier mod;
2030
2031 private:
instruction_modifier_attr_t(const ngen_proxy::InstructionModifier & mod)2032 instruction_modifier_attr_t(const ngen_proxy::InstructionModifier &mod)
2033 : mod(mod) {}
2034 };
2035
2036 // Base class for function IR objects.
2037 class func_impl_t : public object_impl_t {
2038 public:
2039 IR_DECL_TYPE_ID(func_impl_t)
2040
2041 stmt_t call(const std::vector<expr_t> &args,
2042 const func_call_attr_t &attr = {}) const;
2043 };
2044
2045 // Wrapper for IR function objects.
2046 class func_t : public object_t {
2047 public:
2048 using object_t::object_t;
2049
2050 func_t() = default;
func_t(const object_t & obj)2051 func_t(const object_t &obj) : object_t(obj) {}
func_t(object_t && obj)2052 func_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)2053 func_t &operator=(const object_t &obj) {
2054 object_t::operator=(obj);
2055 return *this;
2056 }
operator =(object_t && obj)2057 func_t &operator=(object_t &&obj) {
2058 object_t::operator=(obj);
2059 return *this;
2060 }
2061
call(const std::vector<expr_t> & args={},const func_call_attr_t & attr={}) const2062 stmt_t call(const std::vector<expr_t> &args = {},
2063 const func_call_attr_t &attr = {}) const {
2064 return ((const func_impl_t *)impl())->call(args, attr);
2065 }
2066
2067 private:
sanity_check() const2068 void sanity_check() const override {
2069 ir_assert(dynamic_cast<const func_impl_t *>(impl()) == impl())
2070 << object_t(impl());
2071 }
2072 };
2073
2074 // Function call.
2075 class func_call_t : public stmt_impl_t {
2076 public:
IR_DECL_STMT_TYPE_ID(func_call_t)2077 IR_DECL_STMT_TYPE_ID(func_call_t)
2078
2079 static stmt_t make(const func_t &func, const std::vector<expr_t> &args,
2080 const func_call_attr_t &attr = {}) {
2081 return stmt_t(new func_call_t(func, args, attr));
2082 }
2083
is_equal(const object_impl_t & obj) const2084 bool is_equal(const object_impl_t &obj) const override {
2085 if (!obj.is<self_type>()) return false;
2086 auto &other = obj.as<self_type>();
2087
2088 return func.is_equal(other.func) && ir_utils::is_equal(args, other.args)
2089 && attr.is_equal(other.attr);
2090 }
2091
get_hash() const2092 size_t get_hash() const override {
2093 return ir_utils::get_hash(func, args, attr);
2094 }
2095
2096 IR_DECLARE_TRAVERSERS()
2097
2098 func_t func;
2099 std::vector<expr_t> args;
2100 func_call_attr_t attr;
2101
2102 private:
func_call_t(const func_t & func,const std::vector<expr_t> & args,const func_call_attr_t & attr)2103 func_call_t(const func_t &func, const std::vector<expr_t> &args,
2104 const func_call_attr_t &attr)
2105 : func(func), args(args), attr(attr) {
2106 ir_assert(!func.is_empty());
2107 }
2108 };
2109
call(const std::vector<expr_t> & args,const func_call_attr_t & attr) const2110 inline stmt_t func_impl_t::call(
2111 const std::vector<expr_t> &args, const func_call_attr_t &attr) const {
2112 return func_call_t::make(this, args, attr);
2113 }
2114
apply_to(const stmt_t & s) const2115 inline stmt_t func_call_attr_t::apply_to(const stmt_t &s) const {
2116 auto &c = s.as<func_call_t>();
2117 ir_assert(c.attr.is_empty())
2118 << "Merging of attributes is not supported: " << s;
2119 return func_call_t::make(c.func, c.args, *this);
2120 }
2121
2122 template <typename F>
is_func_call(const stmt_t & s)2123 inline bool is_func_call(const stmt_t &s) {
2124 auto *c = s.as_ptr<func_call_t>();
2125 if (!c) return false;
2126 return c->func.is<F>();
2127 }
2128
2129 // Generic function with a name.
2130 class builtin_t : public func_impl_t {
2131 public:
IR_DECL_DERIVED_TYPE_ID(builtin_t,func_impl_t)2132 IR_DECL_DERIVED_TYPE_ID(builtin_t, func_impl_t)
2133
2134 static func_t make(const std::string &name) {
2135 return func_t(new builtin_t(name));
2136 }
2137
is_equal(const object_impl_t & obj) const2138 bool is_equal(const object_impl_t &obj) const override {
2139 if (!obj.is<self_type>()) return false;
2140 auto &other = obj.as<self_type>();
2141
2142 return name == other.name;
2143 }
2144
get_hash() const2145 size_t get_hash() const override { return ir_utils::get_hash(name); }
2146
str() const2147 std::string str() const override { return name; }
2148
2149 std::string name;
2150
2151 private:
builtin_t(const std::string & name)2152 builtin_t(const std::string &name) : name(name) {}
2153 };
2154
2155 } // namespace jit
2156 } // namespace gpu
2157 } // namespace impl
2158 } // namespace dnnl
2159
2160 #endif
2161