1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_JIT_CONV_IR_CORE_HPP
18 #define GPU_JIT_CONV_IR_CORE_HPP
19 
20 #include <algorithm>
21 #include <atomic>
22 #include <cstdio>
23 #include <numeric>
24 #include <string>
25 
26 #include "common/c_types_map.hpp"
27 #include "gpu/jit/conv/ngen_proxy.hpp"
28 #include "gpu/jit/conv/utils.hpp"
29 
30 // All IR expression objects.
31 #define HANDLE_EXPR_IR_OBJECTS() \
32     HANDLE_IR_OBJECT(binary_op_t) \
33     HANDLE_IR_OBJECT(bool_imm_t) \
34     HANDLE_IR_OBJECT(cast_t) \
35     HANDLE_IR_OBJECT(float_imm_t) \
36     HANDLE_IR_OBJECT(iif_t) \
37     HANDLE_IR_OBJECT(int_imm_t) \
38     HANDLE_IR_OBJECT(load_t) \
39     HANDLE_IR_OBJECT(ptr_t) \
40     HANDLE_IR_OBJECT(shuffle_t) \
41     HANDLE_IR_OBJECT(ternary_op_t) \
42     HANDLE_IR_OBJECT(unary_op_t) \
43     HANDLE_IR_OBJECT(var_t)
44 
45 // All IR statement objects.
46 #define HANDLE_STMT_IR_OBJECTS() \
47     HANDLE_IR_OBJECT(alloc_t) \
48     HANDLE_IR_OBJECT(for_t) \
49     HANDLE_IR_OBJECT(func_call_t) \
50     HANDLE_IR_OBJECT(if_t) \
51     HANDLE_IR_OBJECT(let_t) \
52     HANDLE_IR_OBJECT(stmt_group_t) \
53     HANDLE_IR_OBJECT(stmt_seq_t) \
54     HANDLE_IR_OBJECT(store_t)
55 
56 #define HANDLE_MUTATE_TARGETS() \
57     HANDLE_EXPR_IR_OBJECTS() \
58     HANDLE_STMT_IR_OBJECTS() \
59     HANDLE_IR_OBJECT(func_impl_t) \
60     HANDLE_IR_OBJECT(nary_op_t) \
61     HANDLE_IR_OBJECT(pexpr_t)
62 
63 #define HANDLE_ALL_IR_OBJECTS() \
64     HANDLE_EXPR_IR_OBJECTS() \
65     HANDLE_STMT_IR_OBJECTS() \
66     HANDLE_IR_OBJECT(func_impl_t)
67 
68 enum ir_type_id_t {
69 #define HANDLE_IR_OBJECT(type) type,
70 
71     // Create typeid for objects which can be visited/mutated. These need to be
72     // first as the typeid is used as an index into an array to dispatch to the
73     // correct mutate function.
74     HANDLE_ALL_IR_OBJECTS()
75 
76     //Used to calculate number of IR objects that can be visited/mutated
77     end_visitable_ir_objects,
78 
79     // Other IR object
80     expr_impl_t = end_visitable_ir_objects,
81     nary_op_t,
82     stmt_impl_t,
83     grf_alloc_attr_t,
84     instruction_modifier_attr_t,
85     builtin_t,
86     pexpr_t,
87     pint_imm_t,
88     factored_expr_t,
89     send_t,
90     dpas_t,
91     mad_t,
92     reduce_t,
93     reorder_t,
94     eltwise_t,
95 
96 #undef HANDLE_IR_OBJECT
97 };
98 
99 // Auxiliary macros to reduce boilerplate.
100 #define IR_DECL_TYPE_ID(class_name) \
101     using self_type = class_name; \
102     static int64_t _type_id() { return ir_type_id_t::class_name; } \
103     static int64_t _dispatch_type_id() { return _type_id(); } \
104     int64_t type_id() const override { return _type_id(); }
105 
106 #define IR_DECL_DERIVED_TYPE_ID(class_name, base_name) \
107     using self_type = class_name; \
108     static int64_t _type_id() { return ir_type_id_t::class_name; } \
109     int64_t type_id() const override { return _type_id(); } \
110     static int64_t _dispatch_type_id() { return base_name::_type_id(); } \
111     int64_t dispatch_type_id() const override { return _dispatch_type_id(); }
112 
113 #define IR_DECL_EXPR_TYPE_ID(class_name) \
114     IR_DECL_TYPE_ID(class_name) \
115     bool is_expr() const override { return true; }
116 
117 #define IR_DECL_STMT_TYPE_ID(class_name) \
118     IR_DECL_TYPE_ID(class_name) \
119     bool is_stmt() const override { return true; }
120 
121 #define IR_DECL_MUTATE(mutator_template) \
122     object_t _mutate(mutator_template &mutator) const override { \
123         return mutator._mutate(*this); \
124     }
125 
126 #define IR_DECLARE_TRAVERSERS() IR_DECL_MUTATE(ir_mutator_t)
127 
128 // Defines getter for a function argument.
129 #define IR_DEFINE_ARG_GET(name, index) \
130     static const expr_t &arg_##name(const stmt_t &s) { \
131         ir_assert(s.is<func_call_t>()) << s; \
132         auto &c = s.as<func_call_t>(); \
133         ir_assert(c.func.is<self_type>()) << s; \
134         return c.args[index]; \
135     } \
136     template <typename T> \
137     static T &arg_##name(std::vector<T> &args) { \
138         return args[index]; \
139     } \
140     template <typename T> \
141     static const T &arg_##name(const std::vector<T> &args) { \
142         return args[index]; \
143     }
144 
145 #if defined(__GNUC__)
146 // clang-format off
147 // Defines dump() method for debugging purposes, to pretty print the object.
148 #define IR_DEFINE_DUMP() \
149     __attribute__((noinline)) \
150     __attribute__((used)) \
151     void dump() const { \
152         printf("%s\n", str().c_str()); \
153     }
154 // clang-format on
155 #else
156 #define IR_DEFINE_DUMP()
157 #endif
158 
159 namespace dnnl {
160 namespace impl {
161 namespace gpu {
162 namespace jit {
163 
164 enum class type_kind_t {
165     undef,
166     _bool,
167 
168     // Integer types.
169     u8,
170     s8,
171     u16,
172     s16,
173     u32,
174     s32,
175     u64,
176     s64,
177 
178     // Floating point types.
179     bf16,
180     f16,
181     f32,
182 
183     // Message data types.
184     byte,
185     dword,
186     qword,
187     oword,
188     hword
189 };
190 
191 std::string to_string(type_kind_t kind);
192 
193 class type_t {
194 public:
undef()195     static type_t undef() { return type_t(type_kind_t::undef); }
_bool(int elems=1)196     static type_t _bool(int elems = 1) {
197         return type_t(type_kind_t::_bool, elems);
198     }
199 
u8(int elems=1)200     static type_t u8(int elems = 1) { return type_t(type_kind_t::u8, elems); }
s8(int elems=1)201     static type_t s8(int elems = 1) { return type_t(type_kind_t::s8, elems); }
u16(int elems=1)202     static type_t u16(int elems = 1) { return type_t(type_kind_t::u16, elems); }
s16(int elems=1)203     static type_t s16(int elems = 1) { return type_t(type_kind_t::s16, elems); }
u32(int elems=1)204     static type_t u32(int elems = 1) { return type_t(type_kind_t::u32, elems); }
s32(int elems=1)205     static type_t s32(int elems = 1) { return type_t(type_kind_t::s32, elems); }
u64(int elems=1)206     static type_t u64(int elems = 1) { return type_t(type_kind_t::u64, elems); }
s64(int elems=1)207     static type_t s64(int elems = 1) { return type_t(type_kind_t::s64, elems); }
208 
209     // Returns unsigned integer type.
u(int bits,int elems=1)210     static type_t u(int bits, int elems = 1) {
211         switch (bits) {
212             case 8: return u8(elems);
213             case 16: return u16(elems);
214             case 32: return u32(elems);
215             case 64: return u64(elems);
216             default: ir_error_not_expected();
217         }
218         return type_t::undef();
219     }
220 
221     // Returns signed integer type.
s(int bits,int elems=1)222     static type_t s(int bits, int elems = 1) {
223         switch (bits) {
224             case 8: return s8(elems);
225             case 16: return s16(elems);
226             case 32: return s32(elems);
227             case 64: return s64(elems);
228             default: ir_error_not_expected();
229         }
230         return type_t::undef();
231     }
232 
bf16(int elems=1)233     static type_t bf16(int elems = 1) {
234         return type_t(type_kind_t::bf16, elems);
235     }
f16(int elems=1)236     static type_t f16(int elems = 1) { return type_t(type_kind_t::f16, elems); }
f32(int elems=1)237     static type_t f32(int elems = 1) { return type_t(type_kind_t::f32, elems); }
238 
byte(int elems=1)239     static type_t byte(int elems = 1) {
240         return type_t(type_kind_t::byte, elems);
241     }
byte_ptr(int elems=1)242     static type_t byte_ptr(int elems = 1) {
243         return type_t(type_kind_t::byte, elems).with_ptr();
244     }
dword(int elems=1)245     static type_t dword(int elems = 1) {
246         return type_t(type_kind_t::dword, elems);
247     }
qword(int elems=1)248     static type_t qword(int elems = 1) {
249         return type_t(type_kind_t::qword, elems);
250     }
oword(int elems=1)251     static type_t oword(int elems = 1) {
252         return type_t(type_kind_t::oword, elems);
253     }
hword(int elems=1)254     static type_t hword(int elems = 1) {
255         return type_t(type_kind_t::hword, elems);
256     }
257 
258     template <typename T>
from_cpp()259     static type_t from_cpp() {
260 #define CASE(cpp_type, type) \
261     if (std::is_same<T, cpp_type>::value) return type()
262 
263         CASE(bool, _bool);
264         CASE(float, f32);
265         CASE(int16_t, s16);
266         CASE(int32_t, s32);
267         CASE(int64_t, s64);
268         CASE(uint16_t, u16);
269         CASE(uint32_t, u32);
270         CASE(uint64_t, u64);
271 
272 #undef CASE
273 
274         ir_error_not_expected();
275 
276         return undef();
277     }
278 
is_vector(int elems)279     static bool is_vector(int elems) { return elems != 1; }
280 
type_t()281     type_t() : type_t(type_t::undef()) {}
282 
type_t(type_kind_t kind,uint32_t elems=1)283     type_t(type_kind_t kind, uint32_t elems = 1) : kind_(kind), elems_(elems) {}
284 
285     // Constructor from dnnl_data_type_t.
type_t(data_type_t dt)286     type_t(data_type_t dt) {
287         elems_ = 1;
288         switch (dt) {
289 #define CASE(x) \
290     case data_type::x: kind_ = type_kind_t::x; break;
291             CASE(bf16);
292             CASE(f16);
293             CASE(f32);
294             CASE(s32);
295             CASE(s8);
296             CASE(u8);
297 #undef CASE
298             default: ir_error_not_expected();
299         }
300     }
301 
kind() const302     type_kind_t kind() const { return kind_; }
303 
elems() const304     int elems() const { return elems_; }
305 
is_ptr() const306     bool is_ptr() const { return is_ptr_; }
307 
operator ==(const type_t & other) const308     bool operator==(const type_t &other) const {
309         return (kind() == other.kind()) && (elems() == other.elems())
310                 && (is_ptr() == other.is_ptr());
311     }
312 
operator !=(const type_t & other) const313     bool operator!=(const type_t &other) const { return !operator==(other); }
314 
is_equal(const type_t & other) const315     bool is_equal(const type_t &other) const { return operator==(other); }
316 
get_hash() const317     size_t get_hash() const {
318         return ir_utils::get_hash(kind(), elems(), is_ptr());
319     }
320 
is_undef() const321     bool is_undef() const { return kind() == type_kind_t::undef; }
322 
is_vector() const323     bool is_vector() const { return type_t::is_vector(elems()); }
324 
is_bool() const325     bool is_bool() const { return kind() == type_kind_t::_bool; }
326 
is_fp() const327     bool is_fp() const {
328         return utils::one_of(
329                 kind(), type_kind_t::bf16, type_kind_t::f16, type_kind_t::f32);
330     }
331 
is_bf16() const332     bool is_bf16() const { return kind() == type_kind_t::bf16; }
is_f16() const333     bool is_f16() const { return kind() == type_kind_t::f16; }
is_f32() const334     bool is_f32() const { return kind() == type_kind_t::f32; }
335 
is_int() const336     bool is_int() const {
337         return utils::one_of(kind(), type_kind_t::u8, type_kind_t::s8,
338                 type_kind_t::u16, type_kind_t::s16, type_kind_t::u32,
339                 type_kind_t::s32, type_kind_t::u64, type_kind_t::s64);
340     }
341 
is_s8() const342     bool is_s8() const { return kind() == type_kind_t::s8; }
is_u8() const343     bool is_u8() const { return kind() == type_kind_t::u8; }
is_x8() const344     bool is_x8() const {
345         return utils::one_of(kind(), type_kind_t::s8, type_kind_t::u8);
346     }
347 
is_s16() const348     bool is_s16() const { return kind() == type_kind_t::s16; }
is_u16() const349     bool is_u16() const { return kind() == type_kind_t::u16; }
is_x16() const350     bool is_x16() const {
351         return utils::one_of(kind(), type_kind_t::s16, type_kind_t::u16);
352     }
353 
is_s32() const354     bool is_s32() const { return kind() == type_kind_t::s32; }
is_u32() const355     bool is_u32() const { return kind() == type_kind_t::u32; }
is_x32() const356     bool is_x32() const {
357         return utils::one_of(kind(), type_kind_t::s32, type_kind_t::u32);
358     }
359 
is_signed(int elems=-1) const360     bool is_signed(int elems = -1) const {
361         if (elems != -1 && elems_ != elems) return false;
362         return utils::one_of(kind(), type_kind_t::s8, type_kind_t::s16,
363                 type_kind_t::s32, type_kind_t::s64);
364     }
365 
is_unsigned(int elems=-1) const366     bool is_unsigned(int elems = -1) const {
367         if (elems != -1 && elems_ != elems) return false;
368         return utils::one_of(kind(), type_kind_t::u8, type_kind_t::u16,
369                 type_kind_t::u32, type_kind_t::u64);
370     }
371 
is_scalar() const372     bool is_scalar() const { return elems() == 1; }
373 
374     template <typename T>
is_cpp() const375     bool is_cpp() const {
376         return *this == type_t::from_cpp<T>();
377     }
378 
remove_elems() const379     type_t remove_elems() const { return with_elems(1); }
380 
remove_ptr() const381     type_t remove_ptr() const {
382         type_t copy = *this;
383         copy.is_ptr_ = false;
384         return copy;
385     }
386 
with_elems(int new_elems) const387     type_t with_elems(int new_elems) const {
388         type_t copy = *this;
389         copy.elems_ = new_elems;
390         return copy;
391     }
392 
with_ptr() const393     type_t with_ptr() const {
394         type_t copy = *this;
395         copy.is_ptr_ = true;
396         return copy;
397     }
398 
scalar() const399     type_t scalar() const { return with_elems(1); }
400 
401     // Returns size in bytes.
402     int size() const;
403 
str() const404     std::string str() const {
405         std::ostringstream oss;
406         oss << to_string(kind());
407         if (elems() > 1) oss << "x" << elems();
408         if (is_ptr()) oss << "*";
409         return oss.str();
410     }
411 
412     IR_DEFINE_DUMP()
413 
414 private:
415     type_kind_t kind_ = type_kind_t::undef;
416     int elems_ = 0;
417     bool is_ptr_ = false;
418 };
419 
operator <<(std::ostream & out,const type_t & type)420 inline std::ostream &operator<<(std::ostream &out, const type_t &type) {
421     out << type.str();
422     return out;
423 }
424 
425 // type_t to dnnl_data_type_t convertor.
426 data_type_t to_dnnl(const type_t &type);
427 
428 // Reference counter for IR objects.
429 class ref_count_t {
430 public:
ref_count_t()431     ref_count_t() : value_(0) {}
432     ref_count_t(const ref_count_t &) = delete;
433 
increment()434     uint32_t increment() { return ++value_; }
decrement()435     uint32_t decrement() { return --value_; }
436 
437 private:
438     std::atomic<uint32_t> value_;
439 };
440 
441 // Forward Declare IR objects
442 class object_t;
443 class ir_mutator_t;
444 #define HANDLE_IR_OBJECT(type) class type;
445 HANDLE_MUTATE_TARGETS()
446 #undef HANDLE_IR_OBJECT
447 
448 // Base class for all IR objects. Implemented as an intrusive pointer, with
449 // the reference counter stored inside the object.
450 class object_impl_t {
451 public:
452     object_impl_t() = default;
453 
454     object_impl_t(const object_impl_t &) = delete;
455 
456     virtual ~object_impl_t() = default;
457 
ref_count()458     ref_count_t &ref_count() { return ref_count_; }
459 
460     // Unique type ID.
461     virtual int64_t type_id() const = 0;
462 
463     // Type ID used for dispatching in ir_visitor_t and ir_mutator_t.
464     // For some IR objects
dispatch_type_id() const465     virtual int64_t dispatch_type_id() const { return type_id(); }
466 
467     // Provides equality semantics.
468     virtual bool is_equal(const object_impl_t &obj) const = 0;
469 
470     virtual size_t get_hash() const = 0;
471 
is_expr() const472     virtual bool is_expr() const { return false; }
is_stmt() const473     virtual bool is_stmt() const { return false; }
474 
475     // Downcasts the object to the IR type, returns a reference. The IR type
476     // must match the real IR type.
477     template <typename T>
as() const478     const T &as() const {
479         ir_assert(type_id() == T::_type_id());
480         return *(const T *)this;
481     }
482 
483     template <typename T>
as()484     T &as() {
485         ir_assert(type_id() == T::_type_id());
486         return *(T *)this;
487     }
488 
489     // Downcasts the object to the IR type, returns a pointer. If the IR type
490     // doesn't match the real IR type, returns nullptr.
491     template <typename T>
as_ptr() const492     const T *as_ptr() const {
493         if (type_id() != T::_type_id()) return nullptr;
494         return (const T *)this;
495     }
496 
497     template <typename T>
as_ptr()498     T *as_ptr() {
499         if (type_id() != T::_type_id()) return nullptr;
500         return (T *)this;
501     }
502 
503     // Returns true if T matches the real IR type.
504     template <typename T>
is() const505     bool is() const {
506         return type_id() == T::_type_id();
507     }
508 
509     virtual std::string str() const;
510 
511     virtual object_t _mutate(ir_mutator_t &mutator) const;
512     IR_DEFINE_DUMP()
513 
514 private:
515     ref_count_t ref_count_;
516 };
517 
518 // Base wrapper for IR objects.
519 class object_t {
520 public:
object_t(object_impl_t * impl=nullptr)521     object_t(object_impl_t *impl = nullptr) : impl_(impl) {
522         increment(impl_);
523 #ifndef NDEBUG
524         sanity_check();
525 #endif
526     }
object_t(const object_impl_t & impl)527     object_t(const object_impl_t &impl)
528         : object_t(const_cast<object_impl_t *>(&impl)) {}
object_t(const object_impl_t * impl)529     object_t(const object_impl_t *impl)
530         : object_t(const_cast<object_impl_t *>(impl)) {}
object_t(const object_t & obj)531     object_t(const object_t &obj) : object_t(obj.impl()) {}
object_t(object_t && obj)532     object_t(object_t &&obj) : impl_(obj.impl_) {
533         obj.impl_ = nullptr;
534 #ifndef NDEBUG
535         sanity_check();
536 #endif
537     }
538 
~object_t()539     virtual ~object_t() { decrement_and_maybe_destroy(impl_); }
540 
operator =(const object_t & other)541     object_t &operator=(const object_t &other) {
542         increment(other.impl());
543         decrement_and_maybe_destroy(impl_);
544         impl_ = other.impl();
545 #ifndef NDEBUG
546         sanity_check();
547 #endif
548         return *this;
549     }
550 
operator =(object_t && other)551     object_t &operator=(object_t &&other) {
552         std::swap(impl_, other.impl_);
553 #ifndef NDEBUG
554         sanity_check();
555 #endif
556         return *this;
557     }
558 
impl() const559     object_impl_t *impl() const { return impl_; }
560 
is_empty() const561     bool is_empty() const { return !impl_; }
562 
type_id() const563     int64_t type_id() const { return impl_->type_id(); }
564 
dispatch_type_id() const565     int64_t dispatch_type_id() const { return impl_->dispatch_type_id(); }
566 
567     template <typename T>
as() const568     const T &as() const {
569         ir_assert(impl_);
570         return impl_->as<T>();
571     }
572 
573     template <typename T>
as()574     T &as() {
575         ir_assert(impl_);
576         return impl_->as<T>();
577     }
578 
579     template <typename T>
as_ptr() const580     const T *as_ptr() const {
581         if (!impl_) return nullptr;
582         return impl_->as_ptr<T>();
583     }
584 
585     template <typename T>
as_ptr()586     T *as_ptr() {
587         if (!impl_) return nullptr;
588         return impl_->as_ptr<T>();
589     }
590 
591     template <typename T>
is() const592     bool is() const {
593         if (is_empty()) return false;
594         return impl_->is<T>();
595     }
596 
597     // Comparison with identity semantics.
is_same(const object_t & other) const598     bool is_same(const object_t &other) const { return impl_ == other.impl(); }
599 
600     // Comparison with equality semantics.
is_equal(const object_t & other) const601     bool is_equal(const object_t &other) const {
602         if (is_empty() || other.is_empty())
603             return is_empty() == other.is_empty();
604 
605         return impl_->is_equal(*other.impl());
606     }
607 
get_hash() const608     size_t get_hash() const {
609         if (is_empty()) return 0;
610         return impl()->get_hash();
611     }
612 
is_expr() const613     bool is_expr() const { return impl_ && impl_->is_expr(); }
is_stmt() const614     bool is_stmt() const { return impl_ && impl_->is_stmt(); }
615 
str() const616     std::string str() const {
617         if (is_empty()) return "(nil)";
618         return impl()->str();
619     }
620 
621     IR_DEFINE_DUMP()
622 
623 protected:
sanity_check() const624     virtual void sanity_check() const {}
625 
626 private:
increment(object_impl_t * impl)627     static void increment(object_impl_t *impl) {
628         if (!impl) return;
629         impl->ref_count().increment();
630     }
631 
decrement_and_maybe_destroy(object_impl_t * impl)632     static void decrement_and_maybe_destroy(object_impl_t *impl) {
633         if (!impl) return;
634         if (impl->ref_count().decrement() == 0) { delete impl; }
635     }
636 
637     object_impl_t *impl_;
638 };
639 
operator <<(std::ostream & out,const object_t & obj)640 inline std::ostream &operator<<(std::ostream &out, const object_t &obj) {
641     out << obj.str();
642     return out;
643 }
644 
645 // Helper classes for containers to store object_t.
646 struct object_id_hash_t {
operator ()dnnl::impl::gpu::jit::object_id_hash_t647     size_t operator()(const object_t &obj) const {
648         return std::hash<const object_impl_t *>()(obj.impl());
649     }
650 };
651 
652 struct object_eq_hash_t {
operator ()dnnl::impl::gpu::jit::object_eq_hash_t653     size_t operator()(const object_t &obj) const { return obj.get_hash(); }
654 };
655 
656 struct object_id_equal_t {
operator ()dnnl::impl::gpu::jit::object_id_equal_t657     bool operator()(const object_t &a, const object_t &b) const {
658         return a.is_same(b);
659     }
660 };
661 
662 struct object_eq_equal_t {
operator ()dnnl::impl::gpu::jit::object_eq_equal_t663     bool operator()(const object_t &a, const object_t &b) const {
664         return a.is_equal(b);
665     }
666 };
667 
668 // Containers to store object_t.
669 
670 // Unordered set, uses identity comparison for keys.
671 template <typename KeyT>
672 using object_set_t
673         = std::unordered_set<KeyT, object_id_hash_t, object_id_equal_t>;
674 
675 // Unordered set, uses equality comparison for keys.
676 template <typename KeyT>
677 using object_eq_set_t
678         = std::unordered_set<KeyT, object_eq_hash_t, object_eq_equal_t>;
679 
680 // Unordered map, uses identity comparison for keys.
681 template <typename KeyT, typename ValueT>
682 using object_map_t
683         = std::unordered_map<KeyT, ValueT, object_id_hash_t, object_id_equal_t>;
684 
685 // Unordered map, uses equality comparison for keys.
686 template <typename KeyT, typename ValueT>
687 using object_eq_map_t
688         = std::unordered_map<KeyT, ValueT, object_eq_hash_t, object_eq_equal_t>;
689 
690 // Helper class to mutate IR tree.
691 class ir_mutator_t {
692 public:
693     virtual ~ir_mutator_t() = default;
694 
mutate(const object_t & obj)695     object_t mutate(const object_t &obj) {
696         auto impl = obj.impl();
697         if (!impl) return impl;
698         return impl->_mutate(*this);
699     }
700 
701     template <typename T>
mutate(const std::vector<T> & v)702     std::vector<T> mutate(const std::vector<T> &v) {
703         std::vector<T> new_v;
704         for (auto &e : v)
705             new_v.push_back(mutate(e));
706         return new_v;
707     }
708 
709     // To catch missing _mutate() handlers in ir_mutator_t.
_mutate(const object_impl_t & obj)710     object_t _mutate(const object_impl_t &obj) {
711         ir_error_not_expected() << "Can't handle type: " << object_t(&obj);
712         return {};
713     }
714 
715 #define HANDLE_IR_OBJECT(type) virtual object_t _mutate(const type &obj);
716     HANDLE_MUTATE_TARGETS()
717 #undef HANDLE_IR_OBJECT
718 };
719 
720 // Base class for IR expression objects.
721 class expr_impl_t : public object_impl_t {
722 public:
723     IR_DECL_TYPE_ID(expr_impl_t)
724 
expr_impl_t(const type_t & type)725     expr_impl_t(const type_t &type) : type(type) {}
726 
727     type_t type;
728 };
729 
730 // Wrapper for IR expression objects.
731 class expr_t : public object_t {
732 public:
733     using object_t::object_t;
734 
735     expr_t() = default;
expr_t(const object_t & obj)736     expr_t(const object_t &obj) : object_t(obj) {}
expr_t(object_t && obj)737     expr_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)738     expr_t &operator=(const object_t &obj) {
739         object_t::operator=(obj);
740         return *this;
741     }
operator =(object_t && obj)742     expr_t &operator=(object_t &&obj) {
743         object_t::operator=(obj);
744         return *this;
745     }
746 
747     explicit expr_t(bool v);
748     expr_t(float v);
749     expr_t(int16_t v);
750     expr_t(int32_t v);
751     expr_t(int64_t v);
752     expr_t(uint16_t v);
753     expr_t(uint32_t v);
754     expr_t(uint64_t v);
755 
type() const756     const type_t &type() const {
757         ir_assert(!is_empty());
758         return ((const expr_impl_t *)impl())->type;
759     }
760 
761 #define DECLARE_BINARY_ASSIGN_OPERATOR(op) \
762     expr_t &operator op##=(const expr_t &rhs);
763 
764     DECLARE_BINARY_ASSIGN_OPERATOR(+)
765     DECLARE_BINARY_ASSIGN_OPERATOR(-)
766     DECLARE_BINARY_ASSIGN_OPERATOR(*)
767     DECLARE_BINARY_ASSIGN_OPERATOR(/)
768     DECLARE_BINARY_ASSIGN_OPERATOR(%)
769     DECLARE_BINARY_ASSIGN_OPERATOR(&)
770 
771 #undef DECLARE_BINARY_ASSIGN_OPERATOR
772 
773     // Returns a pointer shifted by `off` bytes relative to this pointer. The
774     // base expression must be a pointer.
775     expr_t operator[](const expr_t &off) const;
776 
777 private:
sanity_check() const778     void sanity_check() const override {
779         ir_assert(dynamic_cast<const expr_impl_t *>(impl()) == impl())
780                 << object_t(impl());
781     }
782 };
783 
784 // Helper functions.
785 inline bool is_const(const expr_t &e);
786 inline bool is_var(const expr_t &e);
787 
788 // Unary and binary operators.
789 enum class op_kind_t {
790     undef,
791 
792     _minus,
793     _add,
794     _sub,
795     _mul,
796     _div,
797     _mod,
798     _shl,
799     _shr,
800     _min,
801     _max,
802 
803     _lt,
804     _le,
805     _gt,
806     _ge,
807     _ne,
808     _eq,
809 
810     _and,
811 
812     _add3, // a + b + c
813     _mad, // a + b * c
814 };
815 
816 std::string to_string(op_kind_t kind);
817 
operator <<(std::ostream & out,op_kind_t kind)818 inline std::ostream &operator<<(std::ostream &out, op_kind_t kind) {
819     out << to_string(kind);
820     return out;
821 }
822 
823 bool is_cmp_op(op_kind_t op_kind);
824 
825 op_kind_t negate_cmp_op(op_kind_t op_kind);
826 
827 type_t unary_op_type(op_kind_t op_kind, const expr_t &a);
828 
829 type_t common_int_type(const type_t &_a, const type_t &_b);
830 
831 type_t common_type(const type_t &a, const type_t &b);
832 
833 type_t common_type(const expr_t &a, const expr_t &b);
834 
835 type_t binary_op_type(op_kind_t op_kind, const expr_t &a, const expr_t &b);
836 
837 type_t ternary_op_type(
838         op_kind_t op_kind, const expr_t &a, const expr_t &b, const expr_t &c);
839 
840 type_t nary_op_type(op_kind_t op_kind, const std::vector<expr_t> &args);
841 
842 // Binary operation: (a op b).
843 class binary_op_t : public expr_impl_t {
844 public:
IR_DECL_EXPR_TYPE_ID(binary_op_t)845     IR_DECL_EXPR_TYPE_ID(binary_op_t)
846 
847     static expr_t make(op_kind_t op_kind, const expr_t &a, const expr_t &b) {
848         return expr_t(new binary_op_t(op_kind, a, b));
849     }
850 
is_equal(const object_impl_t & obj) const851     bool is_equal(const object_impl_t &obj) const override {
852         if (!obj.is<self_type>()) return false;
853         auto &other = obj.as<self_type>();
854 
855         return (op_kind == other.op_kind) && a.is_equal(other.a)
856                 && b.is_equal(other.b);
857     }
858 
get_hash() const859     size_t get_hash() const override {
860         return ir_utils::get_hash(op_kind, a, b);
861     }
862 
863     IR_DECLARE_TRAVERSERS()
864 
865     op_kind_t op_kind;
866     expr_t a;
867     expr_t b;
868 
869 private:
binary_op_t(op_kind_t op_kind,const expr_t & a,const expr_t & b)870     binary_op_t(op_kind_t op_kind, const expr_t &a, const expr_t &b)
871         : expr_impl_t(binary_op_type(op_kind, a, b))
872         , op_kind(op_kind)
873         , a(a)
874         , b(b) {}
875 };
876 
877 // Boolean immediate value.
878 class bool_imm_t : public expr_impl_t {
879 public:
880     friend class expr_t;
IR_DECL_EXPR_TYPE_ID(bool_imm_t)881     IR_DECL_EXPR_TYPE_ID(bool_imm_t)
882 
883     static expr_t make(bool value) { return expr_t(new bool_imm_t(value)); }
884 
is_equal(const object_impl_t & obj) const885     bool is_equal(const object_impl_t &obj) const override {
886         if (!obj.is<self_type>()) return false;
887         auto &other = obj.as<self_type>();
888 
889         return value == other.value;
890     }
891 
get_hash() const892     size_t get_hash() const override { return ir_utils::get_hash(value); }
893 
894     bool value;
895 
896 private:
bool_imm_t(bool value)897     bool_imm_t(bool value) : expr_impl_t(type_t::_bool()), value(value) {}
898 };
899 
900 // Cast between data types.
901 class cast_t : public expr_impl_t {
902 public:
IR_DECL_EXPR_TYPE_ID(cast_t)903     IR_DECL_EXPR_TYPE_ID(cast_t)
904 
905     static expr_t make(
906             const type_t &type, const expr_t &expr, bool saturate = false) {
907         return expr_t(new cast_t(type, expr, saturate));
908     }
909 
is_equal(const object_impl_t & obj) const910     bool is_equal(const object_impl_t &obj) const override {
911         if (!obj.is<self_type>()) return false;
912         auto &other = obj.as<self_type>();
913 
914         return (type == other.type) && expr.is_equal(other.expr)
915                 && (saturate == other.saturate);
916     }
917 
get_hash() const918     size_t get_hash() const override {
919         return ir_utils::get_hash(type, expr, saturate);
920     }
921 
922     IR_DECLARE_TRAVERSERS()
923 
924     expr_t expr;
925     bool saturate;
926 
927 private:
cast_t(const type_t & type,const expr_t & expr,bool saturate)928     cast_t(const type_t &type, const expr_t &expr, bool saturate)
929         : expr_impl_t(type), expr(expr), saturate(saturate) {
930         ir_assert(type.elems() == expr.type().elems())
931                 << "Number of elements must match.";
932     }
933 };
934 
935 // Floating-point immediate value.
936 class float_imm_t : public expr_impl_t {
937 public:
938     friend class expr_t;
IR_DECL_EXPR_TYPE_ID(float_imm_t)939     IR_DECL_EXPR_TYPE_ID(float_imm_t)
940 
941     static expr_t make(float value) { return expr_t(new float_imm_t(value)); }
942 
is_equal(const object_impl_t & obj) const943     bool is_equal(const object_impl_t &obj) const override {
944         if (!obj.is<self_type>()) return false;
945         auto &other = obj.as<self_type>();
946 
947         return value == other.value;
948     }
949 
get_hash() const950     size_t get_hash() const override { return ir_utils::get_hash(value); }
951 
952     float value;
953 
954 private:
float_imm_t(float value)955     float_imm_t(float value) : expr_impl_t(type_t::f32()), value(value) {}
956 };
957 
958 // Immediate if or the conditional (ternary) operator.
959 // C++ equivalent: (cond ? true_expr : false_expr).
960 class iif_t : public expr_impl_t {
961 public:
962     IR_DECL_EXPR_TYPE_ID(iif_t);
963 
make(const expr_t & cond,const expr_t & true_expr,const expr_t & false_expr)964     static expr_t make(const expr_t &cond, const expr_t &true_expr,
965             const expr_t &false_expr) {
966         return expr_t(new iif_t(cond, true_expr, false_expr));
967     }
968 
is_equal(const object_impl_t & obj) const969     bool is_equal(const object_impl_t &obj) const override {
970         if (!obj.is<self_type>()) return false;
971         auto &other = obj.as<self_type>();
972 
973         return cond.is_equal(other.cond) && true_expr.is_equal(other.true_expr)
974                 && false_expr.is_equal(other.false_expr);
975     }
976 
get_hash() const977     size_t get_hash() const override {
978         return ir_utils::get_hash(cond, true_expr, false_expr);
979     }
980 
981     IR_DECLARE_TRAVERSERS()
982 
983     expr_t cond;
984     expr_t true_expr;
985     expr_t false_expr;
986 
987 private:
iif_t(const expr_t & cond,const expr_t & true_expr,const expr_t & false_expr)988     iif_t(const expr_t &cond, const expr_t &true_expr, const expr_t &false_expr)
989         : expr_impl_t(common_type(true_expr.type(), false_expr.type()))
990         , cond(cond)
991         , true_expr(true_expr)
992         , false_expr(false_expr) {}
993 };
994 
995 // Integer immediate value.
996 class int_imm_t : public expr_impl_t {
997 public:
998     friend class expr_t;
999     IR_DECL_EXPR_TYPE_ID(int_imm_t);
1000 
1001     template <typename T>
make(T value,const type_t & type=type_t::undef ())1002     static expr_t make(T value, const type_t &type = type_t::undef()) {
1003         return expr_t(new int_imm_t(value, type));
1004     }
1005 
is_equal(const object_impl_t & obj) const1006     bool is_equal(const object_impl_t &obj) const override {
1007         if (!obj.is<self_type>()) return false;
1008         auto &other = obj.as<self_type>();
1009 
1010         return value == other.value;
1011     }
1012 
get_hash() const1013     size_t get_hash() const override { return ir_utils::get_hash(value); }
1014 
shrink_type(const expr_t & e)1015     static expr_t shrink_type(const expr_t &e) {
1016         auto &imm = e.as<int_imm_t>();
1017         type_t new_type = shrink_type(imm.value);
1018         if (new_type == imm.type) return e;
1019         return make(imm.value, new_type);
1020     }
1021 
1022     template <typename T>
try_shrink_type(int64_t v)1023     static bool try_shrink_type(int64_t v) {
1024         if (v >= std::numeric_limits<T>::min()
1025                 && v <= std::numeric_limits<T>::max())
1026             return true;
1027         return false;
1028     }
1029 
1030     int64_t value;
1031 
1032 private:
int_imm_t(int64_t value,const type_t & type=type_t::undef ())1033     int_imm_t(int64_t value, const type_t &type = type_t::undef())
1034         : expr_impl_t(type.is_undef() ? shrink_type(value) : type)
1035         , value(value) {}
1036 
shrink_type(int64_t v)1037     static type_t shrink_type(int64_t v) {
1038         if (try_shrink_type<int32_t>(v)) return type_t::s32();
1039         return type_t::s64();
1040     }
1041 };
1042 
1043 // Updates `base_expr` and `off` so that after return:
1044 // - base_expr contains a variable of a pointer type
1045 // - off contains an offset
1046 void normalize_ptr(const type_t &type, expr_t &base, expr_t &off);
1047 
1048 // Load from a GRF buffer.
1049 // C++ equivalent (when type is scalar):
1050 //     load = *(type *)(&buf[off]);
1051 // C++ equivalent (when type is vector):
1052 //     int _stride = (has_default_stride() ? sizeof(scalar_type) : stride);
1053 //     for (int i = 0; i < elems; i++) {
1054 //         load[i] = *(scalar_type *)(&buf[off + i * _stride]);
1055 //     }
1056 class load_t : public expr_impl_t {
1057 public:
IR_DECL_EXPR_TYPE_ID(load_t)1058     IR_DECL_EXPR_TYPE_ID(load_t)
1059 
1060     // offset and stride are expressed in bytes.
1061     // default stride means unit stride (in terms of value.type().scalar()
1062     // elements).
1063     static expr_t make(const type_t &type, const expr_t &buf, const expr_t &off,
1064             int stride = default_stride) {
1065         return expr_t(new load_t(type, buf, off, stride));
1066     }
1067 
is_equal(const object_impl_t & obj) const1068     bool is_equal(const object_impl_t &obj) const override {
1069         if (!obj.is<self_type>()) return false;
1070         auto &other = obj.as<self_type>();
1071 
1072         return type.is_equal(other.type) && buf.is_equal(other.buf)
1073                 && off.is_equal(other.off) && (stride == other.stride);
1074     }
1075 
get_hash() const1076     size_t get_hash() const override {
1077         return ir_utils::get_hash(type, buf, off, stride);
1078     }
1079 
has_default_stride() const1080     bool has_default_stride() const { return stride == default_stride; }
1081 
1082     IR_DECLARE_TRAVERSERS()
1083 
1084     static const int default_stride = -1;
1085 
1086     expr_t buf;
1087     expr_t off;
1088     int stride;
1089 
1090 private:
load_t(const type_t & _type,const expr_t & _buf,const expr_t & _off,int _stride)1091     load_t(const type_t &_type, const expr_t &_buf, const expr_t &_off,
1092             int _stride)
1093         : expr_impl_t(_type), buf(_buf), off(_off), stride(_stride) {
1094         normalize_ptr(type, buf, off);
1095         ir_assert(is_var(buf)) << buf;
1096         ir_assert(buf.type().is_ptr()) << buf;
1097         if (stride == type.scalar().size()) stride = default_stride;
1098     }
1099 };
1100 
1101 // N-ary expression: (a[0] op a[1] op ... op a[n - 1]),
1102 // where <op> is either addition or multiplication.
1103 class nary_op_t : public expr_impl_t {
1104 public:
IR_DECL_EXPR_TYPE_ID(nary_op_t)1105     IR_DECL_EXPR_TYPE_ID(nary_op_t)
1106 
1107     static expr_t make(op_kind_t op_kind, const std::vector<expr_t> &args) {
1108         return expr_t(new nary_op_t(op_kind, args));
1109     }
1110 
is_equal(const object_impl_t & obj) const1111     bool is_equal(const object_impl_t &obj) const override {
1112         if (!obj.is<self_type>()) return false;
1113         auto &other = obj.as<self_type>();
1114 
1115         return (op_kind == other.op_kind)
1116                 && ir_utils::is_equal(args, other.args);
1117     }
1118 
get_hash() const1119     size_t get_hash() const override {
1120         return ir_utils::get_hash(op_kind, args);
1121     }
1122 
str() const1123     std::string str() const override {
1124         std::ostringstream oss;
1125         oss << "(";
1126         for (size_t i = 0; i < args.size(); i++) {
1127             oss << (i != 0 ? " " + to_string(op_kind) + " " : "") << args[i];
1128         }
1129 
1130         oss << ")";
1131         return oss.str();
1132     }
1133 
1134     IR_DECLARE_TRAVERSERS()
1135 
1136     op_kind_t op_kind;
1137     std::vector<expr_t> args;
1138 
1139 private:
nary_op_t(op_kind_t op_kind,const std::vector<expr_t> & args)1140     nary_op_t(op_kind_t op_kind, const std::vector<expr_t> &args)
1141         : expr_impl_t(nary_op_type(op_kind, args))
1142         , op_kind(op_kind)
1143         , args(args) {}
1144 };
1145 
1146 // Pointer expression: (base_ptr + off).
1147 class ptr_t : public expr_impl_t {
1148 public:
IR_DECL_EXPR_TYPE_ID(ptr_t)1149     IR_DECL_EXPR_TYPE_ID(ptr_t)
1150 
1151     // off - offset in bytes.
1152     static expr_t make(const expr_t &base, const expr_t &off) {
1153         return expr_t(new ptr_t(base, off));
1154     }
1155 
is_equal(const object_impl_t & obj) const1156     bool is_equal(const object_impl_t &obj) const override {
1157         if (!obj.is<self_type>()) return false;
1158         auto &other = obj.as<self_type>();
1159 
1160         return base.is_equal(other.base) && off.is_equal(other.off);
1161     }
1162 
get_hash() const1163     size_t get_hash() const override { return ir_utils::get_hash(base, off); }
1164 
1165     // Normalizes (base op off) pointer so that the new base is a variable and
1166     // off is an offset expression.
1167     // Example:
1168     //     Before call: base = (base0 + off0), off = off1
1169     //     After call:  base = base0, off = off0 + off1
1170     static void normalize(
1171             expr_t &base, expr_t &off, op_kind_t op_kind = op_kind_t::_add);
1172 
1173     IR_DECLARE_TRAVERSERS()
1174 
1175     expr_t base;
1176     expr_t off;
1177 
1178 private:
ptr_t(const expr_t & base,const expr_t & off)1179     ptr_t(const expr_t &base, const expr_t &off)
1180         : expr_impl_t(base.type()), base(base), off(off) {
1181         normalize(this->base, this->off);
1182     }
1183 };
1184 
1185 class shuffle_t : public expr_impl_t {
1186 public:
IR_DECL_EXPR_TYPE_ID(shuffle_t)1187     IR_DECL_EXPR_TYPE_ID(shuffle_t)
1188 
1189     static expr_t make(
1190             const std::vector<expr_t> &vec, const std::vector<int> &idx) {
1191         if (idx.size() == 1) return vec[idx[0]];
1192         return expr_t(new shuffle_t(vec, idx));
1193     }
1194 
make(const std::vector<expr_t> & _vec,bool find_equal=true)1195     static expr_t make(
1196             const std::vector<expr_t> &_vec, bool find_equal = true) {
1197         std::vector<expr_t> vec;
1198         std::vector<int> idx;
1199         for (auto &v : _vec) {
1200             bool found = false;
1201             int size = int(vec.size());
1202             if (find_equal) {
1203                 for (int i = 0; i < size; i++) {
1204                     if (v.is_equal(vec[i])) {
1205                         idx.push_back(i);
1206                         found = true;
1207                         break;
1208                     }
1209                 }
1210             }
1211             if (!found) {
1212                 vec.push_back(v);
1213                 idx.push_back(size);
1214             }
1215         }
1216         return make(vec, idx);
1217     }
1218 
make_broadcast(const expr_t & expr,int elems)1219     static expr_t make_broadcast(const expr_t &expr, int elems) {
1220         ir_assert(expr.type().is_scalar()) << expr;
1221         return make({expr}, std::vector<int>(elems, 0));
1222     }
1223 
1224     // Slices the existing shuffle expression. For inputs (S, beg, end) returns
1225     // (S[beg], S[beg + 1], ..., S[end - 1]) vector.
make(const expr_t & _shuffle,int beg,int end)1226     static expr_t make(const expr_t &_shuffle, int beg, int end) {
1227         auto &shuffle = _shuffle.as<shuffle_t>();
1228         ir_assert(beg >= 0 && beg <= shuffle.elems());
1229         ir_assert(end >= 0 && end <= shuffle.elems());
1230         ir_assert(beg < end);
1231         std::vector<expr_t> vec;
1232         std::vector<int> idx(end - beg, -1);
1233         for (int i = beg; i < end; i++) {
1234             if (idx[i - beg] != -1) continue;
1235             int old_idx = shuffle.idx[i];
1236             vec.push_back(shuffle.vec[old_idx]);
1237             for (int j = i; j < end; j++) {
1238                 if (shuffle.idx[j] == old_idx)
1239                     idx[j - beg] = int(vec.size()) - 1;
1240             }
1241         }
1242         return make(vec, idx);
1243     }
1244 
is_equal(const object_impl_t & obj) const1245     bool is_equal(const object_impl_t &obj) const override {
1246         if (!obj.is<self_type>()) return false;
1247         auto &other = obj.as<self_type>();
1248 
1249         return ir_utils::is_equal(vec, other.vec)
1250                 && ir_utils::is_equal(idx, other.idx);
1251     }
1252 
get_hash() const1253     size_t get_hash() const override { return ir_utils::get_hash(vec, idx); }
1254 
elems() const1255     int elems() const { return int(idx.size()); }
1256 
is_vector() const1257     bool is_vector() const {
1258         for (int i = 0; i < elems(); i++)
1259             if (idx[i] != i) return false;
1260         return true;
1261     }
1262 
is_broadcast() const1263     bool is_broadcast() const { return vec.size() == 1; }
1264 
1265     IR_DECLARE_TRAVERSERS()
1266 
1267     std::vector<expr_t> vec;
1268     std::vector<int> idx;
1269 
1270 private:
shuffle_t(const std::vector<expr_t> & vec,const std::vector<int> & idx)1271     shuffle_t(const std::vector<expr_t> &vec, const std::vector<int> &idx)
1272         : expr_impl_t(shuffle_type(vec, idx)), vec(vec), idx(idx) {
1273         ir_assert(idx.size() > 1) << "Unexpected empty or scalar shuffle.";
1274     }
1275 
shuffle_type(const std::vector<expr_t> & vec,const std::vector<int> & idx)1276     static type_t shuffle_type(
1277             const std::vector<expr_t> &vec, const std::vector<int> &idx) {
1278         ir_assert(!vec.empty() && !idx.empty());
1279 
1280         auto elem_type = vec[0].type();
1281         for (auto &v : vec)
1282             elem_type = common_type(elem_type, v.type());
1283 
1284         for (size_t i = 0; i < idx.size(); i++) {
1285             ir_assert(idx[i] >= 0 && idx[i] < int(vec.size()))
1286                     << "Incorrect index.";
1287             MAYBE_UNUSED(i);
1288         }
1289 
1290         int elems = int(idx.size());
1291         return elem_type.with_elems(elems);
1292     }
1293 };
1294 
1295 // Ternary operation: op(a, b, c).
1296 class ternary_op_t : public expr_impl_t {
1297 public:
IR_DECL_EXPR_TYPE_ID(ternary_op_t)1298     IR_DECL_EXPR_TYPE_ID(ternary_op_t)
1299 
1300     static expr_t make(op_kind_t op_kind, const expr_t &a, const expr_t &b,
1301             const expr_t &c) {
1302         return expr_t(new ternary_op_t(op_kind, a, b, c));
1303     }
1304 
is_equal(const object_impl_t & obj) const1305     bool is_equal(const object_impl_t &obj) const override {
1306         if (!obj.is<self_type>()) return false;
1307         auto &other = obj.as<self_type>();
1308 
1309         return (op_kind == other.op_kind) && a.is_equal(other.a)
1310                 && b.is_equal(other.b) && c.is_equal(other.c);
1311     }
1312 
get_hash() const1313     size_t get_hash() const override {
1314         return ir_utils::get_hash(op_kind, a, b, c);
1315     }
1316 
1317     IR_DECLARE_TRAVERSERS()
1318 
1319     op_kind_t op_kind;
1320     expr_t a;
1321     expr_t b;
1322     expr_t c;
1323 
1324 private:
ternary_op_t(op_kind_t op_kind,const expr_t & a,const expr_t & b,const expr_t & c)1325     ternary_op_t(op_kind_t op_kind, const expr_t &a, const expr_t &b,
1326             const expr_t &c)
1327         : expr_impl_t(ternary_op_type(op_kind, a, b, c))
1328         , op_kind(op_kind)
1329         , a(a)
1330         , b(b)
1331         , c(c) {}
1332 };
1333 
ternary_mad(const expr_t & a,const expr_t & b,const expr_t & c)1334 inline expr_t ternary_mad(const expr_t &a, const expr_t &b, const expr_t &c) {
1335     return ternary_op_t::make(op_kind_t::_mad, a, b, c);
1336 }
1337 
ternary_add3(const expr_t & a,const expr_t & b,const expr_t & c)1338 inline expr_t ternary_add3(const expr_t &a, const expr_t &b, const expr_t &c) {
1339     return ternary_op_t::make(op_kind_t::_add3, a, b, c);
1340 }
1341 
1342 // Unary operation: (op a).
1343 class unary_op_t : public expr_impl_t {
1344 public:
IR_DECL_EXPR_TYPE_ID(unary_op_t)1345     IR_DECL_EXPR_TYPE_ID(unary_op_t)
1346 
1347     static expr_t make(op_kind_t op_kind, const expr_t &a) {
1348         return expr_t(new unary_op_t(op_kind, a));
1349     }
1350 
is_equal(const object_impl_t & obj) const1351     bool is_equal(const object_impl_t &obj) const override {
1352         if (!obj.is<self_type>()) return false;
1353         auto &other = obj.as<self_type>();
1354 
1355         return (op_kind == other.op_kind) && a.is_equal(other.a);
1356     }
1357 
get_hash() const1358     size_t get_hash() const override { return ir_utils::get_hash(op_kind, a); }
1359 
1360     IR_DECLARE_TRAVERSERS()
1361 
1362     op_kind_t op_kind;
1363     expr_t a;
1364 
1365 private:
unary_op_t(op_kind_t op_kind,const expr_t & a)1366     unary_op_t(op_kind_t op_kind, const expr_t &a)
1367         : expr_impl_t(unary_op_type(op_kind, a)), op_kind(op_kind), a(a) {}
1368 };
1369 
1370 class var_t : public expr_impl_t {
1371 public:
IR_DECL_EXPR_TYPE_ID(var_t)1372     IR_DECL_EXPR_TYPE_ID(var_t)
1373 
1374     static expr_t make(const type_t &type, const std::string &name) {
1375         return expr_t(new var_t(type, name));
1376     }
1377 
is_equal(const object_impl_t & obj) const1378     bool is_equal(const object_impl_t &obj) const override {
1379         // Do not allow variable cloning.
1380         return this == &obj;
1381     }
1382 
get_hash() const1383     size_t get_hash() const override { return ir_utils::get_hash(name); }
1384 
1385     IR_DECLARE_TRAVERSERS()
1386 
1387     std::string name;
1388 
1389 private:
var_t(const type_t & type,const std::string & name)1390     var_t(const type_t &type, const std::string &name)
1391         : expr_impl_t(type), name(name) {}
1392 };
1393 
1394 // Convertor from C++ type to IR expression.
1395 template <typename T>
to_expr(T value,const type_t & type)1396 expr_t to_expr(T value, const type_t &type) {
1397 #define CASE(ir_type, cpp_type) \
1398     if (type == type_t::ir_type()) return expr_t((cpp_type)value)
1399 
1400     CASE(_bool, bool);
1401     CASE(f32, float);
1402     CASE(s16, int16_t);
1403     CASE(s32, int32_t);
1404     CASE(s64, int64_t);
1405     CASE(u16, uint16_t);
1406     CASE(u32, uint32_t);
1407     CASE(u64, uint64_t);
1408 
1409 #undef CASE
1410 
1411     ir_error_not_expected() << type;
1412 
1413     return expr_t();
1414 }
1415 
1416 template <typename T>
to_expr(T value)1417 expr_t to_expr(T value) {
1418     return to_expr(value, type_t::from_cpp<T>());
1419 }
1420 
is_binary_op(const expr_t & e)1421 inline bool is_binary_op(const expr_t &e) {
1422     return e.is<binary_op_t>();
1423 }
1424 
is_binary_op(const expr_t & e,op_kind_t op_kind)1425 inline bool is_binary_op(const expr_t &e, op_kind_t op_kind) {
1426     if (!is_binary_op(e)) return false;
1427     return e.as<binary_op_t>().op_kind == op_kind;
1428 }
1429 
is_binary_cmp_op(const expr_t & e)1430 inline bool is_binary_cmp_op(const expr_t &e) {
1431     if (!is_binary_op(e)) return false;
1432     return is_cmp_op(e.as<binary_op_t>().op_kind);
1433 }
1434 
is_const(const expr_t & e)1435 inline bool is_const(const expr_t &e) {
1436     return e.is<bool_imm_t>() || e.is<int_imm_t>() || e.is<float_imm_t>();
1437 }
1438 
is_shuffle_const(const expr_t & e)1439 inline bool is_shuffle_const(const expr_t &e) {
1440     auto *shuffle = e.as_ptr<shuffle_t>();
1441     if (!shuffle) return false;
1442     for (auto &v : shuffle->vec)
1443         if (!is_const(v)) return false;
1444     return true;
1445 }
1446 
is_var(const expr_t & e)1447 inline bool is_var(const expr_t &e) {
1448     return e.is<var_t>();
1449 }
1450 
1451 // Convertor from IR expression to C++ constant.
1452 template <typename T>
to_cpp(const expr_t & e)1453 T to_cpp(const expr_t &e) {
1454     ir_assert(is_const(e)) << "Expression must be constant.";
1455 
1456     if (e.is<int_imm_t>()) return (T)e.as<int_imm_t>().value;
1457     if (e.is<float_imm_t>()) return (T)e.as<float_imm_t>().value;
1458     if (e.is<bool_imm_t>()) return (T)e.as<bool_imm_t>().value;
1459 
1460     ir_error_not_expected();
1461     return 0;
1462 }
1463 
1464 expr_t operator-(const expr_t &a);
1465 
1466 #define DECLARE_BINARY_OPERATOR(op, op_kind) \
1467     expr_t operator op(const expr_t &a, const expr_t &b);
1468 
1469 DECLARE_BINARY_OPERATOR(+, op_kind_t::_add)
1470 DECLARE_BINARY_OPERATOR(-, op_kind_t::_sub)
1471 DECLARE_BINARY_OPERATOR(*, op_kind_t::_mul)
1472 DECLARE_BINARY_OPERATOR(/, op_kind_t::_div)
1473 DECLARE_BINARY_OPERATOR(%, op_kind_t::_mod)
1474 DECLARE_BINARY_OPERATOR(<<, op_kind_t::_shl)
1475 DECLARE_BINARY_OPERATOR(>>, op_kind_t::_shr)
1476 
1477 DECLARE_BINARY_OPERATOR(==, op_kind_t::_eq)
1478 DECLARE_BINARY_OPERATOR(!=, op_kind_t::_ne)
1479 DECLARE_BINARY_OPERATOR(>, op_kind_t::_gt)
1480 DECLARE_BINARY_OPERATOR(>=, op_kind_t::_ge)
1481 DECLARE_BINARY_OPERATOR(<, op_kind_t::_lt)
1482 DECLARE_BINARY_OPERATOR(<=, op_kind_t::_le)
1483 
1484 DECLARE_BINARY_OPERATOR(&, op_kind_t::_and)
1485 
1486 #undef DECLARE_BINARY_OPERATOR
1487 
1488 // Returns a shifted pointer with base `a` (pointer) and offset `b` (in bytes).
1489 // shift_ptr(op, a, b) returns &(a op b) in C++ terms (op is either addition or
1490 // subtraction).
1491 expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b);
1492 
1493 // Base class for IR statement objects.
1494 class stmt_impl_t : public object_impl_t {
1495 public:
1496     IR_DECL_TYPE_ID(stmt_impl_t)
1497 };
1498 
1499 // Wrapper for IR statement objects.
1500 class stmt_t : public object_t {
1501 public:
1502     using object_t::object_t;
1503 
1504     stmt_t() = default;
stmt_t(const object_t & obj)1505     stmt_t(const object_t &obj) : object_t(obj) {}
stmt_t(object_t && obj)1506     stmt_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)1507     stmt_t &operator=(const object_t &obj) {
1508         object_t::operator=(obj);
1509         return *this;
1510     }
operator =(object_t && obj)1511     stmt_t &operator=(object_t &&obj) {
1512         object_t::operator=(obj);
1513         return *this;
1514     }
1515 
1516     stmt_t append(const stmt_t &s) const;
1517 
1518 private:
sanity_check() const1519     void sanity_check() const override {
1520         ir_assert(dynamic_cast<const stmt_impl_t *>(impl()) == impl())
1521                 << object_t(impl());
1522     }
1523 };
1524 
1525 enum class alloc_kind_t {
1526     undef,
1527     grf, // GRF - general register file.
1528     slm, // SLM - shared local memory.
1529     global, // Global memory.
1530 };
1531 
1532 class alloc_attr_impl_t : public object_impl_t {};
1533 
1534 class alloc_attr_t : public object_t {
1535 public:
1536     using object_t::object_t;
1537 
1538     alloc_attr_t() = default;
alloc_attr_t(const object_t & obj)1539     alloc_attr_t(const object_t &obj) : object_t(obj) {}
alloc_attr_t(object_t && obj)1540     alloc_attr_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)1541     alloc_attr_t &operator=(const object_t &obj) {
1542         object_t::operator=(obj);
1543         return *this;
1544     }
operator =(object_t && obj)1545     alloc_attr_t &operator=(object_t &&obj) {
1546         object_t::operator=(obj);
1547         return *this;
1548     }
1549 
1550 private:
sanity_check() const1551     void sanity_check() const override {
1552         ir_assert(dynamic_cast<const alloc_attr_impl_t *>(impl()) == impl())
1553                 << object_t(impl());
1554     }
1555 };
1556 
1557 // Allocation attribute for GRF.
1558 class grf_alloc_attr_t : public alloc_attr_impl_t {
1559 public:
IR_DECL_TYPE_ID(grf_alloc_attr_t)1560     IR_DECL_TYPE_ID(grf_alloc_attr_t)
1561 
1562     static alloc_attr_t make(const ngen_proxy::Bundle &bundle) {
1563         return alloc_attr_t(new grf_alloc_attr_t(bundle));
1564     }
1565 
is_equal(const object_impl_t & obj) const1566     bool is_equal(const object_impl_t &obj) const override {
1567         if (!obj.is<self_type>()) return false;
1568         auto &other = obj.as<self_type>();
1569 
1570         return bundle == other.bundle;
1571     }
1572 
get_hash() const1573     size_t get_hash() const override {
1574         return ir_utils::get_hash(bundle.bundle_id, bundle.bank_id);
1575     }
1576 
1577     ngen_proxy::Bundle bundle;
1578 
1579 private:
grf_alloc_attr_t(const ngen_proxy::Bundle & bundle)1580     grf_alloc_attr_t(const ngen_proxy::Bundle &bundle) : bundle(bundle) {}
1581 };
1582 
1583 // Allocation for SLM and GRF buffers.
1584 // C++ equivalent:
1585 //     {
1586 //         byte *buf = new byte[size];
1587 //         body;
1588 //      }
1589 class alloc_t : public stmt_impl_t {
1590 public:
IR_DECL_STMT_TYPE_ID(alloc_t)1591     IR_DECL_STMT_TYPE_ID(alloc_t)
1592 
1593     static stmt_t make(const expr_t &buf, int size, alloc_kind_t kind,
1594             const alloc_attr_t &attr = {}, const stmt_t &body = {}) {
1595         return stmt_t(new alloc_t(buf, size, kind, attr, body));
1596     }
1597 
is_equal(const object_impl_t & obj) const1598     bool is_equal(const object_impl_t &obj) const override {
1599         if (!obj.is<self_type>()) return false;
1600         auto &other = obj.as<self_type>();
1601 
1602         return buf.is_equal(other.buf) && (size == other.size)
1603                 && (kind == other.kind) && attr.is_equal(other.attr)
1604                 && body.is_equal(other.body);
1605     }
1606 
get_hash() const1607     size_t get_hash() const override {
1608         return ir_utils::get_hash(buf, size, kind, attr, body);
1609     }
1610 
1611     IR_DECLARE_TRAVERSERS()
1612 
1613     expr_t buf;
1614     int size;
1615     alloc_kind_t kind;
1616     alloc_attr_t attr;
1617     stmt_t body;
1618 
1619 private:
alloc_t(const expr_t & buf,int size,alloc_kind_t kind,const alloc_attr_t & attr,const stmt_t & body)1620     alloc_t(const expr_t &buf, int size, alloc_kind_t kind,
1621             const alloc_attr_t &attr, const stmt_t &body)
1622         : buf(buf), size(size), kind(kind), attr(attr), body(body) {
1623         ir_assert(buf.type().is_ptr()) << buf;
1624     }
1625 };
1626 
1627 // Store to a GRF buffer.
1628 // C++ equivalent (when value is scalar):
1629 //     *(value_type *)(&buf[off]) = value;
1630 // C++ equivalent (when value is vector):
1631 //     int _stride = (has_default_stride() ? sizeof(scalar_type) : stride);
1632 //     for (int i = 0; i < elems; i++) {
1633 //         *(scalar_type *)(&buf[off + i * _stride]) = value[i];
1634 //     }
1635 class store_t : public stmt_impl_t {
1636 public:
IR_DECL_STMT_TYPE_ID(store_t)1637     IR_DECL_STMT_TYPE_ID(store_t)
1638 
1639     // offset and stride are expressed in bytes.
1640     // default stride means unit stride (in terms of value.type().scalar()
1641     // elements).
1642     static stmt_t make(const expr_t &buf, const expr_t &off,
1643             const expr_t &value, int stride = default_stride,
1644             const expr_t &mask = expr_t()) {
1645         return stmt_t(new store_t(buf, off, value, stride, mask));
1646     }
1647 
is_equal(const object_impl_t & obj) const1648     bool is_equal(const object_impl_t &obj) const override {
1649         if (!obj.is<self_type>()) return false;
1650         auto &other = obj.as<self_type>();
1651 
1652         return buf.is_equal(other.buf) && off.is_equal(other.off)
1653                 && value.is_equal(other.value) && (stride == other.stride)
1654                 && mask.is_equal(other.mask);
1655     }
1656 
get_hash() const1657     size_t get_hash() const override {
1658         return ir_utils::get_hash(buf, off, value, stride, mask);
1659     }
1660 
has_default_stride() const1661     bool has_default_stride() const { return stride == default_stride; }
1662 
1663     IR_DECLARE_TRAVERSERS()
1664 
1665     static const int default_stride = -1;
1666 
1667     expr_t buf;
1668     expr_t off;
1669     expr_t value;
1670     int stride;
1671     expr_t mask;
1672 
1673 private:
store_t(const expr_t & _buf,const expr_t & _off,const expr_t & _value,int _stride,const expr_t & _mask)1674     store_t(const expr_t &_buf, const expr_t &_off, const expr_t &_value,
1675             int _stride, const expr_t &_mask)
1676         : buf(_buf), off(_off), value(_value), stride(_stride), mask(_mask) {
1677         normalize_ptr(value.type(), buf, off);
1678         ir_assert(is_var(buf)) << buf;
1679         ir_assert(buf.type().is_ptr()) << buf;
1680         if (stride == value.type().scalar().size()) stride = default_stride;
1681         if (!mask.is_empty())
1682             ir_assert(mask.type() == type_t::_bool(value.type().elems()));
1683     }
1684 };
1685 
1686 // Loop statement with unit increment.
1687 // C++ equivalent:
1688 //    for (var = init; var < bound; var++) {
1689 //        body;
1690 //    }
1691 // unroll specifies the unroll factor, unroll = 1 means no unrolling.
1692 class for_t : public stmt_impl_t {
1693 public:
IR_DECL_STMT_TYPE_ID(for_t)1694     IR_DECL_STMT_TYPE_ID(for_t)
1695 
1696     static stmt_t make(const expr_t &var, const expr_t &init,
1697             const expr_t &bound, const stmt_t &body = {}, int unroll = 1) {
1698         return stmt_t(new for_t(var, init, bound, body, unroll));
1699     }
1700 
is_equal(const object_impl_t & obj) const1701     bool is_equal(const object_impl_t &obj) const override {
1702         if (!obj.is<self_type>()) return false;
1703         auto &other = obj.as<self_type>();
1704 
1705         return var.is_equal(other.var) && init.is_equal(other.init)
1706                 && bound.is_equal(other.bound) && body.is_equal(other.body)
1707                 && (unroll == other.unroll);
1708     }
1709 
get_hash() const1710     size_t get_hash() const override {
1711         return ir_utils::get_hash(var, init, bound, body, unroll);
1712     }
1713 
1714     IR_DECLARE_TRAVERSERS()
1715 
1716     expr_t var;
1717     expr_t init;
1718     expr_t bound;
1719     stmt_t body;
1720     int unroll;
1721 
1722 private:
for_t(const expr_t & var,const expr_t & init,const expr_t & bound,const stmt_t & body,int unroll)1723     for_t(const expr_t &var, const expr_t &init, const expr_t &bound,
1724             const stmt_t &body, int unroll)
1725         : var(var), init(init), bound(bound), body(body), unroll(unroll) {}
1726 };
1727 
1728 // If-else statement.
1729 // C++ equivalent:
1730 //     if (cond) {
1731 //         body;
1732 //     } else {
1733 //         else_body;
1734 //     }
1735 class if_t : public stmt_impl_t {
1736 public:
IR_DECL_STMT_TYPE_ID(if_t)1737     IR_DECL_STMT_TYPE_ID(if_t)
1738 
1739     static stmt_t make(const expr_t &cond, const stmt_t &body,
1740             const stmt_t &else_body = stmt_t()) {
1741         return stmt_t(new if_t(cond, body, else_body));
1742     }
1743 
is_equal(const object_impl_t & obj) const1744     bool is_equal(const object_impl_t &obj) const override {
1745         if (!obj.is<self_type>()) return false;
1746         auto &other = obj.as<self_type>();
1747 
1748         return cond.is_equal(other.cond) && body.is_equal(other.body)
1749                 && else_body.is_equal(other.else_body);
1750     }
1751 
get_hash() const1752     size_t get_hash() const override {
1753         return ir_utils::get_hash(cond, body, else_body);
1754     }
1755 
1756     IR_DECLARE_TRAVERSERS()
1757 
1758     expr_t cond;
1759     stmt_t body;
1760     stmt_t else_body;
1761 
1762 private:
if_t(const expr_t & cond,const stmt_t & body,const stmt_t & else_body)1763     if_t(const expr_t &cond, const stmt_t &body, const stmt_t &else_body)
1764         : cond(cond), body(body), else_body(else_body) {}
1765 };
1766 
1767 // Let statement, used to bind a variable to a value within a scope.
1768 // C++ equivalent:
1769 //     {
1770 //         var = value;
1771 //         body;
1772 //     }
1773 class let_t : public stmt_impl_t {
1774 public:
IR_DECL_STMT_TYPE_ID(let_t)1775     IR_DECL_STMT_TYPE_ID(let_t)
1776 
1777     static stmt_t make(
1778             const expr_t &var, const expr_t &value, const stmt_t &body = {}) {
1779         return stmt_t(new let_t(var, value, body));
1780     }
1781 
is_equal(const object_impl_t & obj) const1782     bool is_equal(const object_impl_t &obj) const override {
1783         if (!obj.is<self_type>()) return false;
1784         auto &other = obj.as<self_type>();
1785 
1786         return var.is_equal(other.var) && value.is_equal(other.value)
1787                 && body.is_equal(other.body);
1788     }
1789 
get_hash() const1790     size_t get_hash() const override {
1791         return ir_utils::get_hash(var, value, body);
1792     }
1793 
1794     IR_DECLARE_TRAVERSERS()
1795 
1796     expr_t var;
1797     expr_t value;
1798     stmt_t body;
1799 
1800 private:
let_t(const expr_t & var,const expr_t & value,const stmt_t & body)1801     let_t(const expr_t &var, const expr_t &value, const stmt_t &body)
1802         : var(var), value(value), body(body) {}
1803 };
1804 
1805 // Statement label, specific to GEMM/convolution.
1806 class stmt_label_t {
1807 public:
kernel(int index=-1)1808     static stmt_label_t kernel(int index = -1) {
1809         return stmt_label_t(kind_t::_kernel, index);
1810     }
compute_loop(int index=-1)1811     static stmt_label_t compute_loop(int index = -1) {
1812         return stmt_label_t(kind_t::_compute_loop, index);
1813     }
c_store(int index=-1)1814     static stmt_label_t c_store(int index = -1) {
1815         return stmt_label_t(kind_t::_c_store, index);
1816     }
c_zero_out(int index=-1)1817     static stmt_label_t c_zero_out(int index = -1) {
1818         return stmt_label_t(kind_t::_c_zero_out, index);
1819     }
b_reduced_zero_out(int index=-1)1820     static stmt_label_t b_reduced_zero_out(int index = -1) {
1821         return stmt_label_t(kind_t::_b_reduced_zero_out, index);
1822     }
g2s_load(int index=-1)1823     static stmt_label_t g2s_load(int index = -1) {
1824         return stmt_label_t(kind_t::_g2s_load, index);
1825     }
g2s_store(int index=-1)1826     static stmt_label_t g2s_store(int index = -1) {
1827         return stmt_label_t(kind_t::_g2s_store, index);
1828     }
g2r_load(int index=-1)1829     static stmt_label_t g2r_load(int index = -1) {
1830         return stmt_label_t(kind_t::_g2r_load, index);
1831     }
s2r_load(int index=-1)1832     static stmt_label_t s2r_load(int index = -1) {
1833         return stmt_label_t(kind_t::_s2r_load, index);
1834     }
prefetch(int index=-1)1835     static stmt_label_t prefetch(int index = -1) {
1836         return stmt_label_t(kind_t::_prefetch, index);
1837     }
mul(int index=-1)1838     static stmt_label_t mul(int index = -1) {
1839         return stmt_label_t(kind_t::_mul, index);
1840     }
1841 
operator ==(const stmt_label_t & other) const1842     bool operator==(const stmt_label_t &other) const {
1843         if (kind_ != other.kind_) return false;
1844         if (index_ == -1 || other.index_ == -1) return true;
1845         return index_ == other.index_;
1846     }
1847 
get_hash() const1848     size_t get_hash() const { return ir_utils::get_hash(kind_, index_); }
1849 
str() const1850     std::string str() const {
1851         switch (kind_) {
1852 #define CASE(kind) \
1853     case kind_t::_##kind: return #kind
1854             CASE(kernel);
1855             CASE(compute_loop);
1856             CASE(c_store);
1857             CASE(c_zero_out);
1858             CASE(g2r_load);
1859             CASE(g2s_load);
1860             CASE(g2s_store);
1861             CASE(s2r_load);
1862             CASE(prefetch);
1863             CASE(mul);
1864 #undef CASE
1865             default: ir_error_not_expected();
1866         }
1867         return {};
1868     }
1869 
1870 private:
1871     enum class kind_t {
1872         _undef,
1873         _kernel, // All kernel.
1874         _compute_loop, // Compute loop.
1875         _c_store, // GRF to GMEM store of C.
1876         _c_zero_out, // Zeroing-out of C.
1877         _b_reduced_zero_out, // Zeroing-out of B reduced buffer.
1878         _g2r_load, // GMEM to GRF load for further multiplication.
1879         _g2s_load, // GMEM to GRF load for GMEM -> SLM copy.
1880         _g2s_store, // GRF to SLM store for GMEM -> SLM copy.
1881         _s2r_load, // SLM to GRF load for further multiplication.
1882         _prefetch, // GMEM prefetch.
1883         _mul, // Multiplication.
1884     };
1885 
stmt_label_t()1886     stmt_label_t() : kind_(kind_t::_undef), index_(-1) {}
stmt_label_t(kind_t kind,int index)1887     stmt_label_t(kind_t kind, int index) : kind_(kind), index_(index) {}
1888 
1889     kind_t kind_;
1890     int index_; // Used to differentiate groups with the same kind.
1891 };
1892 
operator <<(std::ostream & out,const stmt_label_t & label)1893 inline std::ostream &operator<<(std::ostream &out, const stmt_label_t &label) {
1894     out << label.str();
1895     return out;
1896 }
1897 
1898 // Statement group, used to assign a label to a group of statements.
1899 class stmt_group_t : public stmt_impl_t {
1900 public:
IR_DECL_STMT_TYPE_ID(stmt_group_t)1901     IR_DECL_STMT_TYPE_ID(stmt_group_t)
1902 
1903     static stmt_t make(const stmt_label_t &label, const stmt_t &body) {
1904         return stmt_t(new stmt_group_t(label, body));
1905     }
1906 
is_equal(const object_impl_t & obj) const1907     bool is_equal(const object_impl_t &obj) const override {
1908         if (!obj.is<self_type>()) return false;
1909         auto &other = obj.as<self_type>();
1910 
1911         return (label == other.label) && body.is_equal(other.body);
1912     }
1913 
get_hash() const1914     size_t get_hash() const override { return ir_utils::get_hash(label, body); }
1915 
1916     IR_DECLARE_TRAVERSERS()
1917 
1918     stmt_label_t label;
1919     stmt_t body;
1920 
1921 private:
stmt_group_t(const stmt_label_t & label,const stmt_t & body)1922     stmt_group_t(const stmt_label_t &label, const stmt_t &body)
1923         : label(label), body(body) {}
1924 };
1925 
1926 // Statement sequence, allows combining two statements.
1927 // C++ equivalent:
1928 //     {
1929 //         head;
1930 //         tail;
1931 //     }
1932 class stmt_seq_t : public stmt_impl_t {
1933 public:
IR_DECL_STMT_TYPE_ID(stmt_seq_t)1934     IR_DECL_STMT_TYPE_ID(stmt_seq_t)
1935 
1936     static stmt_t make(const stmt_t &head, const stmt_t &tail) {
1937         return stmt_t(new stmt_seq_t(head, tail));
1938     }
1939 
is_equal(const object_impl_t & obj) const1940     bool is_equal(const object_impl_t &obj) const override {
1941         if (!obj.is<self_type>()) return false;
1942         auto &other = obj.as<self_type>();
1943 
1944         return head.is_equal(other.head) && tail.is_equal(other.tail);
1945     }
1946 
get_hash() const1947     size_t get_hash() const override { return ir_utils::get_hash(head, tail); }
1948 
1949     IR_DECLARE_TRAVERSERS()
1950 
1951     stmt_t head;
1952     stmt_t tail;
1953 
1954 private:
stmt_seq_t(const stmt_t & head,const stmt_t & tail)1955     stmt_seq_t(const stmt_t &head, const stmt_t &tail)
1956         : head(head), tail(tail) {}
1957 };
1958 
append(const stmt_t & s) const1959 inline stmt_t stmt_t::append(const stmt_t &s) const {
1960     if (is_empty()) return s;
1961     return stmt_seq_t::make(*this, s);
1962 }
1963 
1964 // Function call attribute.
1965 class func_call_attr_impl_t : public object_impl_t {};
1966 
1967 class func_call_attr_t : public object_t {
1968 public:
1969     using object_t::object_t;
1970 
1971     func_call_attr_t() = default;
func_call_attr_t(const object_t & obj)1972     func_call_attr_t(const object_t &obj) : object_t(obj) {}
func_call_attr_t(object_t && obj)1973     func_call_attr_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)1974     func_call_attr_t &operator=(const object_t &obj) {
1975         object_t::operator=(obj);
1976         return *this;
1977     }
operator =(object_t && obj)1978     func_call_attr_t &operator=(object_t &&obj) {
1979         object_t::operator=(obj);
1980         return *this;
1981     }
1982 
1983     // Returns a function call with the attribute applied. The input statement
1984     // must be a function call.
1985     stmt_t apply_to(const stmt_t &s) const;
1986 
1987 private:
sanity_check() const1988     void sanity_check() const override {
1989         ir_assert(dynamic_cast<const func_call_attr_impl_t *>(impl()) == impl())
1990                 << object_t(impl());
1991     }
1992 };
1993 
1994 // Instruction modifier, relies on nGEN API.
1995 class instruction_modifier_attr_t : public func_call_attr_impl_t {
1996 public:
IR_DECL_TYPE_ID(instruction_modifier_attr_t)1997     IR_DECL_TYPE_ID(instruction_modifier_attr_t)
1998 
1999     static func_call_attr_t make(const ngen_proxy::InstructionModifier &mod) {
2000         return func_call_attr_t(new instruction_modifier_attr_t(mod));
2001     }
2002 
is_equal(const object_impl_t & obj) const2003     bool is_equal(const object_impl_t &obj) const override {
2004         if (!obj.is<self_type>()) return false;
2005         auto &other = obj.as<self_type>();
2006 
2007         return mod == other.mod;
2008     }
2009 
get_hash() const2010     size_t get_hash() const override { return ir_utils::get_hash(mod); }
2011 
str() const2012     std::string str() const override {
2013         std::ostringstream oss;
2014         oss << "{";
2015         bool is_first = true;
2016         auto append = [&](const std::string &s) {
2017             if (!is_first) oss << ", ";
2018             oss << s;
2019             is_first = false;
2020         };
2021         if (mod.is_atomic) append("Atomic");
2022         if (!mod.sbid.is_empty()) {
2023             append(std::string("$") + std::to_string(mod.sbid.token));
2024         }
2025         oss << "}";
2026         return oss.str();
2027     }
2028 
2029     ngen_proxy::InstructionModifier mod;
2030 
2031 private:
instruction_modifier_attr_t(const ngen_proxy::InstructionModifier & mod)2032     instruction_modifier_attr_t(const ngen_proxy::InstructionModifier &mod)
2033         : mod(mod) {}
2034 };
2035 
2036 // Base class for function IR objects.
2037 class func_impl_t : public object_impl_t {
2038 public:
2039     IR_DECL_TYPE_ID(func_impl_t)
2040 
2041     stmt_t call(const std::vector<expr_t> &args,
2042             const func_call_attr_t &attr = {}) const;
2043 };
2044 
2045 // Wrapper for IR function objects.
2046 class func_t : public object_t {
2047 public:
2048     using object_t::object_t;
2049 
2050     func_t() = default;
func_t(const object_t & obj)2051     func_t(const object_t &obj) : object_t(obj) {}
func_t(object_t && obj)2052     func_t(object_t &&obj) : object_t(obj) {}
operator =(const object_t & obj)2053     func_t &operator=(const object_t &obj) {
2054         object_t::operator=(obj);
2055         return *this;
2056     }
operator =(object_t && obj)2057     func_t &operator=(object_t &&obj) {
2058         object_t::operator=(obj);
2059         return *this;
2060     }
2061 
call(const std::vector<expr_t> & args={},const func_call_attr_t & attr={}) const2062     stmt_t call(const std::vector<expr_t> &args = {},
2063             const func_call_attr_t &attr = {}) const {
2064         return ((const func_impl_t *)impl())->call(args, attr);
2065     }
2066 
2067 private:
sanity_check() const2068     void sanity_check() const override {
2069         ir_assert(dynamic_cast<const func_impl_t *>(impl()) == impl())
2070                 << object_t(impl());
2071     }
2072 };
2073 
2074 // Function call.
2075 class func_call_t : public stmt_impl_t {
2076 public:
IR_DECL_STMT_TYPE_ID(func_call_t)2077     IR_DECL_STMT_TYPE_ID(func_call_t)
2078 
2079     static stmt_t make(const func_t &func, const std::vector<expr_t> &args,
2080             const func_call_attr_t &attr = {}) {
2081         return stmt_t(new func_call_t(func, args, attr));
2082     }
2083 
is_equal(const object_impl_t & obj) const2084     bool is_equal(const object_impl_t &obj) const override {
2085         if (!obj.is<self_type>()) return false;
2086         auto &other = obj.as<self_type>();
2087 
2088         return func.is_equal(other.func) && ir_utils::is_equal(args, other.args)
2089                 && attr.is_equal(other.attr);
2090     }
2091 
get_hash() const2092     size_t get_hash() const override {
2093         return ir_utils::get_hash(func, args, attr);
2094     }
2095 
2096     IR_DECLARE_TRAVERSERS()
2097 
2098     func_t func;
2099     std::vector<expr_t> args;
2100     func_call_attr_t attr;
2101 
2102 private:
func_call_t(const func_t & func,const std::vector<expr_t> & args,const func_call_attr_t & attr)2103     func_call_t(const func_t &func, const std::vector<expr_t> &args,
2104             const func_call_attr_t &attr)
2105         : func(func), args(args), attr(attr) {
2106         ir_assert(!func.is_empty());
2107     }
2108 };
2109 
call(const std::vector<expr_t> & args,const func_call_attr_t & attr) const2110 inline stmt_t func_impl_t::call(
2111         const std::vector<expr_t> &args, const func_call_attr_t &attr) const {
2112     return func_call_t::make(this, args, attr);
2113 }
2114 
apply_to(const stmt_t & s) const2115 inline stmt_t func_call_attr_t::apply_to(const stmt_t &s) const {
2116     auto &c = s.as<func_call_t>();
2117     ir_assert(c.attr.is_empty())
2118             << "Merging of attributes is not supported: " << s;
2119     return func_call_t::make(c.func, c.args, *this);
2120 }
2121 
2122 template <typename F>
is_func_call(const stmt_t & s)2123 inline bool is_func_call(const stmt_t &s) {
2124     auto *c = s.as_ptr<func_call_t>();
2125     if (!c) return false;
2126     return c->func.is<F>();
2127 }
2128 
2129 // Generic function with a name.
2130 class builtin_t : public func_impl_t {
2131 public:
IR_DECL_DERIVED_TYPE_ID(builtin_t,func_impl_t)2132     IR_DECL_DERIVED_TYPE_ID(builtin_t, func_impl_t)
2133 
2134     static func_t make(const std::string &name) {
2135         return func_t(new builtin_t(name));
2136     }
2137 
is_equal(const object_impl_t & obj) const2138     bool is_equal(const object_impl_t &obj) const override {
2139         if (!obj.is<self_type>()) return false;
2140         auto &other = obj.as<self_type>();
2141 
2142         return name == other.name;
2143     }
2144 
get_hash() const2145     size_t get_hash() const override { return ir_utils::get_hash(name); }
2146 
str() const2147     std::string str() const override { return name; }
2148 
2149     std::string name;
2150 
2151 private:
builtin_t(const std::string & name)2152     builtin_t(const std::string &name) : name(name) {}
2153 };
2154 
2155 } // namespace jit
2156 } // namespace gpu
2157 } // namespace impl
2158 } // namespace dnnl
2159 
2160 #endif
2161