1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/relay/backend/vm/compiler.cc
22  * \brief A compiler from relay::Module to the VM byte code.
23  */
24 
25 #include <tvm/operation.h>
26 #include <tvm/relay/error.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/relay/interpreter.h>
29 #include <tvm/relay/qnn/transform.h>
30 #include <tvm/logging.h>
31 #include <tvm/relay/transform.h>
32 #include <tvm/runtime/vm.h>
33 #include <tvm/relay/attrs/memory.h>
34 #include <topi/tags.h>
35 #include <algorithm>
36 #include <iostream>
37 #include <memory>
38 #include <set>
39 #include <string>
40 #include <tuple>
41 #include <unordered_map>
42 #include <unordered_set>
43 #include <vector>
44 #include "../../../runtime/vm/naive_allocator.h"
45 #include "../../backend/compile_engine.h"
46 #include "../../pass/pass_util.h"
47 #include "../../op/op_common.h"
48 #include "compiler.h"
49 
50 namespace tvm {
51 namespace relay {
52 
53 namespace transform {
54 
55 Pass LambdaLift();
56 Pass InlinePrimitives();
57 Pass RemoveUnusedFunctions(Array<tvm::Expr> entry_functions);
58 
ManifestAlloc(Target target_host)59 Pass ManifestAlloc(Target target_host) {
60   auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc");
61   CHECK(f != nullptr) << "could not load memory allocation pass";
62   return (*f)(target_host);
63 }
64 
65 }  // namespace transform
66 
67 namespace vm {
68 
69 using namespace tvm::runtime;
70 using namespace tvm::runtime::vm;
71 using namespace relay::transform;
72 
73 // (@jroesch): VM passes, eventually declare as passes.
74 bool IsClosure(const Function& func);
75 
76 void InstructionPrint(std::ostream& os, const Instruction& instr);
77 
78 // Represent a runtime object that's going to be matched by pattern match expressions
79 struct MatchValue {
~MatchValuetvm::relay::vm::MatchValue80   virtual ~MatchValue() {}
81 };
82 using MatchValuePtr = std::shared_ptr<MatchValue>;
83 
84 // A runtime object that resides in a register
85 struct RegisterValue : MatchValue {
86   // The register num
87   RegName rergister_num;
88 
RegisterValuetvm::relay::vm::RegisterValue89   explicit RegisterValue(RegName reg) : rergister_num(reg) {}
90 
~RegisterValuetvm::relay::vm::RegisterValue91   ~RegisterValue() {}
92 };
93 
94 // The value is a field of another runtime object
95 struct AccessField : MatchValue {
96   MatchValuePtr parent;
97   // Field index
98   size_t index;
99   // Runtime register num after compiling the access field path
100   RegName reg{-1};
101 
AccessFieldtvm::relay::vm::AccessField102   AccessField(MatchValuePtr parent, size_t index)
103   : parent(parent), index(index) {}
104 
~AccessFieldtvm::relay::vm::AccessField105   ~AccessField() {}
106 };
107 
108 /*!
109  * \brief Condition in a decision tree
110  */
111 struct ConditionNode {
~ConditionNodetvm::relay::vm::ConditionNode112   virtual ~ConditionNode() {}
113 };
114 
115 using ConditionNodePtr = std::shared_ptr<ConditionNode>;
116 
117 /*!
118  * \brief A var binding condition
119  */
120 struct VarBinding : ConditionNode {
121   Var var;
122   MatchValuePtr val;
123 
VarBindingtvm::relay::vm::VarBinding124   VarBinding(Var var, MatchValuePtr val)
125           : var(var), val(val) {}
126 
~VarBindingtvm::relay::vm::VarBinding127   ~VarBinding() {}
128 };
129 
130 /*!
131  * \brief Compare the tag of the object
132  */
133 struct TagCompare : ConditionNode {
134   /*! \brief The object to be examined */
135   MatchValuePtr obj;
136 
137   /*! \brief The expected tag */
138   int target_tag;
139 
TagComparetvm::relay::vm::TagCompare140   TagCompare(MatchValuePtr obj, size_t target)
141           : obj(obj), target_tag(target) {
142   }
143 
~TagComparetvm::relay::vm::TagCompare144   ~TagCompare() {}
145 };
146 
147 using TreeNodePtr = typename relay::TreeNode<ConditionNodePtr>::pointer;
148 using TreeLeafNode = relay::TreeLeafNode<ConditionNodePtr>;
149 using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionNodePtr>;
150 using TreeBranchNode = relay::TreeBranchNode<ConditionNodePtr>;
151 
BuildDecisionTreeFromPattern(MatchValuePtr data,Pattern pattern,TreeNodePtr then_branch,TreeNodePtr else_branch)152 TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
153                                          Pattern pattern,
154                                          TreeNodePtr then_branch,
155                                          TreeNodePtr else_branch) {
156   if (pattern.as<PatternWildcardNode>()) {
157     // We ignore wildcard binding since it's not producing new vars
158     return then_branch;
159   } else if (pattern.as<PatternVarNode>()) {
160     auto pat = pattern.as<PatternVarNode>();
161     auto pattern = GetRef<PatternVar>(pat);
162     auto cond = std::make_shared<VarBinding>(pattern->var, data);
163     return TreeBranchNode::Make(cond, then_branch, else_branch);
164   } else if (auto pcn = pattern.as<PatternConstructorNode>()) {
165     auto tag = pcn->constructor->tag;
166 
167     size_t field_index = 0;
168     for (auto& p : pcn->patterns) {
169       auto d = std::make_shared<AccessField>(data, field_index);
170       then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
171       field_index++;
172     }
173     auto cond = std::make_shared<TagCompare>(data, tag);
174     return TreeBranchNode::Make(cond, then_branch, else_branch);
175   } else {
176     auto pt = pattern.as<PatternTupleNode>();
177     CHECK(pt) << "unhandled case: " << pattern;
178     size_t field_index = 0;
179     for (auto& p : pt->patterns) {
180       auto d = std::make_shared<AccessField>(data, field_index);
181       then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
182       field_index++;
183     }
184     return then_branch;
185   }
186 }
187 
BuildDecisionTreeFromClause(MatchValuePtr data,Clause clause,TreeNodePtr else_branch)188 TreeNodePtr BuildDecisionTreeFromClause(MatchValuePtr data,
189                                         Clause clause,
190                                         TreeNodePtr else_branch) {
191   return BuildDecisionTreeFromPattern(data, clause->lhs,
192                                       TreeLeafNode::Make(clause->rhs), else_branch);
193 }
194 
BuildDecisionTreeFromClauses(MatchValuePtr data,tvm::Array<Clause> clauses)195 TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
196   // When nothing matches, the VM throws fatal error
197   TreeNodePtr else_branch = TreeLeafFatalNode::Make();
198   // Start from the last clause
199   for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) {
200     else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
201   }
202   return else_branch;
203 }
204 
ToAllocTensorShape64(NDArray shape)205 std::vector<int64_t> ToAllocTensorShape64(NDArray shape) {
206   std::vector<int64_t> raw_shape;
207   DLTensor tensor = shape.ToDLPack()->dl_tensor;
208   CHECK_EQ(tensor.ndim, 1u);
209   CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
210 
211   // TODO(@jroesch): we really need to standaridize the bit width of
212   // all of the shape manipulating code.
213   CHECK_EQ(tensor.dtype.bits, 64) << "found " << tensor.dtype.bits;
214   int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
215   for (auto i = 0; i < tensor.shape[0]; i++) {
216     raw_shape.push_back(int_ptr[i]);
217   }
218   return raw_shape;
219 }
220 
221 
ToAllocTensorShape32(NDArray shape)222 std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
223   std::vector<int64_t> raw_shape;
224   DLTensor tensor = shape.ToDLPack()->dl_tensor;
225   CHECK_EQ(tensor.ndim, 1u);
226   CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
227 
228   // TODO(@jroesch): we really need to standaridize the bit width of
229   // all of the shape manipulating code.
230   CHECK_LE(tensor.dtype.bits, 32) << "found " << tensor.dtype.bits;
231   int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
232   for (auto i = 0; i < tensor.shape[0]; i++) {
233     raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
234   }
235   return raw_shape;
236 }
237 
238 class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
239  public:
VMFunctionCompiler(VMCompilerContext * context,TargetsMap targets,Target target_host)240   VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
241       : last_register_(0),
242         registers_num_(0),
243         engine_(CompileEngine::Global()),
244         context_(context),
245         targets_(targets),
246         target_host_(target_host) {}
247 
Compile(const GlobalVar & var,const Function & func)248   VMFunction Compile(const GlobalVar& var, const Function& func) {
249     size_t i = 0;
250     // We then assign register num to the free variables
251     for (auto param : func->params) {
252       auto arg_register = NewRegister();
253       CHECK_EQ(i, arg_register);
254       var_register_map_.insert({param, arg_register});
255       params_.push_back(param->name_hint());
256       ++i;
257     }
258 
259     if (IsClosure(func)) {
260       Function inner_func = Downcast<Function>(func->body);
261       for (auto param : inner_func->params) {
262         auto arg_register = NewRegister();
263         CHECK_EQ(i, arg_register);
264         var_register_map_.insert({param, arg_register});
265         params_.push_back(param->name_hint());
266         ++i;
267       }
268       this->VisitExpr(inner_func->body);
269     } else {
270       this->VisitExpr(func->body);
271     }
272     instructions_.push_back(Instruction::Ret(last_register_));
273     return VMFunction(var->name_hint, params_, instructions_, registers_num_);
274   }
275 
276  protected:
NewRegister()277   size_t NewRegister() { return registers_num_++; }
278 
Emit(const Instruction & instr)279   inline void Emit(const Instruction& instr) {
280     DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
281     CHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
282     switch (instr.op) {
283       case Opcode::AllocADT:
284       case Opcode::AllocTensor:
285       case Opcode::AllocTensorReg:
286       case Opcode::GetField:
287       case Opcode::GetTag:
288       case Opcode::LoadConst:
289       case Opcode::LoadConsti:
290       case Opcode::Invoke:
291       case Opcode::AllocClosure:
292       case Opcode::AllocStorage:
293       case Opcode::Move:
294       case Opcode::InvokeClosure:
295         last_register_ = instr.dst;
296         break;
297       case Opcode::InvokePacked:
298       case Opcode::If:
299       case Opcode::Ret:
300       case Opcode::Goto:
301       case Opcode::Fatal:
302         break;
303     }
304     instructions_.push_back(instr);
305   }
306 
VisitExpr_(const ConstantNode * const_node)307   void VisitExpr_(const ConstantNode* const_node) {
308     size_t konst_idx = context_->constants.size();
309     context_->constants.push_back(const_node->data);
310     Emit(Instruction::LoadConst(konst_idx, NewRegister()));
311   }
312 
VisitExpr_(const VarNode * var_node)313   void VisitExpr_(const VarNode* var_node) {
314     auto var = GetRef<Var>(var_node);
315     auto reg_it = this->var_register_map_.find(var);
316     CHECK(reg_it != this->var_register_map_.end());
317     last_register_ = reg_it->second;
318   }
319 
VisitExpr_(const TupleNode * tuple_node)320   void VisitExpr_(const TupleNode* tuple_node) {
321     auto tuple = GetRef<Tuple>(tuple_node);
322     std::vector<Index> fields_registers;
323 
324     for (auto& field : tuple->fields) {
325       this->VisitExpr(field);
326       fields_registers.push_back(last_register_);
327     }
328 
329     // TODO(@jroesch): use correct tag
330     Emit(Instruction::AllocADT(
331       0,
332       tuple->fields.size(),
333       fields_registers,
334       NewRegister()));
335   }
336 
VisitExpr_(const MatchNode * match_node)337   void VisitExpr_(const MatchNode* match_node) {
338     auto match = GetRef<Match>(match_node);
339 
340     this->VisitExpr(match->data);
341     CompileMatch(match);
342   }
343 
VisitExpr_(const LetNode * let_node)344   void VisitExpr_(const LetNode* let_node) {
345     DLOG(INFO) << PrettyPrint(let_node->value);
346     this->VisitExpr(let_node->value);
347     var_register_map_.insert({let_node->var, this->last_register_});
348     this->VisitExpr(let_node->body);
349   }
350 
VisitExpr_(const TupleGetItemNode * get_node)351   void VisitExpr_(const TupleGetItemNode* get_node) {
352     auto get = GetRef<TupleGetItem>(get_node);
353     this->VisitExpr(get->tuple);
354     auto tuple_register = last_register_;
355     Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
356   }
357 
VisitExpr_(const GlobalVarNode * gvar)358   void VisitExpr_(const GlobalVarNode* gvar) {
359     auto var = GetRef<GlobalVar>(gvar);
360     auto func = context_->module->Lookup(var);
361     auto it = context_->global_map.find(var);
362     CHECK(it != context_->global_map.end());
363     // Allocate closure with zero free vars
364     Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister()));
365   }
366 
VisitExpr_(const IfNode * if_node)367   void VisitExpr_(const IfNode* if_node) {
368     this->VisitExpr(if_node->cond);
369 
370     size_t test_register = last_register_;
371 
372     this->Emit(Instruction::LoadConsti(1, NewRegister()));
373     auto after_cond = instructions_.size();
374     auto target_register = last_register_;
375     this->Emit(Instruction::If(test_register, target_register, 0, 0));
376     this->VisitExpr(if_node->true_branch);
377 
378     // It saves the result of If-Else expression.
379     auto merge_register = NewRegister();
380     Emit(Instruction::Move(last_register_, merge_register));
381     Emit(Instruction::Goto(0));
382 
383     // Finally store how many instructions there are in the
384     // true branch.
385     auto after_true = this->instructions_.size();
386 
387     this->VisitExpr(if_node->false_branch);
388 
389     size_t false_register = last_register_;
390 
391     // In else-branch, override the then-branch register
392     Emit(Instruction::Move(false_register, merge_register));
393     // Compute the total number of instructions
394     // after generating false.
395     auto after_false = this->instructions_.size();
396 
397     // Now we will compute the jump targets in order
398     // to properly patch the instruction with the
399     // the requiste targets.
400 
401     // After we emit the true body, and false body,
402     // we patch up the if instruction, and goto.
403     auto true_offset = 1;
404     auto false_offset = after_true - after_cond;
405     instructions_[after_cond].if_op.true_offset = true_offset;
406     instructions_[after_cond].if_op.false_offset = false_offset;
407 
408     // Patch the Goto.
409     this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;
410 
411     this->last_register_ = merge_register;
412   }
413 
EmitShapeFunc(Function func,Array<Expr> inputs,Array<Expr> outputs)414   void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
415     // Lower shape function
416     auto key = CCacheKeyNode::make(func, target_host_);
417     auto cfunc = engine_->LowerShapeFunc(key);
418     int op_index = -1;
419     if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
420       op_index = context_->cached_funcs.size();
421       context_->cached_funcs.push_back(cfunc);
422       context_->seen_funcs[cfunc->funcs[0]] = op_index;
423     } else {
424       op_index = context_->seen_funcs[cfunc->funcs[0]];
425     }
426 
427     // Prepare input and output registers
428     std::vector<Index> argument_registers;
429     for (auto input : inputs) {
430       auto reg = var_register_map_.find(Downcast<Var>(input));
431       CHECK(reg != var_register_map_.end())
432         << "internal error: all variables should be in the register mapping";
433       argument_registers.push_back(reg->second);
434     }
435 
436     for (auto output : outputs) {
437       auto reg = var_register_map_.find(Downcast<Var>(output));
438       CHECK(reg != var_register_map_.end())
439         << "internal error: all variables should be in the register mapping";
440       argument_registers.push_back(reg->second);
441     }
442 
443     Emit(Instruction::InvokePacked(op_index,
444       argument_registers.size(),
445       outputs.size(),
446       argument_registers));
447   }
448 
EmitInvokeTVMOp(const Function & func,const Expr & inputs,const Expr & outputs)449   void EmitInvokeTVMOp(const Function& func,
450                        const Expr& inputs,
451                        const Expr& outputs) {
452     std::vector<Index> argument_registers;
453 
454     CHECK(func->IsPrimitive())
455       << "internal error: invoke_tvm_op requires the first argument to be a relay::Function";
456 
457     auto input_tuple = inputs.as<TupleNode>();
458     CHECK(input_tuple)
459       << "internal error: invoke_tvm_op inputs must be a tuple,"
460       << "please file a bug in the memory manifestation pass";
461 
462     auto output_tuple = outputs.as<TupleNode>();
463     CHECK(output_tuple)
464       << "internal error: invoke_tvm_op outputs must be a tuple,"
465       << "please file a bug in the memory manifestation pass";
466 
467     for (auto input : input_tuple->fields) {
468       auto reg = var_register_map_.find(Downcast<Var>(input));
469       CHECK(reg != var_register_map_.end())
470         << "internal error: all variables should be in the register mapping";
471       argument_registers.push_back(reg->second);
472     }
473 
474     for (auto output : output_tuple->fields) {
475       auto reg = var_register_map_.find(Downcast<Var>(output));
476       CHECK(reg != var_register_map_.end())
477         << "internal error: all variables should be in the register mapping";
478       argument_registers.push_back(reg->second);
479     }
480 
481     // Next generate the invoke instruction.
482     Target target;
483     if (targets_.size() == 1) {
484       // homogeneous execution.
485       for (auto kv : targets_) {
486         target = kv.second;
487       }
488     } else {
489       // heterogeneous execution.
490       LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
491     }
492 
493     auto key = CCacheKeyNode::make(func, target);
494     auto cfunc = engine_->Lower(key);
495 
496     // TODO(jroesch): support lowered funcs for multiple targets
497     CHECK_EQ(cfunc->funcs.size(), 1);
498     auto op_index = -1;
499     if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
500       op_index = context_->cached_funcs.size();
501       context_->cached_funcs.push_back(cfunc);
502       context_->seen_funcs[cfunc->funcs[0]] = op_index;
503     } else {
504       op_index = context_->seen_funcs[cfunc->funcs[0]];
505     }
506 
507     Emit(Instruction::InvokePacked(op_index,
508       argument_registers.size(),
509       output_tuple->fields.size(),
510       argument_registers));
511   }
512 
VisitExpr_(const CallNode * call_node)513   void VisitExpr_(const CallNode* call_node) {
514     Expr op = call_node->op;
515 
516     // First we handle the case in which we are using an opaque
517     // operator used to define a sub-dialect, such as memory
518     // allocation operations.
519     if (op.as<OpNode>()) {
520       OpMatch<void> matcher;
521       matcher.Match("memory.invoke_tvm_op",
522         [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
523           CHECK_EQ(args.size(), 3);
524           EmitInvokeTVMOp(Downcast<Function>(args[0]), args[1], args[2]);
525       }).Match("memory.alloc_tensor",
526         [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
527           CHECK_EQ(args.size(), 2);
528 
529           // Get the attributes.
530           auto alloc_attrs = attrs.as<AllocTensorAttrs>();
531           CHECK(alloc_attrs != nullptr)
532               << "must be the alloc tensor attrs";
533           auto dtype = alloc_attrs->dtype;
534 
535           // The storage will be passed dynamically.
536           this->VisitExpr(args[0]);
537           auto storage_register = last_register_;
538 
539           // If the shape is constant then we will emit a static tensor allocation instruction.
540           auto const_shape = args[1].as<ConstantNode>();
541 
542           if (const_shape) {
543             NDArray shape = const_shape->data;
544             std::vector<int64_t> raw_shape;
545             DLTensor tensor = shape.ToDLPack()->dl_tensor;
546             // TODO(@jroesch): we need to get an RFC done to standarize this
547             if (tensor.dtype.bits == 64) {
548               raw_shape = ToAllocTensorShape64(shape);
549             } else if (tensor.dtype.bits == 32) {
550               raw_shape = ToAllocTensorShape32(shape);
551             } else {
552               LOG(FATAL) << "unsupported bitwidth: " << tensor.dtype.bits;
553             }
554 
555             // Add context field.
556             Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister()));
557           } else {
558             this->VisitExpr(args[1]);
559             auto shape_register = last_register_;
560             Emit(Instruction::AllocTensorReg(
561               storage_register,
562               shape_register,
563               dtype,
564               NewRegister()));
565           }
566       }).Match("memory.alloc_storage",
567         [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
568           CHECK_EQ(args.size(), 2);
569           // Compute the size of the allocation.
570           this->VisitExpr(args[0]);
571           auto size_register = last_register_;
572 
573           this->VisitExpr(args[1]);
574           auto alignment_register = last_register_;
575 
576           // Get the dtype hint from the attributes.
577           auto alloc_attrs = attrs.as<AllocTensorAttrs>();
578           CHECK(alloc_attrs != nullptr)
579               << "must be the alloc tensor attrs";
580           auto dtype = alloc_attrs->dtype;
581 
582           Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, NewRegister()));
583       }).Match("memory.shape_func",
584         [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
585           CHECK_EQ(args.size(), 3);
586           auto shape_func = Downcast<Function>(args[0]);
587           auto inputs = Downcast<Tuple>(args[1]);
588           auto outputs = Downcast<Tuple>(args[2]);
589           EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
590       }).Match("memory.kill",
591         [](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
592           LOG(FATAL) << "memory.kill is not yet supported";
593       });
594       matcher(GetRef<Call>(call_node));
595       return;
596     }
597 
598     // In the case its not one of these specialized operators we will generate code
599     // for one of the "standard" cases.
600     std::vector<Index> args_registers;
601 
602     for (auto arg : call_node->args) {
603       this->VisitExpr(arg);
604       args_registers.push_back(last_register_);
605     }
606 
607     if (auto global_node = op.as<GlobalVarNode>()) {
608       // In the case we are invoking a global we need to find its
609       // global ID, and then check whether it is closure invocation
610       // or whether it is a standard global, and emit the correct
611       // calling convention.
612       auto global = GetRef<GlobalVar>(global_node);
613       auto it = context_->global_map.find(global);
614       CHECK(it != context_->global_map.end());
615       DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
616                       << " with func_index=" << it->second;
617       auto func = context_->module->Lookup(global);
618       if (IsClosure(func)) {
619         auto arity = func->params.size();
620         Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
621       } else {
622         Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
623       }
624     } else if (auto constructor_node = op.as<ConstructorNode>()) {
625       // In the constructor case, we simply need to find its tag
626       // and emit a call to allocate the data structure.
627       auto constructor = GetRef<Constructor>(constructor_node);
628       Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
629                                       NewRegister()));
630     } else if (auto var_node = op.as<VarNode>()) {
631       // If we are calling a variable, it must be the case that it is a closure so we
632       // emit invoke closure here.
633       VisitExpr(GetRef<Var>(var_node));
634       Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
635     } else {
636       // Finally if there are any other cases this is a bug.
637       LOG(FATAL) << "internal error: unreachable code,"
638                  << "should be transformed away by previous passes"
639                  << PrettyPrint(GetRef<Expr>(call_node));
640     }
641   }
642 
VisitExpr_(const FunctionNode * func_node)643   void VisitExpr_(const FunctionNode* func_node) {
644     if (!func_node->IsPrimitive()) {
645       LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
646                  << "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
647                  << "AST: " << GetRef<Function>(func_node);
648     }
649   }
650 
651   /*!
652    * \brief Compile a match value
653    * Generate byte code that compute the value specificed in val
654    *
655    * \return The register number assigned for the final value
656    */
CompileMatchValue(MatchValuePtr val)657   RegName CompileMatchValue(MatchValuePtr val) {
658     if (std::dynamic_pointer_cast<RegisterValue>(val)) {
659       auto r = std::dynamic_pointer_cast<RegisterValue>(val);
660       return r->rergister_num;
661     } else {
662       auto path = std::dynamic_pointer_cast<AccessField>(val);
663       auto p = CompileMatchValue(path->parent);
664       Emit(Instruction::GetField(p, path->index, NewRegister()));
665       path->reg = last_register_;
666       return path->reg;
667     }
668   }
669 
CompileTreeNode(TreeNodePtr tree)670   void CompileTreeNode(TreeNodePtr tree) {
671     if (std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
672       auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree);
673       VisitExpr(node->body);
674     } else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
675       Emit(Instruction::Fatal());
676     } else if (std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
677       auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree);
678       if (std::dynamic_pointer_cast<TagCompare>(node->cond)) {
679         // For Tag compariton, generate branches
680         auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond);
681         auto r = CompileMatchValue(cond->obj);
682         Emit(Instruction::GetTag(r, NewRegister()));
683         auto operand1 = last_register_;
684         Emit(Instruction::LoadConsti(cond->target_tag, NewRegister()));
685         auto operand2 = last_register_;
686 
687         Emit(Instruction::If(operand1, operand2, 1, 0));
688         auto cond_offset = instructions_.size() - 1;
689         CompileTreeNode(node->then_branch);
690         auto if_reg = last_register_;
691         Emit(Instruction::Goto(1));
692         auto goto_offset = instructions_.size() - 1;
693         CompileTreeNode(node->else_branch);
694         auto else_reg = last_register_;
695         Emit(Instruction::Move(else_reg, if_reg));
696         last_register_ = if_reg;
697         auto else_offset = instructions_.size() - 1;
698         // Fixing offsets
699         instructions_[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
700         instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
701       } else {
702         // For other non-branch conditions, move to then_branch directly
703         auto cond = std::dynamic_pointer_cast<VarBinding>(node->cond);
704         var_register_map_[cond->var] = CompileMatchValue(cond->val);
705         CompileTreeNode(node->then_branch);
706       }
707     }
708   }
709 
710   /*!
711    * \brief Compile a pattern match expression
712    * It first converts the pattern match expression into a desicision tree, the condition
713    * could be object comparison or variable binding. If any of the condition fails in a clause,
714    * the decision tree switches to check the conditions of next clause and so on. If no clause
715    * matches the value, a fatal node is inserted.
716    *
717    * After the decision tree is built, we convert it into bytecodes using If/Goto.
718    */
CompileMatch(Match match)719   void CompileMatch(Match match) {
720     auto data = std::make_shared<RegisterValue>(last_register_);
721     auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses);
722     CompileTreeNode(decision_tree);
723   }
724 
725  protected:
726   /*! \brief Store the expression a variable points to. */
727   std::unordered_map<Var, Expr, NodeHash, NodeEqual> expr_map_;
728   /*! \brief Instructions in the VMFunction. */
729   std::vector<Instruction> instructions_;
730   /*! \brief Parameter names of the function. */
731   std::vector<std::string> params_;
732   /*! \brief Map from var to register number. */
733   std::unordered_map<Var, RegName, NodeHash, NodeEqual> var_register_map_;
734   /*! \brief Last used register number. */
735   size_t last_register_;
736   /*! \brief Total number of virtual registers allocated. */
737   size_t registers_num_;
738   /*! \brief Compiler engine to lower primitive functions. */
739   CompileEngine engine_;
740   /*! \brief Global shared meta data */
741   VMCompilerContext* context_;
742   /*! \brief Target devices. */
743   TargetsMap targets_;
744   /*! \brief Host target. */
745   Target target_host_;
746 };
747 
748 
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)749 PackedFunc VMCompiler::GetFunction(const std::string& name,
750                                    const ObjectPtr<Object>& sptr_to_self) {
751   if (name == "compile") {
752     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
753       CHECK_EQ(args.num_args, 3);
754       Module mod = args[0];
755       this->Compile(mod, args[1], args[2]);
756     });
757   } else if (name == "get_executable") {
758     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
759       *rv = runtime::Module(exec_);
760     });
761   } else if (name == "set_params") {
762     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
763       Map<std::string, Constant> params = args[0];
764       for (const auto& kv : params) {
765         this->SetParam(kv.first, kv.second->data);
766       }
767     });
768   } else {
769     LOG(FATAL) << "Unknown packed function: " << name;
770     return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
771   }
772 }
773 
SetParam(const std::string & name,runtime::NDArray data_in)774 void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
775   params_[name] = data_in;
776 }
777 
BindParamsByName(relay::Function func,const std::unordered_map<std::string,runtime::NDArray> & params)778 relay::Function VMCompiler::BindParamsByName(
779     relay::Function func,
780     const std::unordered_map<std::string, runtime::NDArray>& params) {
781   std::unordered_map<std::string, relay::Var> name_dict;
782   std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
783   for (auto arg : func->params) {
784     const auto &name = arg->name_hint();
785     if (name_dict.count(name)) {
786       repeat_var.insert(arg);
787     } else {
788       name_dict[name] = arg;
789     }
790   }
791   std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
792   for (auto &kv : params) {
793     if (name_dict.count(kv.first) == 0) {
794       continue;
795     }
796     auto arg = name_dict.at(kv.first);
797     if (repeat_var.count(arg)) {
798       LOG(FATAL) << "Multiple args in the function have name " << kv.first;
799     }
800     bind_dict[arg] = ConstantNode::make(kv.second);
801   }
802   Expr bound_expr = relay::Bind(func, bind_dict);
803   Function ret = Downcast<Function>(bound_expr);
804   CHECK(ret.defined())
805       << "The returning type is expected to be a Relay Function."
806       << "\n";
807   return ret;
808 }
809 
Compile(Module mod,const TargetsMap & targets,const tvm::Target & target_host)810 void VMCompiler::Compile(Module mod,
811                          const TargetsMap& targets,
812                          const tvm::Target& target_host) {
813   CHECK_EQ(targets.size(), 1)
814     << "Currently VM compiler doesn't support heterogeneous compilation";
815   if (params_.size()) {
816     auto f = BindParamsByName(mod->Lookup("main"), params_);
817     auto gvar = mod->GetGlobalVar("main");
818     mod->Add(gvar, f);
819   }
820 
821   InitVM();
822   targets_ = targets;
823   target_host_ = target_host;
824 
825   // Run the optimizations necessary to target the VM.
826   context_.module = OptimizeModule(mod, targets_);
827 
828   // Populate the global map.
829   //
830   // This maps global variables to a global index
831   // in the VMFunction table.
832   PopulateGlobalMap();
833 
834   // Next we get ready by allocating space for
835   // the global state.
836   exec_->functions.resize(context_.module->functions.size());
837 
838   for (auto named_func : context_.module->functions) {
839     auto gvar = named_func.first;
840     auto func = named_func.second;
841     VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
842     auto vm_func = func_compiler.Compile(gvar, func);
843 
844     size_t func_index = context_.global_map.at(gvar);
845     CHECK(func_index < exec_->functions.size());
846     exec_->functions[func_index] = vm_func;
847   }
848 
849 #if USE_RELAY_DEBUG
850   for (auto vm_func : exec_->functions) {
851     DLOG(INFO) << vm_func << "-------------";
852   }
853 #endif  // USE_RELAY_DEBUG
854 
855   // populate constants
856   for (auto data : context_.constants) {
857     exec_->constants.push_back(vm::Tensor(data));
858   }
859 
860   LibraryCodegen();
861 
862   for (auto gv : context_.global_map) {
863     exec_->global_map.insert({gv.first->name_hint, gv.second});
864   }
865 }
866 
OptimizeModule(const Module & mod,const TargetsMap & targets)867 Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) {
868   Array<Pass> pass_seqs;
869   Array<tvm::Expr> entry_functions{tvm::Expr{"main"}};
870   pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
871   // Run all dialect legalization passes.
872   pass_seqs.push_back(relay::qnn::transform::Legalize());
873 
874   // Legalize pass is restricted to homogeneous execution for now.
875   if (targets.size() == 1) {
876     pass_seqs.push_back(transform::Legalize());
877   }
878 
879   // eta expand to support constructors in argument position
880   pass_seqs.push_back(transform::EtaExpand(
881     /* expand_constructor */ true, /* expand_global_var */ false));
882 
883   pass_seqs.push_back(transform::SimplifyInference());
884   PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
885     Expr expr = args[0];
886     if (expr.as<CallNode>()) {
887       auto call_node = expr.as<CallNode>();
888       auto op_node = call_node->op.as<OpNode>();
889       if (op_node->name == "cast") {
890         auto attrs = call_node->attrs.as<CastAttrs>();
891         if (attrs->dtype == Int(32)) {
892           *rv = true;
893         }
894       }
895     }
896     *rv = false;
897   });
898   pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
899   pass_seqs.push_back(transform::InlinePrimitives());
900 
901   pass_seqs.push_back(transform::CombineParallelConv2D(3));
902   pass_seqs.push_back(transform::CombineParallelDense(3));
903   pass_seqs.push_back(transform::FoldConstant());
904   pass_seqs.push_back(transform::FoldScaleAxis());
905   pass_seqs.push_back(transform::CanonicalizeCast());
906   pass_seqs.push_back(transform::CanonicalizeOps());
907 
908   // Alter layout transformation is only applied to homogeneous execution yet.
909   if (targets.size() == 1) {
910     pass_seqs.push_back(transform::AlterOpLayout());
911   }
912 
913   pass_seqs.push_back(transform::FoldConstant());
914 
915   pass_seqs.push_back(transform::FuseOps());
916   pass_seqs.push_back(transform::ToANormalForm());
917   pass_seqs.push_back(transform::LambdaLift());
918   pass_seqs.push_back(transform::InlinePrimitives());
919 
920   // Manifest the allocations.
921   pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
922   // Compute away possibly introduced constant computation.
923   pass_seqs.push_back(transform::FoldConstant());
924   // Fuse the shape functions.
925   pass_seqs.push_back(transform::FuseOps());
926   // Manifest the allocations needed for the shape functions.
927   pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
928 
929   transform::Sequential seq(pass_seqs);
930   transform::PassContext pass_ctx = PassContext::Current();
931   // TODO(wweic): Support heterogenous execution
932   tvm::With<relay::transform::PassContext> ctx(pass_ctx);
933   if (targets.size() == 1) {
934     const auto& it = targets.begin();
935     With<Target> tctx((*it).second);
936     return seq(mod);
937   }
938   return seq(mod);
939 }
940 
PopulateGlobalMap()941 void VMCompiler::PopulateGlobalMap() {
942   // First we populate global map.
943   size_t global_index = 0;
944   for (auto named_func : context_.module->functions) {
945     auto gvar = named_func.first;
946     context_.global_map.insert({gvar, global_index++});
947   }
948 }
949 
LibraryCodegen()950 void VMCompiler::LibraryCodegen() {
951   auto const &cached_funcs = context_.cached_funcs;
952   if (cached_funcs.size() == 0) {
953     return;
954   }
955   std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
956   for (auto &cfunc : cached_funcs) {
957     std::string target_str = cfunc->target->str();
958     if (tgt_funcs.count(target_str) == 0) {
959       tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
960     } else {
961       tgt_funcs[target_str].push_back(cfunc->funcs[0]);
962     }
963   }
964   Map<Target, Array<LoweredFunc>> funcs;
965   for (auto &it : tgt_funcs) {
966     funcs.Set(Target::Create(it.first), it.second);
967   }
968 
969   if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
970     // The target is just a dummy arg because funcs already contains corresponding target
971     // therefore target won't be used in the build function
972     runtime::Module mod = (*f)(funcs, Target(), target_host_);
973     CHECK(mod.operator->());
974     exec_->lib = mod;
975   } else {
976     LOG(FATAL) << "relay.backend.build is not registered";
977   }
978   size_t primitive_index = 0;
979   for (auto cfunc : cached_funcs) {
980     exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
981   }
982 }
983 
CreateVMCompiler()984 runtime::Module CreateVMCompiler() {
985   auto exec = make_object<VMCompiler>();
986   return runtime::Module(exec);
987 }
988 
989 TVM_REGISTER_GLOBAL("relay._vm._VMCompiler")
__anon460f10160b02(TVMArgs args, TVMRetValue* rv) 990 .set_body([](TVMArgs args, TVMRetValue* rv) {
991   *rv = CreateVMCompiler();
992 });
993 
994 }  // namespace vm
995 }  // namespace relay
996 }  // namespace tvm
997