1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file ir.cc
22 */
23 #include <tvm/base.h>
24 #include <tvm/expr.h>
25 #include <tvm/ir.h>
26 #include <tvm/ir_pass.h>
27 #include <memory>
28 #include "../pass/ir_util.h"
29
30 namespace tvm {
31 namespace ir {
32
33 // constructors
make(DataType t,uint64_t value)34 Expr UIntImm::make(DataType t, uint64_t value) {
35 CHECK(t.is_uint() && t.lanes() == 1)
36 << "ValueError: UIntImm can only take scalar";
37 NodePtr<UIntImm> node = make_node<UIntImm>();
38 node->type = t;
39 node->value = value;
40 return Expr(node);
41 }
42
make(DataType t,double value)43 Expr FloatImm::make(DataType t, double value) {
44 CHECK_EQ(t.lanes(), 1)
45 << "ValueError: FloatImm can only take scalar";
46 NodePtr<FloatImm> node = make_node<FloatImm>();
47 node->type = t;
48 node->value = value;
49 return Expr(node);
50 }
51
make(std::string value)52 Expr StringImm::make(std::string value) {
53 NodePtr<StringImm> node = make_node<StringImm>();
54 node->type = Handle();
55 node->value = std::move(value);
56 return Expr(node);
57 }
58
make(DataType t,Expr value)59 Expr Cast::make(DataType t, Expr value) {
60 CHECK(value.defined());
61 CHECK_EQ(t.lanes(), value.type().lanes());
62 NodePtr<Cast> node = make_node<Cast>();
63 node->type = t;
64 node->value = std::move(value);
65 return Expr(node);
66 }
67
make(Expr a,Expr b)68 Expr And::make(Expr a, Expr b) {
69 CHECK(a.defined()) << "ValueError: a is undefined";
70 CHECK(b.defined()) << "ValueError: b is undefined";
71 CHECK(a.type().is_bool());
72 CHECK(b.type().is_bool());
73 CHECK(a.type() == b.type()) << "TypeError: mismatched types";
74
75 NodePtr<And> node = make_node<And>();
76 node->type = Bool(a.type().lanes());
77 node->a = std::move(a);
78 node->b = std::move(b);
79 return Expr(node);
80 }
81
make(Expr a,Expr b)82 Expr Or::make(Expr a, Expr b) {
83 CHECK(a.defined()) << "ValueError: a is undefined";
84 CHECK(b.defined()) << "ValueError: b is undefined";
85 CHECK(a.type().is_bool());
86 CHECK(b.type().is_bool());
87 CHECK(a.type() == b.type()) << "TypeError: mismatched types";
88
89 NodePtr<Or> node = make_node<Or>();
90 node->type = Bool(a.type().lanes());
91 node->a = std::move(a);
92 node->b = std::move(b);
93 return Expr(node);
94 }
95
make(Expr a)96 Expr Not::make(Expr a) {
97 CHECK(a.defined()) << "ValueError: a is undefined";
98 CHECK(a.type().is_bool());
99
100 NodePtr<Not> node = make_node<Not>();
101 node->type = Bool(a.type().lanes());
102 node->a = std::move(a);
103 return Expr(node);
104 }
105
make(Expr condition,Expr true_value,Expr false_value)106 Expr Select::make(Expr condition, Expr true_value, Expr false_value) {
107 CHECK(condition.defined()) << "ValueError: condition is undefined";
108 CHECK(true_value.defined()) << "ValueError: true_value is undefined";
109 CHECK(false_value.defined()) << "ValueError: true_value is undefined";
110 CHECK(condition.type().is_bool());
111 CHECK_EQ(condition.type().lanes(), true_value.type().lanes());
112 CHECK(false_value.type() == true_value.type()) << "TypeError: mismatched types";
113
114 NodePtr<Select> node = make_node<Select>();
115 node->type = true_value.type();
116 node->condition = std::move(condition);
117 node->true_value = std::move(true_value);
118 node->false_value = std::move(false_value);
119 return Expr(node);
120 }
121
make(DataType type,Var buffer_var,Expr index,Expr predicate)122 Expr Load::make(DataType type, Var buffer_var, Expr index, Expr predicate) {
123 CHECK(buffer_var.defined());
124 CHECK(predicate.defined());
125 CHECK(index.defined());
126 CHECK_EQ(type.lanes(), index.type().lanes());
127 CHECK_EQ(type.lanes(), predicate.type().lanes());
128
129 NodePtr<Load> node = make_node<Load>();
130 node->type = type;
131 node->buffer_var = std::move(buffer_var);
132 node->index = std::move(index);
133 node->predicate = std::move(predicate);
134
135 return Expr(node);
136 }
137
make(Expr base,Expr stride,int lanes)138 Expr Ramp::make(Expr base, Expr stride, int lanes) {
139 CHECK(base.defined());
140 CHECK(stride.defined());
141 CHECK(base.type().is_scalar());
142 CHECK(stride.type().is_scalar());
143 CHECK_GT(lanes, 1);
144 CHECK_EQ(stride.type(), base.type());
145
146 NodePtr<Ramp> node = make_node<Ramp>();
147 node->type = base.type().with_lanes(lanes);
148 node->base = base;
149 node->stride = stride;
150 node->lanes = lanes;
151 return Expr(node);
152 }
153
make(Expr value,int lanes)154 Expr Broadcast::make(Expr value, int lanes) {
155 CHECK(value.defined());
156 CHECK(value.type().is_scalar());
157 CHECK_GT(lanes, 1);
158
159 NodePtr<Broadcast> node = make_node<Broadcast>();
160 node->type = value.type().with_lanes(lanes);
161 node->value = std::move(value);
162 node->lanes = lanes;
163 return Expr(node);
164 }
165
make(Var var,Expr value,Expr body)166 Expr Let::make(Var var, Expr value, Expr body) {
167 CHECK(value.defined());
168 CHECK(body.defined());
169 CHECK_EQ(value.type(), var.type());
170
171 NodePtr<Let> node = make_node<Let>();
172 node->type = body.type();
173 node->var = std::move(var);
174 node->value = std::move(value);
175 node->body = std::move(body);
176 return Expr(node);
177 }
178
179 const char* Call::vectorizable_intrinsics[] = {
180 "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
181 "log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right,
182 ir::Call::likely, ir::Call::popcount
183 };
184
is_vectorizable() const185 bool Call::is_vectorizable() const {
186 size_t cnt = sizeof(Call::vectorizable_intrinsics) / sizeof(char*);
187 for (size_t i = 0; i < cnt; ++i) {
188 if (name == Call::vectorizable_intrinsics[i]) {
189 return true;
190 }
191 }
192 return false;
193 }
194
make(DataType type,std::string name,Array<Expr> args,CallType call_type,FunctionRef func,int value_index)195 Expr Call::make(DataType type,
196 std::string name,
197 Array<Expr> args,
198 CallType call_type,
199 FunctionRef func,
200 int value_index) {
201 for (size_t i = 0; i < args.size(); ++i) {
202 CHECK(args[i].defined());
203 }
204
205 if (call_type == Halide) {
206 for (size_t i = 0; i < args.size(); ++i) {
207 CHECK(args[i].type().is_int());
208 }
209 }
210
211 NodePtr<Call> node = make_node<Call>();
212 node->type = type;
213 node->name = std::move(name);
214 node->args = std::move(args);
215 node->call_type = call_type;
216 node->func = std::move(func);
217 node->value_index = value_index;
218 return Expr(node);
219 }
220
make(Array<Expr> vectors,Array<Expr> indices)221 Expr Shuffle::make(Array<Expr> vectors,
222 Array<Expr> indices) {
223 CHECK_NE(vectors.size(), 0U);
224 CHECK_NE(indices.size(), 0U);
225
226 Type base_type = vectors[0].type().element_of();
227 int total_lanes = 0;
228
229 for (Expr val : vectors) {
230 CHECK(val.type().element_of() == base_type);
231 total_lanes += val.type().lanes();
232 }
233 CHECK_LE(indices.size(), static_cast<size_t>(total_lanes));
234
235 NodePtr<Shuffle> node = make_node<Shuffle>();
236 node->type = base_type.with_lanes(static_cast<int>(indices.size()));
237 node->vectors = std::move(vectors);
238 node->indices = std::move(indices);
239 return Expr(node);
240 }
241
make_concat(Array<Expr> vectors)242 Expr Shuffle::make_concat(Array<Expr> vectors) {
243 CHECK_NE(vectors.size(), 0);
244 if (vectors.size() == 1) {
245 return vectors[0];
246 }
247 Array<Expr> indices;
248 int index = 0;
249 for (const Expr& e : vectors) {
250 for (int i = 0; i < e.type().lanes(); ++i) {
251 indices.push_back(IntImm::make(Int(32), index++));
252 }
253 }
254 return make(vectors, indices);
255 }
256
make_extract_element(Expr vector,int index)257 Expr Shuffle::make_extract_element(Expr vector, int index) {
258 return make({vector}, {Integer(index)});
259 }
260
make(Array<Var> lhs,Array<Var> rhs,Array<Expr> result,Array<Expr> identity_element)261 CommReducer CommReducerNode::make(Array<Var> lhs,
262 Array<Var> rhs,
263 Array<Expr> result,
264 Array<Expr> identity_element) {
265 auto node = make_node<CommReducerNode>();
266 node->lhs = lhs;
267 node->rhs = rhs;
268 node->result = result;
269 node->identity_element = identity_element;
270 return CommReducer(node);
271 }
272
operator ()(Array<Expr> a,Array<Expr> b) const273 Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
274 CHECK_EQ(a.size(), b.size());
275 CHECK_EQ(lhs.size(), a.size());
276 CHECK_EQ(rhs.size(), b.size());
277 Map<Var, Expr> value_map;
278 for (size_t i = 0; i < a.size(); ++i) {
279 value_map.Set(lhs[i], a[i]);
280 value_map.Set(rhs[i], b[i]);
281 }
282 return UpdateArray(result, [&value_map] (const Expr& e) {
283 return Substitute(e, value_map);
284 });
285 }
286
make(CommReducer combiner,Array<Expr> source,Array<IterVar> axis,Expr condition,int value_index)287 Expr Reduce::make(CommReducer combiner, Array<Expr> source,
288 Array<IterVar> axis, Expr condition, int value_index) {
289 for (size_t i = 0; i < axis.size(); ++i) {
290 CHECK_EQ(axis[i]->iter_type, kCommReduce)
291 << "Can only take axis created by reduce_axis";
292 }
293 if (!condition.defined()) {
294 condition = const_true();
295 }
296 auto n = make_node<Reduce>();
297 CHECK(source.defined());
298 for (size_t i = 0; i < axis.size(); ++i) {
299 CHECK(axis[i].defined());
300 }
301 n->type = source[value_index].type();
302 n->combiner = std::move(combiner);
303 n->source = std::move(source);
304 n->axis = std::move(axis);
305 n->condition = condition;
306 n->value_index = value_index;
307 return Expr(n);
308 }
309
make()310 Expr Any::make() {
311 auto n = make_node<Any>();
312 return Expr(n);
313 }
314
make(Var var,Expr value,Stmt body)315 Stmt LetStmt::make(Var var, Expr value, Stmt body) {
316 CHECK(value.defined());
317 CHECK(body.defined());
318 CHECK_EQ(value.type(), var.type());
319
320 NodePtr<LetStmt> node = make_node<LetStmt>();
321 node->var = std::move(var);
322 node->value = std::move(value);
323 node->body = std::move(body);
324 return Stmt(node);
325 }
326
make(NodeRef node,std::string attr_key,Expr value,Stmt body)327 Stmt AttrStmt::make(NodeRef node,
328 std::string attr_key,
329 Expr value,
330 Stmt body) {
331 auto n = make_node<AttrStmt>();
332 n->node = node;
333 n->attr_key = std::move(attr_key);
334 n->value = std::move(value);
335 n->body = std::move(body);
336 return Stmt(n);
337 }
338
make(Expr condition,Expr message,Stmt body)339 Stmt AssertStmt::make(Expr condition, Expr message, Stmt body) {
340 CHECK(condition.defined());
341 CHECK(message.type() == Int(32) ||
342 message.as<StringImm>())
343 << "TypeError: AssertStmt message must be an int or string:"
344 << message << "\n";
345
346 NodePtr<AssertStmt> node = make_node<AssertStmt>();
347 node->condition = std::move(condition);
348 node->message = std::move(message);
349 node->body = std::move(body);
350 return Stmt(node);
351 }
352
make(FunctionRef func,bool is_producer,Stmt body)353 Stmt ProducerConsumer::make(FunctionRef func, bool is_producer, Stmt body) {
354 CHECK(body.defined());
355
356 NodePtr<ProducerConsumer> node = make_node<ProducerConsumer>();
357 node->func = std::move(func);
358 node->is_producer = is_producer;
359 node->body = std::move(body);
360 return Stmt(node);
361 }
362
make(Var loop_var,Expr min,Expr extent,ForType for_type,DeviceAPI device_api,Stmt body)363 Stmt For::make(Var loop_var,
364 Expr min,
365 Expr extent,
366 ForType for_type,
367 DeviceAPI device_api,
368 Stmt body) {
369 CHECK(min.defined());
370 CHECK(extent.defined());
371 CHECK(min.type().is_scalar());
372 CHECK(extent.type().is_scalar());
373 CHECK(loop_var.type().is_scalar());
374 CHECK(body.defined());
375
376 NodePtr<For> node = make_node<For>();
377 node->loop_var = std::move(loop_var);
378 node->min = std::move(min);
379 node->extent = std::move(extent);
380 node->for_type = for_type;
381 node->device_api = device_api;
382 node->body = std::move(body);
383 return Stmt(node);
384 }
385
make(Var buffer_var,Expr value,Expr index,Expr predicate)386 Stmt Store::make(Var buffer_var, Expr value, Expr index, Expr predicate) {
387 CHECK(value.defined());
388 CHECK(index.defined());
389 CHECK(predicate.defined());
390 CHECK_EQ(value.type().lanes(), index.type().lanes());
391 CHECK_EQ(value.type().lanes(), predicate.type().lanes());
392
393 NodePtr<Store> node = make_node<Store>();
394 node->buffer_var = std::move(buffer_var);
395 node->value = std::move(value);
396 node->index = std::move(index);
397 node->predicate = std::move(predicate);
398 return Stmt(node);
399 }
400
make(FunctionRef func,int value_index,Expr value,Array<Expr> args)401 Stmt Provide::make(FunctionRef func, int value_index, Expr value, Array<Expr> args) {
402 CHECK(value_index >=0 && value_index < func->num_outputs())
403 << "value index output function return value bound";
404 CHECK(value.defined()) << "Provide of undefined value\n";
405
406 for (size_t i = 0; i < args.size(); ++i) {
407 CHECK(args[i].defined()) << "Provide to undefined location\n";
408 }
409
410 NodePtr<Provide> node = make_node<Provide>();
411 node->func = std::move(func);
412 node->value_index = value_index;
413 node->value = std::move(value);
414 node->args = std::move(args);
415 return Stmt(node);
416 }
417
make(Var buffer_var,DataType type,Array<Expr> extents,Expr condition,Stmt body,Expr new_expr,std::string free_function)418 Stmt Allocate::make(Var buffer_var,
419 DataType type,
420 Array<Expr> extents,
421 Expr condition,
422 Stmt body,
423 Expr new_expr,
424 std::string free_function) {
425 for (size_t i = 0; i < extents.size(); ++i) {
426 CHECK(extents[i].defined());
427 CHECK(extents[i].type().is_scalar());
428 }
429 CHECK(body.defined());
430 CHECK(condition.defined());
431 CHECK(condition.type().is_bool());
432
433 NodePtr<Allocate> node = make_node<Allocate>();
434 node->buffer_var = std::move(buffer_var);
435 node->type = type;
436 node->extents = std::move(extents);
437 node->condition = std::move(condition);
438 node->body = std::move(body);
439 node->new_expr = std::move(new_expr);
440 node->free_function = std::move(free_function);
441 return Stmt(node);
442 }
443
constant_allocation_size(const Array<Expr> & extents)444 int32_t Allocate::constant_allocation_size(const Array<Expr>& extents) {
445 int64_t result = 1;
446 for (size_t i = 0; i < extents.size(); ++i) {
447 if (const IntImm *int_size = extents[i].as<IntImm>()) {
448 result *= int_size->value;
449 if (result > std::numeric_limits<int32_t>::max()) {
450 return 0;
451 }
452 } else {
453 return 0;
454 }
455 }
456 return static_cast<int32_t>(result);
457 }
458
make(Var buffer_var)459 Stmt Free::make(Var buffer_var) {
460 NodePtr<Free> node = make_node<Free>();
461 node->buffer_var = buffer_var;
462 return Stmt(node);
463 }
464
make(FunctionRef func,int value_index,DataType type,Region bounds,Expr condition,Stmt body)465 Stmt Realize::make(FunctionRef func,
466 int value_index,
467 DataType type,
468 Region bounds,
469 Expr condition,
470 Stmt body) {
471 for (size_t i = 0; i < bounds.size(); ++i) {
472 CHECK(bounds[i]->min.defined());
473 CHECK(bounds[i]->extent.defined());
474 CHECK(bounds[i]->min.type().is_scalar());
475 CHECK(bounds[i]->extent.type().is_scalar());
476 }
477 CHECK(body.defined());
478 CHECK(condition.defined());
479 CHECK(condition.type().is_bool());
480
481 NodePtr<Realize> node = make_node<Realize>();
482 node->func = std::move(func);
483 node->value_index = value_index;
484 node->type = type;
485 node->bounds = std::move(bounds);
486 node->condition = std::move(condition);
487 node->body = std::move(body);
488 return Stmt(node);
489 }
490
make(FunctionRef func,int value_index,DataType type,Region bounds)491 Stmt Prefetch::make(FunctionRef func, int value_index, DataType type, Region bounds) {
492 for (size_t i = 0; i < bounds.size(); ++i) {
493 CHECK(bounds[i]->min.defined());
494 CHECK(bounds[i]->extent.defined());
495 CHECK(bounds[i]->min.type().is_scalar());
496 CHECK(bounds[i]->extent.type().is_scalar());
497 }
498
499 NodePtr<Prefetch> node = make_node<Prefetch>();
500 node->func = std::move(func);
501 node->value_index = value_index;
502 node->type = type;
503 node->bounds = std::move(bounds);
504 return Stmt(node);
505 }
506
make(Stmt first,Stmt rest)507 Stmt Block::make(Stmt first, Stmt rest) {
508 CHECK(first.defined());
509 CHECK(rest.defined());
510 NodePtr<Block> node = make_node<Block>();
511
512 // canonicalize.
513 if (const Block* b = first.as<Block>()) {
514 node->first = b->first;
515 node->rest = Block::make(b->rest, rest);
516 } else {
517 node->first = std::move(first);
518 node->rest = std::move(rest);
519 }
520 return Stmt(node);
521 }
522
make(const std::vector<Stmt> & stmts)523 Stmt Block::make(const std::vector<Stmt>& stmts) {
524 if (stmts.empty()) {
525 return Stmt();
526 }
527 Stmt result = stmts.back();
528 for (size_t i = stmts.size() - 1; i != 0; --i) {
529 result = Block::make(stmts[i - 1], result);
530 }
531 return result;
532 }
533
make(Expr condition,Stmt then_case,Stmt else_case)534 Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) {
535 CHECK(condition.defined());
536 CHECK(then_case.defined());
537 // else_case may be null.
538
539 NodePtr<IfThenElse> node = make_node<IfThenElse>();
540 node->condition = std::move(condition);
541 node->then_case = std::move(then_case);
542 node->else_case = std::move(else_case);
543 return Stmt(node);
544 }
545
make(Expr value)546 Stmt Evaluate::make(Expr value) {
547 CHECK(value.defined());
548
549 NodePtr<Evaluate> node = make_node<Evaluate>();
550 node->value = std::move(value);
551 return Stmt(node);
552 }
553
554 // Printers
555 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc0202(const ObjectRef& node, IRPrinter* p) 556 .set_dispatch<UIntImm>([](const ObjectRef& node, IRPrinter* p) {
557 auto* op = static_cast<const UIntImm*>(node.get());
558 p->stream << "(" << op->type << ")" << op->value;
559 });
560
561 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc0302(const ObjectRef& node, IRPrinter* p) 562 .set_dispatch<FloatImm>([](const ObjectRef& node, IRPrinter* p) {
563 auto* op = static_cast<const FloatImm*>(node.get());
564 auto& stream = p->stream;
565 switch (op->type.bits()) {
566 case 64:
567 stream << op->value;
568 break;
569 case 32:
570 stream << op->value << 'f';
571 break;
572 case 16:
573 stream << op->value << 'h';
574 break;
575 default:
576 LOG(FATAL) << "Unknown float type bits=" << op->type.bits();
577 }
578 });
579
580 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc0402(const ObjectRef& node, IRPrinter* p) 581 .set_dispatch<StringImm>([](const ObjectRef& node, IRPrinter* p) {
582 auto* op = static_cast<const StringImm*>(node.get());
583 auto& stream = p->stream;
584 stream << '"';
585 for (size_t i = 0; i < op->value.size(); ++i) {
586 unsigned char c = op->value[i];
587 if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
588 stream << c;
589 } else {
590 stream << '\\';
591 switch (c) {
592 case '"':
593 stream << '"';
594 break;
595 case '\\':
596 stream << '\\';
597 break;
598 case '\t':
599 stream << 't';
600 break;
601 case '\r':
602 stream << 'r';
603 break;
604 case '\n':
605 stream << 'n';
606 break;
607 default:
608 const char* hex_digits = "0123456789ABCDEF";
609 stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
610 }
611 }
612 }
613 stream << '"';
614 });
615
616 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc0502(const ObjectRef& node, IRPrinter* p) 617 .set_dispatch<Cast>([](const ObjectRef& node, IRPrinter* p) {
618 auto* op = static_cast<const Cast*>(node.get());
619 p->stream << op->type << '(';
620 p->Print(op->value);
621 p->stream << ')';
622 })
__anonc9ea20fc0602(const ObjectRef& node, IRPrinter* p) 623 .set_dispatch<Variable>([](const ObjectRef& node, IRPrinter* p) {
624 auto* op = static_cast<const Variable*>(node.get());
625 // omit the type
626 // stream << op->name << "." << op->type;
627 p->stream << op->name_hint;
628 })
__anonc9ea20fc0702(const ObjectRef& node, IRPrinter* p) 629 .set_dispatch<Add>([](const ObjectRef& node, IRPrinter* p) {
630 auto* op = static_cast<const Add*>(node.get());
631 p->stream << '(';
632 p->Print(op->a);
633 p->stream << " + ";
634 p->Print(op->b);
635 p->stream << ')';
636 })
__anonc9ea20fc0802(const ObjectRef& node, IRPrinter* p) 637 .set_dispatch<Sub>([](const ObjectRef& node, IRPrinter* p) {
638 auto* op = static_cast<const Sub*>(node.get());
639 p->stream << '(';
640 p->Print(op->a);
641 p->stream << " - ";
642 p->Print(op->b);
643 p->stream << ')';
644 })
__anonc9ea20fc0902(const ObjectRef& node, IRPrinter* p) 645 .set_dispatch<Mul>([](const ObjectRef& node, IRPrinter* p) {
646 auto* op = static_cast<const Mul*>(node.get());
647 p->stream << '(';
648 p->Print(op->a);
649 p->stream << "*";
650 p->Print(op->b);
651 p->stream << ')';
652 })
__anonc9ea20fc0a02(const ObjectRef& node, IRPrinter* p) 653 .set_dispatch<Div>([](const ObjectRef& node, IRPrinter* p) {
654 auto* op = static_cast<const Div*>(node.get());
655 p->stream << '(';
656 p->Print(op->a);
657 p->stream << "/";
658 p->Print(op->b);
659 p->stream << ')';
660 })
__anonc9ea20fc0b02(const ObjectRef& node, IRPrinter* p) 661 .set_dispatch<Mod>([](const ObjectRef& node, IRPrinter* p) {
662 auto* op = static_cast<const Mod*>(node.get());
663 p->stream << '(';
664 p->Print(op->a);
665 p->stream << " % ";
666 p->Print(op->b);
667 p->stream << ')';
668 })
__anonc9ea20fc0c02(const ObjectRef& node, IRPrinter* p) 669 .set_dispatch<Min>([](const ObjectRef& node, IRPrinter* p) {
670 auto* op = static_cast<const Min*>(node.get());
671 p->stream << "min(";
672 p->Print(op->a);
673 p->stream << ", ";
674 p->Print(op->b);
675 p->stream << ")";
676 })
__anonc9ea20fc0d02(const ObjectRef& node, IRPrinter* p) 677 .set_dispatch<Max>([](const ObjectRef& node, IRPrinter* p) {
678 auto* op = static_cast<const Max*>(node.get());
679 p->stream << "max(";
680 p->Print(op->a);
681 p->stream << ", ";
682 p->Print(op->b);
683 p->stream << ")";
684 })
__anonc9ea20fc0e02(const ObjectRef& node, IRPrinter* p) 685 .set_dispatch<EQ>([](const ObjectRef& node, IRPrinter* p) {
686 auto* op = static_cast<const EQ*>(node.get());
687 p->stream << '(';
688 p->Print(op->a);
689 p->stream << " == ";
690 p->Print(op->b);
691 p->stream << ')';
692 })
__anonc9ea20fc0f02(const ObjectRef& node, IRPrinter* p) 693 .set_dispatch<NE>([](const ObjectRef& node, IRPrinter* p) {
694 auto* op = static_cast<const NE*>(node.get());
695 p->stream << '(';
696 p->Print(op->a);
697 p->stream << " != ";
698 p->Print(op->b);
699 p->stream << ')';
700 })
__anonc9ea20fc1002(const ObjectRef& node, IRPrinter* p) 701 .set_dispatch<LT>([](const ObjectRef& node, IRPrinter* p) {
702 auto* op = static_cast<const LT*>(node.get());
703 p->stream << '(';
704 p->Print(op->a);
705 p->stream << " < ";
706 p->Print(op->b);
707 p->stream << ')';
708 })
__anonc9ea20fc1102(const ObjectRef& node, IRPrinter* p) 709 .set_dispatch<LE>([](const ObjectRef& node, IRPrinter* p) {
710 auto* op = static_cast<const LE*>(node.get());
711 p->stream << '(';
712 p->Print(op->a);
713 p->stream << " <= ";
714 p->Print(op->b);
715 p->stream << ')';
716 })
__anonc9ea20fc1202(const ObjectRef& node, IRPrinter* p) 717 .set_dispatch<GT>([](const ObjectRef& node, IRPrinter* p) {
718 auto* op = static_cast<const GT*>(node.get());
719 p->stream << '(';
720 p->Print(op->a);
721 p->stream << " > ";
722 p->Print(op->b);
723 p->stream << ')';
724 })
__anonc9ea20fc1302(const ObjectRef& node, IRPrinter* p) 725 .set_dispatch<GE>([](const ObjectRef& node, IRPrinter* p) {
726 auto* op = static_cast<const GE*>(node.get());
727 p->stream << '(';
728 p->Print(op->a);
729 p->stream << " >= ";
730 p->Print(op->b);
731 p->stream << ')';
732 });
733
734 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1402(const ObjectRef& node, IRPrinter* p) 735 .set_dispatch<FloorDiv>([](const ObjectRef& node, IRPrinter* p) {
736 auto* op = static_cast<const FloorDiv*>(node.get());
737 p->stream << "floordiv(" << op->a << ", " << op->b << ")";
738 });
739
740 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1502(const ObjectRef& node, IRPrinter* p) 741 .set_dispatch<FloorMod>([](const ObjectRef& node, IRPrinter* p) {
742 auto* op = static_cast<const FloorMod*>(node.get());
743 p->stream << "floormod(" << op->a << ", " << op->b << ")";
744 });
745
746 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1602(const ObjectRef& node, IRPrinter* p) 747 .set_dispatch<And>([](const ObjectRef& node, IRPrinter* p) {
748 auto* op = static_cast<const And*>(node.get());
749 p->stream << '(';
750 p->Print(op->a);
751 p->stream << " && ";
752 p->Print(op->b);
753 p->stream << ')';
754 });
755
756 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1702(const ObjectRef& node, IRPrinter* p) 757 .set_dispatch<Or>([](const ObjectRef& node, IRPrinter* p) {
758 auto* op = static_cast<const Or*>(node.get());
759 p->stream << '(';
760 p->Print(op->a);
761 p->stream << " || ";
762 p->Print(op->b);
763 p->stream << ')';
764 });
765
766 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1802(const ObjectRef& node, IRPrinter* p) 767 .set_dispatch<Not>([](const ObjectRef& node, IRPrinter* p) {
768 auto* op = static_cast<const Not*>(node.get());
769 p->stream << '!';
770 p->Print(op->a);
771 });
772
773 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1902(const ObjectRef& node, IRPrinter* p) 774 .set_dispatch<Select>([](const ObjectRef& node, IRPrinter* p) {
775 auto* op = static_cast<const Select*>(node.get());
776 p->stream << "select(";
777 p->Print(op->condition);
778 p->stream << ", ";
779 p->Print(op->true_value);
780 p->stream << ", ";
781 p->Print(op->false_value);
782 p->stream << ")";
783 });
784
785 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1a02(const ObjectRef& node, IRPrinter* p) 786 .set_dispatch<Load>([](const ObjectRef& node, IRPrinter* p) {
787 auto* op = static_cast<const Load*>(node.get());
788 p->stream << op->buffer_var << "[";
789 p->Print(op->index);
790 p->stream << "]";
791 if (!is_one(op->predicate)) {
792 p->stream << " if ";
793 p->Print(op->predicate);
794 }
795 });
796
797 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1b02(const ObjectRef& node, IRPrinter* p) 798 .set_dispatch<Ramp>([](const ObjectRef& node, IRPrinter* p) {
799 auto* op = static_cast<const Ramp*>(node.get());
800 p->stream << "ramp(";
801 p->Print(op->base);
802 p->stream << ", ";
803 p->Print(op->stride);
804 p->stream << ", " << op->lanes << ")";
805 });
806
807 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1c02(const ObjectRef& node, IRPrinter* p) 808 .set_dispatch<Broadcast>([](const ObjectRef& node, IRPrinter* p) {
809 auto* op = static_cast<const Broadcast*>(node.get());
810 p->stream << "x" << op->lanes << "(";
811 p->Print(op->value);
812 p->stream << ")";
813 });
814
815 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1d02(const ObjectRef& node, IRPrinter* p) 816 .set_dispatch<Call>([](const ObjectRef& node, IRPrinter* p) {
817 auto* op = static_cast<const Call*>(node.get());
818 p->stream << op->name << "(";
819 for (size_t i = 0; i < op->args.size(); ++i) {
820 p->Print(op->args[i]);
821 if (i < op->args.size() - 1) {
822 p->stream << ", ";
823 }
824 }
825 p->stream << ")";
826 });
827
828 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1e02(const ObjectRef& node, IRPrinter* p) 829 .set_dispatch<Let>([](const ObjectRef& node, IRPrinter* p) {
830 auto* op = static_cast<const Let*>(node.get());
831 p->stream << "(let " << op->var << " = ";
832 p->Print(op->value);
833 p->stream << " in ";
834 p->Print(op->body);
835 p->stream << ")";
836 });
837
838 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc1f02(const ObjectRef& node, IRPrinter* p) 839 .set_dispatch<LetStmt>([](const ObjectRef& node, IRPrinter* p) {
840 auto* op = static_cast<const LetStmt*>(node.get());
841 p->PrintIndent();
842 p->stream << "let " << op->var << " = ";
843 p->Print(op->value);
844 p->stream << '\n';
845 p->Print(op->body);
846 });
847
848 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2002(const ObjectRef& node, IRPrinter* p) 849 .set_dispatch<AttrStmt>([](const ObjectRef& node, IRPrinter* p) {
850 auto* op = static_cast<const AttrStmt*>(node.get());
851 p->PrintIndent();
852 p->stream << "// attr [";
853 p->Print(op->node);
854 p->stream << "] "
855 << op->attr_key << " = ";
856 p->Print(op->value);
857 p->stream << '\n';
858 p->Print(op->body);
859 });
860
861 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2102(const ObjectRef& node, IRPrinter* p) 862 .set_dispatch<AssertStmt>([](const ObjectRef& node, IRPrinter* p) {
863 auto* op = static_cast<const AssertStmt*>(node.get());
864 p->PrintIndent();
865 p->stream << "assert(";
866 p->Print(op->condition);
867 p->stream << ", ";
868 p->Print(op->message);
869 p->stream << ")\n";
870 p->Print(op->body);
871 });
872
873 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2202(const ObjectRef& node, IRPrinter* p) 874 .set_dispatch<ProducerConsumer>([](const ObjectRef& node, IRPrinter* p) {
875 auto* op = static_cast<const ProducerConsumer*>(node.get());
876 if (op->is_producer) {
877 p->PrintIndent();
878 p->stream << "produce " << op->func->func_name() << " {\n";
879 p->indent += 2;
880 p->Print(op->body);
881 p->indent -= 2;
882 p->PrintIndent();
883 p->stream << "}\n";
884 } else {
885 p->Print(op->body);
886 }
887 });
888
operator <<(std::ostream & out,ForType type)889 std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
890 switch (type) {
891 case ForType::Serial:
892 out << "for";
893 break;
894 case ForType::Parallel:
895 out << "parallel";
896 break;
897 case ForType::Unrolled:
898 out << "unrolled";
899 break;
900 case ForType::Vectorized:
901 out << "vectorized";
902 break;
903 }
904 return out;
905 }
906
907 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2302(const ObjectRef& node, IRPrinter* p) 908 .set_dispatch<For>([](const ObjectRef& node, IRPrinter* p) {
909 auto* op = static_cast<const For*>(node.get());
910 p->PrintIndent();
911 p->stream << op->for_type << " (" << op->loop_var << ", ";
912 p->Print(op->min);
913 p->stream << ", ";
914 p->Print(op->extent);
915 p->stream << ") {\n";
916
917 p->indent += 2;
918 p->Print(op->body);
919 p->indent -= 2;
920
921 p->PrintIndent();
922 p->stream << "}\n";
923 });
924
925 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2402(const ObjectRef& node, IRPrinter* p) 926 .set_dispatch<Store>([](const ObjectRef& node, IRPrinter* p) {
927 auto* op = static_cast<const Store*>(node.get());
928 p->PrintIndent();
929 p->stream << op->buffer_var << "[";
930 p->Print(op->index);
931 p->stream << "] = ";
932 p->Print(op->value);
933 if (!is_one(op->predicate)) {
934 p->stream << " if ";
935 p->Print(op->predicate);
936 }
937 p->stream << '\n';
938 });
939
940 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2502(const ObjectRef& node, IRPrinter* p) 941 .set_dispatch<Provide>([](const ObjectRef& node, IRPrinter* p) {
942 auto* op = static_cast<const Provide*>(node.get());
943 p->PrintIndent();
944 p->stream << op->func->func_name() << "(";
945 for (size_t i = 0; i < op->args.size(); ++i) {
946 p->Print(op->args[i]);
947 if (i < op->args.size() - 1) p->stream << ", ";
948 }
949 p->stream << ")";
950 if (op->func->num_outputs() != 1) {
951 p->stream << ".value[" << op->value_index << "]";
952 }
953 p->stream << " =";
954 p->Print(op->value);
955 p->stream << '\n';
956 });
957
958 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2602(const ObjectRef& node, IRPrinter* p) 959 .set_dispatch<Allocate>([](const ObjectRef& node, IRPrinter* p) {
960 auto* op = static_cast<const Allocate*>(node.get());
961 p->PrintIndent();
962 p->stream << "allocate " << op->buffer_var << "[" << op->type;
963 for (size_t i = 0; i < op->extents.size(); ++i) {
964 p->stream << " * ";
965 p->Print(op->extents[i]);
966 }
967 p->stream << "]";
968 if (!is_one(op->condition)) {
969 p->stream << " if ";
970 p->Print(op->condition);
971 }
972 p->stream << "\n";
973 p->Print(op->body);
974 });
975
976 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2702(const ObjectRef& node, IRPrinter* p) 977 .set_dispatch<Free>([](const ObjectRef& node, IRPrinter* p) {
978 auto* op = static_cast<const Free*>(node.get());
979 p->PrintIndent();
980 p->stream << "free " << op->buffer_var;
981 p->stream << '\n';
982 });
983
984 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2802(const ObjectRef& node, IRPrinter* p) 985 .set_dispatch<Realize>([](const ObjectRef& node, IRPrinter* p) {
986 auto* op = static_cast<const Realize*>(node.get());
987 p->PrintIndent();
988 p->stream << "realize " << op->func->func_name() << "(";
989 for (size_t i = 0; i < op->bounds.size(); ++i) {
990 p->stream << "[";
991 p->Print(op->bounds[i]->min);
992 p->stream << ", ";
993 p->Print(op->bounds[i]->extent);
994 p->stream << "]";
995 if (i < op->bounds.size() - 1) p->stream << ", ";
996 }
997 p->stream << ")";
998 if (op->func->num_outputs() != 1) {
999 p->stream << ".value[" << op->value_index << "]";
1000 }
1001 if (!is_one(op->condition)) {
1002 p->stream << " if ";
1003 p->Print(op->condition);
1004 }
1005 p->stream << " {\n";
1006
1007 p->indent += 2;
1008 p->Print(op->body);
1009 p->indent -= 2;
1010
1011 p->PrintIndent();
1012 p->stream << "}\n";
1013 });
1014
1015 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2902(const ObjectRef& node, IRPrinter* p) 1016 .set_dispatch<Prefetch>([](const ObjectRef& node, IRPrinter* p) {
1017 auto* op = static_cast<const Prefetch*>(node.get());
1018 p->PrintIndent();
1019 p->stream << "prefetch " << op->func->func_name() << "(";
1020 for (size_t i = 0; i < op->bounds.size(); ++i) {
1021 p->stream << "[";
1022 p->Print(op->bounds[i]->min);
1023 p->stream << ", ";
1024 p->Print(op->bounds[i]->extent);
1025 p->stream << "]";
1026 if (i < op->bounds.size() - 1) p->stream << ", ";
1027 }
1028 p->stream << ")";
1029 if (op->func->num_outputs() != 1) {
1030 p->stream << ".value[" << op->value_index << "]";
1031 }
1032 });
1033
1034 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2a02(const ObjectRef& node, IRPrinter* p) 1035 .set_dispatch<Block>([](const ObjectRef& node, IRPrinter* p) {
1036 auto* op = static_cast<const Block*>(node.get());
1037 p->Print(op->first);
1038 if (op->rest.defined()) p->Print(op->rest);
1039 });
1040
1041 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2b02(const ObjectRef& node, IRPrinter* p) 1042 .set_dispatch<IfThenElse>([](const ObjectRef& node, IRPrinter* p) {
1043 auto* op = static_cast<const IfThenElse*>(node.get());
1044 p->PrintIndent();
1045 while (true) {
1046 p->stream << "if (" << op->condition << ") {\n";
1047 p->indent += 2;
1048 p->Print(op->then_case);
1049 p->indent -= 2;
1050
1051 if (!op->else_case.defined()) {
1052 break;
1053 }
1054
1055 if (const IfThenElse *nested_if = op->else_case.as<IfThenElse>()) {
1056 p->PrintIndent();
1057 p->stream << "} else ";
1058 op = nested_if;
1059 } else {
1060 p->PrintIndent();
1061 p->stream << "} else {\n";
1062 p->indent += 2;
1063 p->Print(op->else_case);
1064 p->indent -= 2;
1065 break;
1066 }
1067 }
1068 p->PrintIndent();
1069 p->stream << "}\n";
1070 });
1071
1072 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2c02(const ObjectRef& node, IRPrinter* p) 1073 .set_dispatch<Evaluate>([](const ObjectRef& node, IRPrinter* p) {
1074 auto* op = static_cast<const Evaluate*>(node.get());
1075 p->PrintIndent();
1076 p->Print(op->value);
1077 p->stream << "\n";
1078 });
1079
1080 template<typename T>
PrintList(const Array<T> & exprs,IRPrinter * p)1081 void PrintList(const Array<T> &exprs, IRPrinter* p) {
1082 for (size_t i = 0; i < exprs.size(); ++i) {
1083 p->Print(exprs[i]);
1084 if (i < exprs.size() - 1) {
1085 p->stream << ", ";
1086 }
1087 }
1088 }
1089
1090 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2d02(const ObjectRef& node, IRPrinter* p) 1091 .set_dispatch<Shuffle>([](const ObjectRef& node, IRPrinter* p) {
1092 auto* op = static_cast<const Shuffle*>(node.get());
1093 p->stream << "shuffle(";
1094 PrintList(op->vectors, p);
1095 p->stream << ", ";
1096 PrintList(op->indices, p);
1097 p->stream << ")";
1098 });
1099
1100 // Container printer
1101 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2e02(const ObjectRef& node, IRPrinter* p) 1102 .set_dispatch<ArrayNode>([](const ObjectRef& node, IRPrinter* p) {
1103 auto* op = static_cast<const ArrayNode*>(node.get());
1104 p->stream << '[';
1105 for (size_t i = 0 ; i < op->data.size(); ++i) {
1106 if (i != 0) {
1107 p->stream << ", ";
1108 }
1109 p->Print(op->data[i]);
1110 }
1111 p->stream << ']';
1112 });
1113
1114 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc2f02(const ObjectRef& node, IRPrinter* p) 1115 .set_dispatch<MapNode>([](const ObjectRef& node, IRPrinter* p) {
1116 auto* op = static_cast<const MapNode*>(node.get());
1117 p->stream << '{';
1118 for (auto it = op->data.begin(); it != op->data.end(); ++it) {
1119 if (it != op->data.begin()) {
1120 p->stream << ", ";
1121 }
1122 p->Print(it->first);
1123 p->stream << ": ";
1124 p->Print(it->second);
1125 }
1126 p->stream << '}';
1127 });
1128
1129 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc3002(const ObjectRef& node, IRPrinter* p) 1130 .set_dispatch<StrMapNode>([](const ObjectRef& node, IRPrinter* p) {
1131 auto* op = static_cast<const StrMapNode*>(node.get());
1132 p->stream << '{';
1133 for (auto it = op->data.begin(); it != op->data.end(); ++it) {
1134 if (it != op->data.begin()) {
1135 p->stream << ", ";
1136 }
1137 p->stream << '\"' << it->first << "\": ";
1138 p->Print(it->second);
1139 }
1140 p->stream << '}';
1141 });
1142
1143 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc3102(const ObjectRef& node, IRPrinter* p) 1144 .set_dispatch<Reduce>([](const ObjectRef& node, IRPrinter* p) {
1145 auto* op = static_cast<const Reduce*>(node.get());
1146 p->stream << "reduce(combiner="
1147 << op->combiner;
1148 p->stream << ", source=" << op->source;
1149 p->stream << ", axis=" << op->axis;
1150 p->stream << ", where=" << op->condition;
1151 p->stream << ", value_index=" << op->value_index;
1152 p->stream << ")";
1153 });
1154
1155 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc3202(const ObjectRef& node, IRPrinter* p) 1156 .set_dispatch<CommReducerNode>([](const ObjectRef& node, IRPrinter* p) {
1157 auto* op = static_cast<const CommReducerNode*>(node.get());
1158 p->stream << "comm_reducer(result=" << op->result
1159 << ", lhs=" << op->lhs
1160 << ", rhs=" << op->rhs
1161 << ", identity_element=" << op->identity_element
1162 << ")";
1163 });
1164
1165 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonc9ea20fc3302(const ObjectRef& node, IRPrinter* p) 1166 .set_dispatch<Any>([](const ObjectRef& node, IRPrinter* p) {
1167 p->stream << "?";
1168 });
1169
1170 TVM_REGISTER_NODE_TYPE(CommReducerNode);
1171 TVM_REGISTER_NODE_TYPE(Reduce);
1172 TVM_REGISTER_NODE_TYPE(Any);
1173 TVM_REGISTER_NODE_TYPE(AttrStmt);
1174 TVM_REGISTER_NODE_TYPE(FloatImm);
1175 TVM_REGISTER_NODE_TYPE(IntImm);
1176 TVM_REGISTER_NODE_TYPE(UIntImm);
1177 TVM_REGISTER_NODE_TYPE(StringImm);
1178 TVM_REGISTER_NODE_TYPE(Cast);
1179 TVM_REGISTER_NODE_TYPE(Variable);
1180 TVM_REGISTER_NODE_TYPE(Add);
1181 TVM_REGISTER_NODE_TYPE(Sub);
1182 TVM_REGISTER_NODE_TYPE(Mul);
1183 TVM_REGISTER_NODE_TYPE(Div);
1184 TVM_REGISTER_NODE_TYPE(Mod);
1185 TVM_REGISTER_NODE_TYPE(FloorDiv);
1186 TVM_REGISTER_NODE_TYPE(FloorMod);
1187 TVM_REGISTER_NODE_TYPE(Min);
1188 TVM_REGISTER_NODE_TYPE(Max);
1189 TVM_REGISTER_NODE_TYPE(EQ);
1190 TVM_REGISTER_NODE_TYPE(NE);
1191 TVM_REGISTER_NODE_TYPE(LT);
1192 TVM_REGISTER_NODE_TYPE(LE);
1193 TVM_REGISTER_NODE_TYPE(GT);
1194 TVM_REGISTER_NODE_TYPE(GE);
1195 TVM_REGISTER_NODE_TYPE(And);
1196 TVM_REGISTER_NODE_TYPE(Or);
1197 TVM_REGISTER_NODE_TYPE(Not);
1198 TVM_REGISTER_NODE_TYPE(Select);
1199 TVM_REGISTER_NODE_TYPE(Load);
1200 TVM_REGISTER_NODE_TYPE(Ramp);
1201 TVM_REGISTER_NODE_TYPE(Broadcast);
1202 TVM_REGISTER_NODE_TYPE(Shuffle);
1203 TVM_REGISTER_NODE_TYPE(Prefetch);
1204 TVM_REGISTER_NODE_TYPE(Call);
1205 TVM_REGISTER_NODE_TYPE(Let);
1206 TVM_REGISTER_NODE_TYPE(LetStmt);
1207 TVM_REGISTER_NODE_TYPE(AssertStmt);
1208 TVM_REGISTER_NODE_TYPE(ProducerConsumer);
1209 TVM_REGISTER_NODE_TYPE(For);
1210 TVM_REGISTER_NODE_TYPE(Store);
1211 TVM_REGISTER_NODE_TYPE(Provide);
1212 TVM_REGISTER_NODE_TYPE(Allocate);
1213 TVM_REGISTER_NODE_TYPE(Free);
1214 TVM_REGISTER_NODE_TYPE(Realize);
1215 TVM_REGISTER_NODE_TYPE(Block);
1216 TVM_REGISTER_NODE_TYPE(IfThenElse);
1217 TVM_REGISTER_NODE_TYPE(Evaluate);
1218
1219 } // namespace ir
1220 } // namespace tvm
1221