1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/relay/backend/vm/compiler.cc
22 * \brief A compiler from relay::Module to the VM byte code.
23 */
24
25 #include <tvm/operation.h>
26 #include <tvm/relay/error.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/relay/interpreter.h>
29 #include <tvm/relay/qnn/transform.h>
30 #include <tvm/logging.h>
31 #include <tvm/relay/transform.h>
32 #include <tvm/runtime/vm.h>
33 #include <tvm/relay/attrs/memory.h>
34 #include <topi/tags.h>
35 #include <algorithm>
36 #include <iostream>
37 #include <memory>
38 #include <set>
39 #include <string>
40 #include <tuple>
41 #include <unordered_map>
42 #include <unordered_set>
43 #include <vector>
44 #include "../../../runtime/vm/naive_allocator.h"
45 #include "../../backend/compile_engine.h"
46 #include "../../pass/pass_util.h"
47 #include "../../op/op_common.h"
48 #include "compiler.h"
49
50 namespace tvm {
51 namespace relay {
52
53 namespace transform {
54
55 Pass LambdaLift();
56 Pass InlinePrimitives();
57 Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions);
58
ManifestAlloc(Target target_host)59 Pass ManifestAlloc(Target target_host) {
60 auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
61 CHECK(f != nullptr) << "could not load memory allocation pass";
62 return (*f)(target_host);
63 }
64
65 } // namespace transform
66
67 namespace vm {
68
69 using namespace tvm::runtime;
70 using namespace tvm::runtime::vm;
71 using namespace relay::transform;
72
73 // (@jroesch): VM passes, eventually declare as passes.
74 bool IsClosure(const Function& func);
75
76 void InstructionPrint(std::ostream& os, const Instruction& instr);
77
78 // Represent a runtime object that's going to be matched by pattern match expressions
79 struct MatchValue {
~MatchValuetvm::relay::vm::MatchValue80 virtual ~MatchValue() {}
81 };
82 using MatchValuePtr = std::shared_ptr<MatchValue>;
83
84 // A runtime object that resides in a register
85 struct RegisterValue : MatchValue {
86 // The register num
87 RegName rergister_num;
88
RegisterValuetvm::relay::vm::RegisterValue89 explicit RegisterValue(RegName reg) : rergister_num(reg) {}
90
~RegisterValuetvm::relay::vm::RegisterValue91 ~RegisterValue() {}
92 };
93
94 // The value is a field of another runtime object
95 struct AccessField : MatchValue {
96 MatchValuePtr parent;
97 // Field index
98 size_t index;
99 // Runtime register num after compiling the access field path
100 RegName reg{-1};
101
AccessFieldtvm::relay::vm::AccessField102 AccessField(MatchValuePtr parent, size_t index)
103 : parent(parent), index(index) {}
104
~AccessFieldtvm::relay::vm::AccessField105 ~AccessField() {}
106 };
107
108 /*!
109 * \brief Condition in a decision tree
110 */
111 struct ConditionNode {
~ConditionNodetvm::relay::vm::ConditionNode112 virtual ~ConditionNode() {}
113 };
114
115 using ConditionNodePtr = std::shared_ptr<ConditionNode>;
116
117 /*!
118 * \brief A var binding condition
119 */
120 struct VarBinding : ConditionNode {
121 Var var;
122 MatchValuePtr val;
123
VarBindingtvm::relay::vm::VarBinding124 VarBinding(Var var, MatchValuePtr val)
125 : var(var), val(val) {}
126
~VarBindingtvm::relay::vm::VarBinding127 ~VarBinding() {}
128 };
129
130 /*!
131 * \brief Compare the tag of the object
132 */
133 struct TagCompare : ConditionNode {
134 /*! \brief The object to be examined */
135 MatchValuePtr obj;
136
137 /*! \brief The expected tag */
138 int target_tag;
139
TagComparetvm::relay::vm::TagCompare140 TagCompare(MatchValuePtr obj, size_t target)
141 : obj(obj), target_tag(target) {
142 }
143
~TagComparetvm::relay::vm::TagCompare144 ~TagCompare() {}
145 };
146
147 using TreeNodePtr = typename relay::TreeNode<ConditionNodePtr>::pointer;
148 using TreeLeafNode = relay::TreeLeafNode<ConditionNodePtr>;
149 using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionNodePtr>;
150 using TreeBranchNode = relay::TreeBranchNode<ConditionNodePtr>;
151
BuildDecisionTreeFromPattern(MatchValuePtr data,Pattern pattern,TreeNodePtr then_branch,TreeNodePtr else_branch)152 TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
153 Pattern pattern,
154 TreeNodePtr then_branch,
155 TreeNodePtr else_branch) {
156 if (pattern.as<PatternWildcardNode>()) {
157 // We ignore wildcard binding since it's not producing new vars
158 return then_branch;
159 } else if (pattern.as<PatternVarNode>()) {
160 auto pat = pattern.as<PatternVarNode>();
161 auto pattern = GetRef<PatternVar>(pat);
162 auto cond = std::make_shared<VarBinding>(pattern->var, data);
163 return TreeBranchNode::Make(cond, then_branch, else_branch);
164 } else if (auto pcn = pattern.as<PatternConstructorNode>()) {
165 auto tag = pcn->constructor->tag;
166
167 size_t field_index = 0;
168 for (auto& p : pcn->patterns) {
169 auto d = std::make_shared<AccessField>(data, field_index);
170 then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
171 field_index++;
172 }
173 auto cond = std::make_shared<TagCompare>(data, tag);
174 return TreeBranchNode::Make(cond, then_branch, else_branch);
175 } else {
176 auto pt = pattern.as<PatternTupleNode>();
177 CHECK(pt) << "unhandled case: " << pattern;
178 size_t field_index = 0;
179 for (auto& p : pt->patterns) {
180 auto d = std::make_shared<AccessField>(data, field_index);
181 then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
182 field_index++;
183 }
184 return then_branch;
185 }
186 }
187
BuildDecisionTreeFromClause(MatchValuePtr data,Clause clause,TreeNodePtr else_branch)188 TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data,
189 Clause clause,
190 TreeNodePtr else_branch) {
191 return BuildDecisionTreeFromPattern(data, clause->lhs,
192 TreeLeafNode::Make(clause->rhs), else_branch);
193 }
194
BuildDecisionTreeFromClauses(MatchValuePtr data,tvm::Array<Clause> clauses)195 TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
196 // When nothing matches, the VM throws fatal error
197 TreeNodePtr else_branch = TreeLeafFatalNode::Make();
198 // Start from the last clause
199 for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) {
200 else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
201 }
202 return else_branch;
203 }
204
ToAllocTensorShape64(NDArray shape)205 std::vector<int64_t> ToAllocTensorShape64(NDArray shape) {
206 std::vector<int64_t> raw_shape;
207 DLTensor tensor = shape.ToDLPack()->dl_tensor;
208 CHECK_EQ(tensor.ndim, 1u);
209 CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
210
211 // TODO(@jroesch): we really need to standaridize the bit width of
212 // all of the shape manipulating code.
213 CHECK_EQ(tensor.dtype.bits, 64) << "found " << tensor.dtype.bits;
214 int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
215 for (auto i = 0; i < tensor.shape[0]; i++) {
216 raw_shape.push_back(int_ptr[i]);
217 }
218 return raw_shape;
219 }
220
221
ToAllocTensorShape32(NDArray shape)222 std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
223 std::vector<int64_t> raw_shape;
224 DLTensor tensor = shape.ToDLPack()->dl_tensor;
225 CHECK_EQ(tensor.ndim, 1u);
226 CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
227
228 // TODO(@jroesch): we really need to standaridize the bit width of
229 // all of the shape manipulating code.
230 CHECK_LE(tensor.dtype.bits, 32) << "found " << tensor.dtype.bits;
231 int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
232 for (auto i = 0; i < tensor.shape[0]; i++) {
233 raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
234 }
235 return raw_shape;
236 }
237
238 class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
239 public:
VMFunctionCompiler(VMCompilerContext * context,TargetsMap targets,Target target_host)240 VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
241 : last_register_(0),
242 registers_num_(0),
243 engine_(CompileEngine::Global()),
244 context_(context),
245 targets_(targets),
246 target_host_(target_host) {}
247
Compile(const GlobalVar & var,const Function & func)248 VMFunction Compile(const GlobalVar& var, const Function& func) {
249 size_t i = 0;
250 // We then assign register num to the free variables
251 for (auto param : func->params) {
252 auto arg_register = NewRegister();
253 CHECK_EQ(i, arg_register);
254 var_register_map_.insert({param, arg_register});
255 params_.push_back(param->name_hint());
256 ++i;
257 }
258
259 if (IsClosure(func)) {
260 Function inner_func = Downcast<Function>(func->body);
261 for (auto param : inner_func->params) {
262 auto arg_register = NewRegister();
263 CHECK_EQ(i, arg_register);
264 var_register_map_.insert({param, arg_register});
265 params_.push_back(param->name_hint());
266 ++i;
267 }
268 this->VisitExpr(inner_func->body);
269 } else {
270 this->VisitExpr(func->body);
271 }
272 instructions_.push_back(Instruction::Ret(last_register_));
273 return VMFunction(var->name_hint, params_, instructions_, registers_num_);
274 }
275
276 protected:
NewRegister()277 size_t NewRegister() { return registers_num_++; }
278
Emit(const Instruction & instr)279 inline void Emit(const Instruction& instr) {
280 DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
281 CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
282 switch (instr.op) {
283 case Opcode::AllocADT:
284 case Opcode::AllocTensor:
285 case Opcode::AllocTensorReg:
286 case Opcode::GetField:
287 case Opcode::GetTag:
288 case Opcode::LoadConst:
289 case Opcode::LoadConsti:
290 case Opcode::Invoke:
291 case Opcode::AllocClosure:
292 case Opcode::AllocStorage:
293 case Opcode::Move:
294 case Opcode::InvokeClosure:
295 last_register_ = instr.dst;
296 break;
297 case Opcode::InvokePacked:
298 case Opcode::If:
299 case Opcode::Ret:
300 case Opcode::Goto:
301 case Opcode::Fatal:
302 break;
303 }
304 instructions_.push_back(instr);
305 }
306
VisitExpr_(const ConstantNode * const_node)307 void VisitExpr_(const ConstantNode* const_node) {
308 size_t konst_idx = context_->constants.size();
309 context_->constants.push_back(const_node->data);
310 Emit(Instruction::LoadConst(konst_idx, NewRegister()));
311 }
312
VisitExpr_(const VarNode * var_node)313 void VisitExpr_(const VarNode* var_node) {
314 auto var = GetRef<Var>(var_node);
315 auto reg_it = this->var_register_map_.find(var);
316 CHECK(reg_it != this->var_register_map_.end());
317 last_register_ = reg_it->second;
318 }
319
VisitExpr_(const TupleNode * tuple_node)320 void VisitExpr_(const TupleNode* tuple_node) {
321 auto tuple = GetRef<Tuple>(tuple_node);
322 std::vector<Index> fields_registers;
323
324 for (auto& field : tuple->fields) {
325 this->VisitExpr(field);
326 fields_registers.push_back(last_register_);
327 }
328
329 // TODO(@jroesch): use correct tag
330 Emit(Instruction::AllocADT(
331 0,
332 tuple->fields.size(),
333 fields_registers,
334 NewRegister()));
335 }
336
VisitExpr_(const MatchNode * match_node)337 void VisitExpr_(const MatchNode* match_node) {
338 auto match = GetRef<Match>(match_node);
339
340 this->VisitExpr(match->data);
341 CompileMatch(match);
342 }
343
VisitExpr_(const LetNode * let_node)344 void VisitExpr_(const LetNode* let_node) {
345 DLOG(INFO) << PrettyPrint(let_node->value);
346 this->VisitExpr(let_node->value);
347 var_register_map_.insert({let_node->var, this->last_register_});
348 this->VisitExpr(let_node->body);
349 }
350
VisitExpr_(const TupleGetItemNode * get_node)351 void VisitExpr_(const TupleGetItemNode* get_node) {
352 auto get = GetRef<TupleGetItem>(get_node);
353 this->VisitExpr(get->tuple);
354 auto tuple_register = last_register_;
355 Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
356 }
357
VisitExpr_(const GlobalVarNode * gvar)358 void VisitExpr_(const GlobalVarNode* gvar) {
359 auto var = GetRef<GlobalVar>(gvar);
360 auto func = context_->module->Lookup(var);
361 auto it = context_->global_map.find(var);
362 CHECK(it != context_->global_map.end());
363 // Allocate closure with zero free vars
364 Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister()));
365 }
366
VisitExpr_(const IfNode * if_node)367 void VisitExpr_(const IfNode* if_node) {
368 this->VisitExpr(if_node->cond);
369
370 size_t test_register = last_register_;
371
372 this->Emit(Instruction::LoadConsti(1, NewRegister()));
373 auto after_cond = instructions_.size();
374 auto target_register = last_register_;
375 this->Emit(Instruction::If(test_register, target_register, 0, 0));
376 this->VisitExpr(if_node->true_branch);
377
378 // It saves the result of If-Else expression.
379 auto merge_register = NewRegister();
380 Emit(Instruction::Move(last_register_, merge_register));
381 Emit(Instruction::Goto(0));
382
383 // Finally store how many instructions there are in the
384 // true branch.
385 auto after_true = this->instructions_.size();
386
387 this->VisitExpr(if_node->false_branch);
388
389 size_t false_register = last_register_;
390
391 // In else-branch, override the then-branch register
392 Emit(Instruction::Move(false_register, merge_register));
393 // Compute the total number of instructions
394 // after generating false.
395 auto after_false = this->instructions_.size();
396
397 // Now we will compute the jump targets in order
398 // to properly patch the instruction with the
399 // the requiste targets.
400
401 // After we emit the true body, and false body,
402 // we patch up the if instruction, and goto.
403 auto true_offset = 1;
404 auto false_offset = after_true - after_cond;
405 instructions_[after_cond].if_op.true_offset = true_offset;
406 instructions_[after_cond].if_op.false_offset = false_offset;
407
408 // Patch the Goto.
409 this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;
410
411 this->last_register_ = merge_register;
412 }
413
EmitShapeFunc(Function func,Array<Expr> inputs,Array<Expr> outputs)414 void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
415 // Lower shape function
416 auto key = CCacheKeyNode::make(func, target_host_);
417 auto cfunc = engine_->LowerShapeFunc(key);
418 int op_index = -1;
419 if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
420 op_index = context_->cached_funcs.size();
421 context_->cached_funcs.push_back(cfunc);
422 context_->seen_funcs[cfunc->funcs[0]] = op_index;
423 } else {
424 op_index = context_->seen_funcs[cfunc->funcs[0]];
425 }
426
427 // Prepare input and output registers
428 std::vector<Index> argument_registers;
429 for (auto input : inputs) {
430 auto reg = var_register_map_.find(Downcast<Var>(input));
431 CHECK(reg != var_register_map_.end())
432 << "internal error: all variables should be in the register mapping";
433 argument_registers.push_back(reg->second);
434 }
435
436 for (auto output : outputs) {
437 auto reg = var_register_map_.find(Downcast<Var>(output));
438 CHECK(reg != var_register_map_.end())
439 << "internal error: all variables should be in the register mapping";
440 argument_registers.push_back(reg->second);
441 }
442
443 Emit(Instruction::InvokePacked(op_index,
444 argument_registers.size(),
445 outputs.size(),
446 argument_registers));
447 }
448
EmitInvokeTVMOp(const Function & func,const Expr & inputs,const Expr & outputs)449 void EmitInvokeTVMOp(const Function& func,
450 const Expr& inputs,
451 const Expr& outputs) {
452 std::vector<Index> argument_registers;
453
454 CHECK(func->IsPrimitive())
455 << "internal error: invoke_tvm_op requires the first argument to be a relay::Function";
456
457 auto input_tuple = inputs.as<TupleNode>();
458 CHECK(input_tuple)
459 << "internal error: invoke_tvm_op inputs must be a tuple,"
460 << "please file a bug in the memory manifestation pass";
461
462 auto output_tuple = outputs.as<TupleNode>();
463 CHECK(output_tuple)
464 << "internal error: invoke_tvm_op outputs must be a tuple,"
465 << "please file a bug in the memory manifestation pass";
466
467 for (auto input : input_tuple->fields) {
468 auto reg = var_register_map_.find(Downcast<Var>(input));
469 CHECK(reg != var_register_map_.end())
470 << "internal error: all variables should be in the register mapping";
471 argument_registers.push_back(reg->second);
472 }
473
474 for (auto output : output_tuple->fields) {
475 auto reg = var_register_map_.find(Downcast<Var>(output));
476 CHECK(reg != var_register_map_.end())
477 << "internal error: all variables should be in the register mapping";
478 argument_registers.push_back(reg->second);
479 }
480
481 // Next generate the invoke instruction.
482 Target target;
483 if (targets_.size() == 1) {
484 // homogeneous execution.
485 for (auto kv : targets_) {
486 target = kv.second;
487 }
488 } else {
489 // heterogeneous execution.
490 LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
491 }
492
493 auto key = CCacheKeyNode::make(func, target);
494 auto cfunc = engine_->Lower(key);
495
496 // TODO(jroesch): support lowered funcs for multiple targets
497 CHECK_EQ(cfunc->funcs.size(), 1);
498 auto op_index = -1;
499 if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
500 op_index = context_->cached_funcs.size();
501 context_->cached_funcs.push_back(cfunc);
502 context_->seen_funcs[cfunc->funcs[0]] = op_index;
503 } else {
504 op_index = context_->seen_funcs[cfunc->funcs[0]];
505 }
506
507 Emit(Instruction::InvokePacked(op_index,
508 argument_registers.size(),
509 output_tuple->fields.size(),
510 argument_registers));
511 }
512
VisitExpr_(const CallNode * call_node)513 void VisitExpr_(const CallNode* call_node) {
514 Expr op = call_node->op;
515
516 // First we handle the case in which we are using an opaque
517 // operator used to define a sub-dialect, such as memory
518 // allocation operations.
519 if (op.as<OpNode>()) {
520 OpMatch<void> matcher;
521 matcher.Match("memory.invoke_tvm_op",
522 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
523 CHECK_EQ(args.size(), 3);
524 EmitInvokeTVMOp(Downcast<Function>(args[0]), args[1], args[2]);
525 }).Match("memory.alloc_tensor",
526 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
527 CHECK_EQ(args.size(), 2);
528
529 // Get the attributes.
530 auto alloc_attrs = attrs.as<AllocTensorAttrs>();
531 CHECK(alloc_attrs != nullptr)
532 << "must be the alloc tensor attrs";
533 auto dtype = alloc_attrs->dtype;
534
535 // The storage will be passed dynamically.
536 this->VisitExpr(args[0]);
537 auto storage_register = last_register_;
538
539 // If the shape is constant then we will emit a static tensor allocation instruction.
540 auto const_shape = args[1].as<ConstantNode>();
541
542 if (const_shape) {
543 NDArray shape = const_shape->data;
544 std::vector<int64_t> raw_shape;
545 DLTensor tensor = shape.ToDLPack()->dl_tensor;
546 // TODO(@jroesch): we need to get an RFC done to standarize this
547 if (tensor.dtype.bits == 64) {
548 raw_shape = ToAllocTensorShape64(shape);
549 } else if (tensor.dtype.bits == 32) {
550 raw_shape = ToAllocTensorShape32(shape);
551 } else {
552 LOG(FATAL) << "unsupported bitwidth: " << tensor.dtype.bits;
553 }
554
555 // Add context field.
556 Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister()));
557 } else {
558 this->VisitExpr(args[1]);
559 auto shape_register = last_register_;
560 Emit(Instruction::AllocTensorReg(
561 storage_register,
562 shape_register,
563 dtype,
564 NewRegister()));
565 }
566 }).Match("memory.alloc_storage",
567 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
568 CHECK_EQ(args.size(), 2);
569 // Compute the size of the allocation.
570 this->VisitExpr(args[0]);
571 auto size_register = last_register_;
572
573 this->VisitExpr(args[1]);
574 auto alignment_register = last_register_;
575
576 // Get the dtype hint from the attributes.
577 auto alloc_attrs = attrs.as<AllocTensorAttrs>();
578 CHECK(alloc_attrs != nullptr)
579 << "must be the alloc tensor attrs";
580 auto dtype = alloc_attrs->dtype;
581
582 Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, NewRegister()));
583 }).Match("memory.shape_func",
584 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
585 CHECK_EQ(args.size(), 3);
586 auto shape_func = Downcast<Function>(args[0]);
587 auto inputs = Downcast<Tuple>(args[1]);
588 auto outputs = Downcast<Tuple>(args[2]);
589 EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
590 }).Match("memory.kill",
591 [](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
592 LOG(FATAL) << "memory.kill is not yet supported";
593 });
594 matcher(GetRef<Call>(call_node));
595 return;
596 }
597
598 // In the case its not one of these specialized operators we will generate code
599 // for one of the "standard" cases.
600 std::vector<Index> args_registers;
601
602 for (auto arg : call_node->args) {
603 this->VisitExpr(arg);
604 args_registers.push_back(last_register_);
605 }
606
607 if (auto global_node = op.as<GlobalVarNode>()) {
608 // In the case we are invoking a global we need to find its
609 // global ID, and then check whether it is closure invocation
610 // or whether it is a standard global, and emit the correct
611 // calling convention.
612 auto global = GetRef<GlobalVar>(global_node);
613 auto it = context_->global_map.find(global);
614 CHECK(it != context_->global_map.end());
615 DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
616 << " with func_index=" << it->second;
617 auto func = context_->module->Lookup(global);
618 if (IsClosure(func)) {
619 auto arity = func->params.size();
620 Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
621 } else {
622 Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
623 }
624 } else if (auto constructor_node = op.as<ConstructorNode>()) {
625 // In the constructor case, we simply need to find its tag
626 // and emit a call to allocate the data structure.
627 auto constructor = GetRef<Constructor>(constructor_node);
628 Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
629 NewRegister()));
630 } else if (auto var_node = op.as<VarNode>()) {
631 // If we are calling a variable, it must be the case that it is a closure so we
632 // emit invoke closure here.
633 VisitExpr(GetRef<Var>(var_node));
634 Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
635 } else {
636 // Finally if there are any other cases this is a bug.
637 LOG(FATAL) << "internal error: unreachable code,"
638 << "should be transformed away by previous passes"
639 << PrettyPrint(GetRef<Expr>(call_node));
640 }
641 }
642
VisitExpr_(const FunctionNode * func_node)643 void VisitExpr_(const FunctionNode* func_node) {
644 if (!func_node->IsPrimitive()) {
645 LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
646 << "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
647 << "AST: " << GetRef<Function>(func_node);
648 }
649 }
650
651 /*!
652 * \brief Compile a match value
653 * Generate byte code that compute the value specificed in val
654 *
655 * \return The register number assigned for the final value
656 */
CompileMatchValue(MatchValuePtr val)657 RegName CompileMatchValue(MatchValuePtr val) {
658 if (std::dynamic_pointer_cast<RegisterValue>(val)) {
659 auto r = std::dynamic_pointer_cast<RegisterValue>(val);
660 return r->rergister_num;
661 } else {
662 auto path = std::dynamic_pointer_cast<AccessField>(val);
663 auto p = CompileMatchValue(path->parent);
664 Emit(Instruction::GetField(p, path->index, NewRegister()));
665 path->reg = last_register_;
666 return path->reg;
667 }
668 }
669
CompileTreeNode(TreeNodePtr tree)670 void CompileTreeNode(TreeNodePtr tree) {
671 if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
672 auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
673 VisitExpr(node->body);
674 } else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
675 Emit(Instruction::Fatal());
676 } else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
677 auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
678 if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
679 // For Tag compariton, generate branches
680 auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
681 auto r = CompileMatchValue(cond->obj);
682 Emit(Instruction::GetTag(r, NewRegister()));
683 auto operand1 = last_register_;
684 Emit(Instruction::LoadConsti(cond->target_tag, NewRegister()));
685 auto operand2 = last_register_;
686
687 Emit(Instruction::If(operand1, operand2, 1, 0));
688 auto cond_offset = instructions_.size() - 1;
689 CompileTreeNode(node->then_branch);
690 auto if_reg = last_register_;
691 Emit(Instruction::Goto(1));
692 auto goto_offset = instructions_.size() - 1;
693 CompileTreeNode(node->else_branch);
694 auto else_reg = last_register_;
695 Emit(Instruction::Move(else_reg, if_reg));
696 last_register_ = if_reg;
697 auto else_offset = instructions_.size() - 1;
698 // Fixing offsets
699 instructions_[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
700 instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
701 } else {
702 // For other non-branch conditions, move to then_branch directly
703 auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
704 var_register_map_[cond->var] = CompileMatchValue(cond->val);
705 CompileTreeNode(node->then_branch);
706 }
707 }
708 }
709
710 /*!
711 * \brief Compile a pattern match expression
712 * It first converts the pattern match expression into a desicision tree, the condition
713 * could be object comparison or variable binding. If any of the condition fails in a clause,
714 * the decision tree switches to check the conditions of next clause and so on. If no clause
715 * matches the value, a fatal node is inserted.
716 *
717 * After the decision tree is built, we convert it into bytecodes using If/Goto.
718 */
CompileMatch(Match match)719 void CompileMatch(Match match) {
720 auto data = std::make_shared<RegisterValue>(last_register_);
721 auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses);
722 CompileTreeNode(decision_tree);
723 }
724
725 protected:
726 /*! \brief Store the expression a variable points to. */
727 std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map_;
728 /*! \brief Instructions in the VMFunction. */
729 std::vector<Instruction> instructions_;
730 /*! \brief Parameter names of the function. */
731 std::vector<std::string> params_;
732 /*! \brief Map from var to register number. */
733 std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map_;
734 /*! \brief Last used register number. */
735 size_t last_register_;
736 /*! \brief Total number of virtual registers allocated. */
737 size_t registers_num_;
738 /*! \brief Compiler engine to lower primitive functions. */
739 CompileEngine engine_;
740 /*! \brief Global shared meta data */
741 VMCompilerContext* context_;
742 /*! \brief Target devices. */
743 TargetsMap targets_;
744 /*! \brief Host target. */
745 Target target_host_;
746 };
747
748
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)749 PackedFunc VMCompiler::GetFunction(const std::string& name,
750 const ObjectPtr<Object>& sptr_to_self) {
751 if (name == "compile") {
752 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
753 CHECK_EQ(args.num_args, 3);
754 Module mod = args[0];
755 this->Compile(mod, args[1], args[2]);
756 });
757 } else if (name == "get_executable") {
758 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
759 *rv = runtime::Module(exec_);
760 });
761 } else if (name == "set_params") {
762 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
763 Map<std::string, Constant> params = args[0];
764 for (const auto& kv : params) {
765 this->SetParam(kv.first, kv.second->data);
766 }
767 });
768 } else {
769 LOG(FATAL) << "Unknown packed function: " << name;
770 return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
771 }
772 }
773
SetParam(const std::string & name,runtime::NDArray data_in)774 void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
775 params_[name] = data_in;
776 }
777
BindParamsByName(relay::Function func,const std::unordered_map<std::string,runtime::NDArray> & params)778 relay::Function VMCompiler::BindParamsByName(
779 relay::Function func,
780 const std::unordered_map<std::string, runtime::NDArray>& params) {
781 std::unordered_map<std::string, relay::Var> name_dict;
782 std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
783 for (auto arg : func->params) {
784 const auto &name = arg->name_hint();
785 if (name_dict.count(name)) {
786 repeat_var.insert(arg);
787 } else {
788 name_dict[name] = arg;
789 }
790 }
791 std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
792 for (auto &kv : params) {
793 if (name_dict.count(kv.first) == 0) {
794 continue;
795 }
796 auto arg = name_dict.at(kv.first);
797 if (repeat_var.count(arg)) {
798 LOG(FATAL) << "Multiple args in the function have name " << kv.first;
799 }
800 bind_dict[arg] = ConstantNode::make(kv.second);
801 }
802 Expr bound_expr = relay::Bind(func, bind_dict);
803 Function ret = Downcast<Function>(bound_expr);
804 CHECK(ret.defined())
805 << "The returning type is expected to be a Relay Function."
806 << "\n";
807 return ret;
808 }
809
Compile(Module mod,const TargetsMap & targets,const tvm::Target & target_host)810 void VMCompiler::Compile(Module mod,
811 const TargetsMap& targets,
812 const tvm::Target& target_host) {
813 CHECK_EQ(targets.size(), 1)
814 << "Currently VM compiler doesn't support heterogeneous compilation";
815 if (params_.size()) {
816 auto f = BindParamsByName(mod->Lookup("main"), params_);
817 auto gvar = mod->GetGlobalVar("main");
818 mod->Add(gvar, f);
819 }
820
821 InitVM();
822 targets_ = targets;
823 target_host_ = target_host;
824
825 // Run the optimizations necessary to target the VM.
826 context_.module = OptimizeModule(mod, targets_);
827
828 // Populate the global map.
829 //
830 // This maps global variables to a global index
831 // in the VMFunction table.
832 PopulateGlobalMap();
833
834 // Next we get ready by allocating space for
835 // the global state.
836 exec_->functions.resize(context_.module->functions.size());
837
838 for (auto named_func : context_.module->functions) {
839 auto gvar = named_func.first;
840 auto func = named_func.second;
841 VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
842 auto vm_func = func_compiler.Compile(gvar, func);
843
844 size_t func_index = context_.global_map.at(gvar);
845 CHECK(func_index < exec_->functions.size());
846 exec_->functions[func_index] = vm_func;
847 }
848
849 #if USE_RELAY_DEBUG
850 for (auto vm_func : exec_->functions) {
851 DLOG(INFO) << vm_func << "-------------";
852 }
853 #endif // USE_RELAY_DEBUG
854
855 // populate constants
856 for (auto data : context_.constants) {
857 exec_->constants.push_back(vm::Tensor(data));
858 }
859
860 LibraryCodegen();
861
862 for (auto gv : context_.global_map) {
863 exec_->global_map.insert({gv.first->name_hint, gv.second});
864 }
865 }
866
OptimizeModule(const Module & mod,const TargetsMap & targets)867 Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
868 Array<Pass> pass_seqs;
869 Array<tvm::Expr> entry_functions{tvm::Expr{"main"}};
870 pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
871 // Run all dialect legalization passes.
872 pass_seqs.push_back(relay::qnn::transform::Legalize());
873
874 // Legalize pass is restricted to homogeneous execution for now.
875 if (targets.size() == 1) {
876 pass_seqs.push_back(transform::Legalize());
877 }
878
879 // eta expand to support constructors in argument position
880 pass_seqs.push_back(transform::EtaExpand(
881 /* expand_constructor */ true, /* expand_global_var */ false));
882
883 pass_seqs.push_back(transform::SimplifyInference());
884 PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
885 Expr expr = args[0];
886 if (expr.as<CallNode>()) {
887 auto call_node = expr.as<CallNode>();
888 auto op_node = call_node->op.as<OpNode>();
889 if (op_node->name == "cast") {
890 auto attrs = call_node->attrs.as<CastAttrs>();
891 if (attrs->dtype == Int(32)) {
892 *rv = true;
893 }
894 }
895 }
896 *rv = false;
897 });
898 pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
899 pass_seqs.push_back(transform::InlinePrimitives());
900
901 pass_seqs.push_back(transform::CombineParallelConv2D(3));
902 pass_seqs.push_back(transform::CombineParallelDense(3));
903 pass_seqs.push_back(transform::FoldConstant());
904 pass_seqs.push_back(transform::FoldScaleAxis());
905 pass_seqs.push_back(transform::CanonicalizeCast());
906 pass_seqs.push_back(transform::CanonicalizeOps());
907
908 // Alter layout transformation is only applied to homogeneous execution yet.
909 if (targets.size() == 1) {
910 pass_seqs.push_back(transform::AlterOpLayout());
911 }
912
913 pass_seqs.push_back(transform::FoldConstant());
914
915 pass_seqs.push_back(transform::FuseOps());
916 pass_seqs.push_back(transform::ToANormalForm());
917 pass_seqs.push_back(transform::LambdaLift());
918 pass_seqs.push_back(transform::InlinePrimitives());
919
920 // Manifest the allocations.
921 pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
922 // Compute away possibly introduced constant computation.
923 pass_seqs.push_back(transform::FoldConstant());
924 // Fuse the shape functions.
925 pass_seqs.push_back(transform::FuseOps());
926 // Manifest the allocations needed for the shape functions.
927 pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
928
929 transform::Sequential seq(pass_seqs);
930 transform::PassContext pass_ctx = PassContext::Current();
931 // TODO(wweic): Support heterogenous execution
932 tvm::With<relay::transform::PassContext> ctx(pass_ctx);
933 if (targets.size() == 1) {
934 const auto& it = targets.begin();
935 With<Target> tctx((*it).second);
936 return seq(mod);
937 }
938 return seq(mod);
939 }
940
PopulateGlobalMap()941 void VMCompiler::PopulateGlobalMap() {
942 // First we populate global map.
943 size_t global_index = 0;
944 for (auto named_func : context_.module->functions) {
945 auto gvar = named_func.first;
946 context_.global_map.insert({gvar, global_index++});
947 }
948 }
949
LibraryCodegen()950 void VMCompiler::LibraryCodegen() {
951 auto const &cached_funcs = context_.cached_funcs;
952 if (cached_funcs.size() == 0) {
953 return;
954 }
955 std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
956 for (auto &cfunc : cached_funcs) {
957 std::string target_str = cfunc->target->str();
958 if (tgt_funcs.count(target_str) == 0) {
959 tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
960 } else {
961 tgt_funcs[target_str].push_back(cfunc->funcs[0]);
962 }
963 }
964 Map<Target, Array<LoweredFunc>> funcs;
965 for (auto &it : tgt_funcs) {
966 funcs.Set(Target::Create(it.first), it.second);
967 }
968
969 if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
970 // The target is just a dummy arg because funcs already contains corresponding target
971 // therefore target won't be used in the build function
972 runtime::Module mod = (*f)(funcs, Target(), target_host_);
973 CHECK(mod.operator->());
974 exec_->lib = mod;
975 } else {
976 LOG(FATAL) << "relay.backend.build is not registered";
977 }
978 size_t primitive_index = 0;
979 for (auto cfunc : cached_funcs) {
980 exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
981 }
982 }
983
CreateVMCompiler()984 runtime::Module CreateVMCompiler() {
985 auto exec = make_object<VMCompiler>();
986 return runtime::Module(exec);
987 }
988
989 TVM_REGISTER_GLOBAL("relay._vm._VMCompiler")
__anon460f10160b02(TVMArgs args, TVMRetValue* rv) 990 .set_body([](TVMArgs args, TVMRetValue* rv) {
991 *rv = CreateVMCompiler();
992 });
993
994 } // namespace vm
995 } // namespace relay
996 } // namespace tvm
997