1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/tvm/relay/interpreter.cc
22  * \brief An interpreter for the Relay IR.
23  */
24 #include <tvm/packed_func_ext.h>
25 #include <tvm/runtime/device_api.h>
26 #include <tvm/relay/expr_functor.h>
27 #include <tvm/relay/pattern_functor.h>
28 #include <tvm/relay/interpreter.h>
29 #include <tvm/relay/transform.h>
30 #include <tvm/relay/analysis.h>
31 #include <tvm/relay/attrs/debug.h>
32 #include <tvm/relay/feature.h>
33 #include "compile_engine.h"
34 
35 namespace tvm {
36 namespace relay {
37 
38 using namespace runtime;
39 
GetPackedFunc(const std::string & name)40 inline const PackedFunc& GetPackedFunc(const std::string& name) {
41   const PackedFunc* pf = tvm::runtime::Registry::Get(name);
42   CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
43   return *pf;
44 }
45 
46 /* Value Implementation */
make(tvm::Map<Var,Value> env,Function func)47 Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
48   NodePtr<ClosureNode> n = make_node<ClosureNode>();
49   n->env = std::move(env);
50   n->func = std::move(func);
51   return Closure(n);
52 }
53 
54 TVM_REGISTER_API("relay._make.Closure")
55 .set_body_typed(ClosureNode::make);
56 
57 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0102(const ObjectRef& ref, IRPrinter* p) 58 .set_dispatch<ClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
59     auto* node = static_cast<const ClosureNode*>(ref.get());
60     p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
61   });
62 
63 
64 // TODO(@jroesch): this doesn't support mutual letrec
65 /* Value Implementation */
make(Closure clos,Var bind)66 RecClosure RecClosureNode::make(Closure clos, Var bind) {
67   NodePtr<RecClosureNode> n = make_node<RecClosureNode>();
68   n->clos = std::move(clos);
69   n->bind = std::move(bind);
70   return RecClosure(n);
71 }
72 
73 TVM_REGISTER_API("relay._make.RecClosure")
74 .set_body_typed(RecClosureNode::make);
75 
76 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0202(const ObjectRef& ref, IRPrinter* p) 77 .set_dispatch<RecClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
78     auto* node = static_cast<const RecClosureNode*>(ref.get());
79     p->stream << "RecClosureNode(" << node->clos << ")";
80   });
81 
make(tvm::Array<Value> value)82 TupleValue TupleValueNode::make(tvm::Array<Value> value) {
83   NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
84   n->fields = value;
85   return TupleValue(n);
86 }
87 
88 TVM_REGISTER_API("relay._make.TupleValue")
89 .set_body_typed(TupleValueNode::make);
90 
91 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0302(const ObjectRef& ref, IRPrinter* p) 92 .set_dispatch<TupleValueNode>([](const ObjectRef& ref, IRPrinter* p) {
93     auto* node = static_cast<const TupleValueNode*>(ref.get());
94     p->stream << "TupleValueNode(" << node->fields << ")";
95   });
96 
make(runtime::NDArray data)97 TensorValue TensorValueNode::make(runtime::NDArray data) {
98   NodePtr<TensorValueNode> n = make_node<TensorValueNode>();
99   n->data = std::move(data);
100   return TensorValue(n);
101 }
102 
103 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0402(const ObjectRef& ref, IRPrinter* p) 104 .set_dispatch<TensorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
105     auto* node = static_cast<const TensorValueNode*>(ref.get());
106     auto to_str = GetPackedFunc("relay._tensor_value_repr");
107     std::string data_str = to_str(GetRef<TensorValue>(node));
108     p->stream << "TensorValueNode(" << data_str << ")";
109   });
110 
111 TVM_REGISTER_API("relay._make.TensorValue")
112 .set_body_typed(TensorValueNode::make);
113 
make(Value value)114 RefValue RefValueNode::make(Value value) {
115   NodePtr<RefValueNode> n = make_node<RefValueNode>();
116   n->value = value;
117   return RefValue(n);
118 }
119 
120 TVM_REGISTER_API("relay._make.RefValue")
121 .set_body_typed(RefValueNode::make);
122 
123 TVM_REGISTER_NODE_TYPE(RefValueNode);
124 
125 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0502(const ObjectRef& ref, IRPrinter* p) 126 .set_dispatch<RefValueNode>([](const ObjectRef& ref, IRPrinter* p) {
127     auto* node = static_cast<const RefValueNode*>(ref.get());
128     p->stream << "RefValueNode(" << node->value << ")";
129   });
130 
make(int32_t tag,tvm::Array<Value> fields,Constructor constructor)131 ConstructorValue ConstructorValueNode::make(int32_t tag,
132                                             tvm::Array<Value> fields,
133                                             Constructor constructor) {
134   NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
135   n->tag = tag;
136   n->fields = fields;
137   n->constructor = constructor;
138   return ConstructorValue(n);
139 }
140 
141 TVM_REGISTER_API("relay._make.ConstructorValue")
142 .set_body_typed(ConstructorValueNode::make);
143 
144 TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
145 
146 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0602(const ObjectRef& ref, IRPrinter* p) 147 .set_dispatch<ConstructorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
148   auto* node = static_cast<const ConstructorValueNode*>(ref.get());
149   p->stream << "ConstructorValueNode(" << node->tag << ","
150             << node->fields << ")";
151 });
152 
153 /*!
154  * \brief A stack frame in the Relay interpreter.
155  *
156  * Contains a mapping from relay::Var to relay::Value.
157  */
158 struct Frame {
159   /*! \brief The set of local variables and arguments for the frame. */
160   tvm::Map<Var, Value> locals;
161 
Frametvm::relay::Frame162   explicit Frame(tvm::Map<Var, Value> locals) : locals(locals) {}
163 };
164 
165 /*!
166  * \brief The call stack in the Relay interpreter.
167  *
168  * Contains a stack of frames; each corresponding to
169  * a function call.
170  */
171 struct Stack {
172   /*! \brief The stack frames. */
173   std::vector<Frame> frames;
Stacktvm::relay::Stack174   Stack() : frames() { frames.push_back(Frame({})); }
175 
current_frametvm::relay::Stack176   Frame& current_frame() { return frames.back(); }
177 
Lookuptvm::relay::Stack178   Value Lookup(const Var& local) {
179     for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) {
180       auto elem = frame->locals.find(local);
181       if (elem != frame->locals.end()) {
182         return (*elem).second;
183       }
184     }
185 
186     LOG(FATAL) << "could not find variable binding for " << local
187                << "address= " << local.operator->();
188     return Value();
189   }
190   /*!
191    * A wrapper around Frame to add RAII semantics to pushing and popping
192    * stack frames.
193    */
194   struct LocalFrame {
195     Stack& st;
LocalFrametvm::relay::Stack::LocalFrame196     explicit LocalFrame(Stack& st, const Frame& fr) : st(st) {
197       st.frames.push_back(fr);
198     }
~LocalFrametvm::relay::Stack::LocalFrame199     ~LocalFrame() { st.frames.pop_back(); }
200   };
201 };
202 
203 /*! \brief A representation of the interpreter state which can be passed back to Python. */
204 class InterpreterState;
205 
206 /*! \brief A container capturing the state of the interpreter. */
207 class InterpreterStateNode : public Node {
208  public:
209   using Frame = tvm::Map<Var, Value>;
210   using Stack = tvm::Array<Frame>;
211 
212   /*! \brief The current expression under evaluation. */
213   Expr current_expr;
214 
215   /*! \brief The call stack of the interpreter. */
216   Stack stack;
217 
VisitAttrs(tvm::AttrVisitor * v)218   void VisitAttrs(tvm::AttrVisitor* v) {
219     v->Visit("current_expr", &current_expr);
220     v->Visit("stack", &stack);
221   }
222 
223   static InterpreterState make(Expr current_expr, Stack stack);
224 
225   static constexpr const char* _type_key = "relay.InterpreterState";
226   TVM_DECLARE_NODE_TYPE_INFO(InterpreterStateNode, Node);
227 };
228 
229 RELAY_DEFINE_NODE_REF(InterpreterState, InterpreterStateNode, NodeRef);
230 
make(Expr current_expr,Stack stack)231 InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
232   NodePtr<InterpreterStateNode> n = make_node<InterpreterStateNode>();
233   n->current_expr = std::move(current_expr);
234   n->stack = std::move(stack);
235   return InterpreterState(n);
236 }
237 
238 // NOTE: the current interpreter assumes A-normal form.
239 // which is better for execution.
240 //
241 // It will run duplicated computations when taking program that
242 // contains DAG in dataflow-form.
243 //
244 // Conversion to ANF is recommended before running the interpretation.
245 class Interpreter :
246       public ExprFunctor<Value(const Expr& n)>,
247              PatternFunctor<bool(const Pattern& p, const Value& v)> {
248  public:
Interpreter(Module mod,DLContext context,Target target)249   Interpreter(Module mod,
250               DLContext context,
251               Target target)
252       : mod_(mod), context_(context), target_(target) {
253     engine_ = CompileEngine::Global();
254   }
255 
256   template <typename T>
WithFrame(const Frame & fr,const std::function<T ()> & f)257   T WithFrame(const Frame& fr, const std::function<T()>& f) {
258     Stack::LocalFrame lf(stack_, fr);
259     return f();
260   }
261 
extend(const Var & id,Value v)262   void extend(const Var& id, Value v) {
263     stack_.current_frame().locals.Set(id, v);
264   }
265 
Lookup(const Var & local)266   inline Value Lookup(const Var& local) {
267     return stack_.Lookup(local);
268   }
269 
Eval(const Expr & expr)270   Value Eval(const Expr& expr) {
271     return VisitExpr(expr);
272   }
273 
VisitExpr(const Expr & expr)274   Value VisitExpr(const Expr& expr) final {
275     auto ret = ExprFunctor<Value(const Expr& n)>::VisitExpr(expr);
276     return ret;
277   }
278 
VisitExpr_(const VarNode * var_node)279   Value VisitExpr_(const VarNode* var_node) final {
280     return Lookup(GetRef<Var>(var_node));
281   }
282 
VisitExpr_(const GlobalVarNode * op)283   Value VisitExpr_(const GlobalVarNode* op) final {
284     return Eval(mod_->Lookup(GetRef<GlobalVar>(op)));
285   }
286 
VisitExpr_(const OpNode * id)287   Value VisitExpr_(const OpNode* id) override {
288     // TODO(@jroesch): Eta-expand and return in this case.
289     LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node "
290                << "in "
291                << "this case, eta expand";
292     return Value();
293   }
294 
VisitExpr_(const ConstantNode * op)295   Value VisitExpr_(const ConstantNode* op) final {
296     return TensorValueNode::make(op->data.CopyTo(context_));
297   }
298 
VisitExpr_(const TupleNode * op)299   Value VisitExpr_(const TupleNode* op) final {
300     std::vector<Value> values;
301 
302     for (const auto& field : op->fields) {
303       Value field_value = Eval(field);
304       values.push_back(field_value);
305     }
306 
307     return TupleValueNode::make(values);
308   }
309 
MakeClosure(const Function & func,Var letrec_name=Var ())310   inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
311     tvm::Map<Var, Value> captured_mod;
312     Array<Var> free_vars = FreeVars(func);
313 
314     for (const auto& var : free_vars) {
315       // Evaluate the free var (which could be a function call) if it hasn't
316       // shown up in a letting binding that has invoked the function.
317       if (letrec_name.defined() && letrec_name == var) {
318         continue;
319       }
320 
321       captured_mod.Set(var, Eval(var));
322     }
323 
324     // We must use mutation here to build a self referential closure.
325     auto closure = ClosureNode::make(captured_mod, func);
326     if (letrec_name.defined()) {
327       return RecClosureNode::make(closure, letrec_name);
328     }
329     return std::move(closure);
330   }
331 
VisitExpr_(const FunctionNode * func_node)332   Value VisitExpr_(const FunctionNode* func_node) final {
333     auto func = GetRef<Function>(func_node);
334     return MakeClosure(func);
335   }
336 
ComputeDynamicShape(const Function & func,const Array<Value> & args)337   Array<Shape> ComputeDynamicShape(const Function& func,
338                                    const Array<Value>& args) {
339     auto key = CCacheKeyNode::make(func, Target::Create("llvm"));
340     auto cfunc = engine_->LowerShapeFunc(key);
341     size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
342 
343     std::vector<TVMValue> values(arity);
344     std::vector<int> codes(arity);
345     TVMArgsSetter setter(values.data(), codes.data());
346     std::vector<NDArray> inputs(cfunc->inputs.size());
347     std::vector<NDArray> outputs(cfunc->outputs.size());
348 
349     DLContext cpu_ctx;
350     cpu_ctx.device_type = kDLCPU;
351     cpu_ctx.device_id = 0;
352 
353     auto fset_input = [&](size_t i, Value val, bool need_shape) {
354         const TensorValueNode* tv = val.as<TensorValueNode>();
355         CHECK(tv != nullptr) << "expect Tensor argument";
356         if (need_shape) {
357           int64_t ndim = tv->data.Shape().size();
358           NDArray shape_arr;
359           if (ndim == 0) {
360             shape_arr = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
361           } else {
362             shape_arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
363             int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
364             for (auto j = 0; j < ndim; ++j) {
365               data[j] = tv->data.Shape()[j];
366             }
367           }
368           inputs[i] = shape_arr;
369           setter(i, shape_arr);
370         } else {
371           auto arr = tv->data.CopyTo(cpu_ctx);
372           inputs[i] = arr;
373           setter(i, arr);
374         }
375     };
376 
377     size_t arg_counter = 0;
378     for (size_t i = 0; i < args.size(); ++i) {
379       auto arg = args[i];
380       auto param = func->params[i];
381       int state = cfunc->shape_func_param_states[i]->value;
382       if (arg.as<TensorValueNode>()) {
383         if (state & kNeedInputData) {
384           fset_input(arg_counter++, arg, false);
385         }
386         if (state & kNeedInputShape) {
387           fset_input(arg_counter++, arg, true);
388         }
389       } else {
390         const TupleValueNode* tuple = arg.as<TupleValueNode>();
391         CHECK(tuple != nullptr);
392         if (state & kNeedInputData) {
393           for (size_t i = 0; i < tuple->fields.size(); ++i) {
394             fset_input(arg_counter++, tuple->fields[i], false);
395           }
396         }
397         if (state & kNeedInputShape) {
398           for (size_t i = 0; i < tuple->fields.size(); ++i) {
399             fset_input(arg_counter++, tuple->fields[i], true);
400           }
401         }
402       }
403     }
404     CHECK_EQ(arg_counter, cfunc->inputs.size())
405       << "Shape function input sizes mismatch";
406 
407     auto fset_shape_output = [&](size_t i, Type val_type) {
408         // TODO(@icemelon): allow recursive tuple
409         const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
410         CHECK(rtype != nullptr);
411         int64_t ndim = rtype->shape.size();
412         auto arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
413         outputs[i] = arr;
414         setter(arg_counter + i, arr);
415     };
416 
417     auto ret_type = func->body->checked_type();
418     size_t out_cnt = 0;
419     if (auto rtype = ret_type.as<TupleTypeNode>()) {
420       out_cnt = rtype->fields.size();
421       for (size_t i = 0; i < out_cnt; ++i) {
422         fset_shape_output(i, rtype->fields[i]);
423       }
424     } else {
425       out_cnt = 1;
426       auto tt = Downcast<TensorType>(ret_type);
427       fset_shape_output(0, tt);
428     }
429     CHECK_EQ(cfunc->outputs.size(), out_cnt)
430       << "Shape function output sizes mismatch";
431 
432     PackedFunc shape_func;
433     TVMRetValue rv;
434     if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
435       tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target);
436       shape_func = m.GetFunction(cfunc->func_name);
437     } else {
438       LOG(FATAL) << "relay.backend.build is not registered";
439     }
440     shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
441 
442     // Get output shapes
443     Array<Shape> out_shapes;
444     for (auto out_tensor : outputs) {
445       int64_t* shape_data = reinterpret_cast<int64_t*>(out_tensor->data);
446       Shape out_shape;
447       for (int i = 0; i < out_tensor->shape[0]; ++i) {
448         out_shape.push_back(tvm::Integer(shape_data[i]));
449       }
450       out_shapes.push_back(out_shape);
451     }
452     return out_shapes;
453   }
454 
InvokePrimitiveOp(const Function & func,const Array<Value> & args)455   Value InvokePrimitiveOp(const Function& func,
456                           const Array<Value>& args) {
457     auto call_node = func->body.as<CallNode>();
458 
459     if (call_node && call_node->op == Op::Get("debug")) {
460       auto dattrs = call_node->attrs.as<DebugAttrs>();
461       auto interp_state = this->get_state(call_node->args[0]);
462 
463       if (dattrs->debug_func.defined()) {
464         dattrs->debug_func(interp_state);
465       } else {
466         RELAY_DEBUG_INTERP(interp_state);
467       }
468 
469       return args[0];
470     }
471 
472     // Marshal the arguments.
473     // Handle tuple input/output by flattening them.
474     size_t arg_len = 0;
475     for (size_t i = 0; i < args.size(); ++i) {
476       if (args[i].as<TensorValueNode>()) {
477         ++arg_len;
478       } else {
479         const auto* tvalue = args[i].as<TupleValueNode>();
480         arg_len += tvalue->fields.size();
481       }
482     }
483     size_t num_inputs = arg_len;
484     if (const auto* tuple_type = func->body->checked_type().as<TupleTypeNode>()) {
485       arg_len += tuple_type->fields.size();
486     } else {
487       CHECK(func->body->checked_type().as<TensorTypeNode>())
488         << func->body->checked_type();
489       arg_len += 1;
490     }
491     std::vector<TVMValue> values(arg_len);
492     std::vector<int> codes(arg_len);
493     TVMArgsSetter setter(values.data(), codes.data());
494 
495     auto fset_input = [&](size_t i, Value val) {
496       const TensorValueNode* tv = val.as<TensorValueNode>();
497       CHECK(tv != nullptr) << "expect Tensor argument";
498       setter(i, tv->data);
499       DLContext arg_ctx = tv->data->ctx;
500       CHECK(arg_ctx.device_type ==  context_.device_type &&
501             arg_ctx.device_id == context_.device_id)
502         << "Interpreter expect context to be "
503         << context_ << ", but get " << arg_ctx;
504     };
505 
506     int arg_counter = 0;
507     for (Value arg : args) {
508       if (arg.as<TensorValueNode>()) {
509         fset_input(arg_counter++,  arg);
510       } else {
511         const TupleValueNode* tuple = arg.as<TupleValueNode>();
512         CHECK(tuple != nullptr);
513         for (size_t i = 0; i < tuple->fields.size(); ++i) {
514           fset_input(arg_counter++, tuple->fields[i]);
515         }
516       }
517     }
518 
519     // TVM's calling convention is that the final argument is the output
520     // buffer. To preserve the illusion of being a functional language
521     // we need to allocate space for the output buffer based on the
522     // return type.
523     auto fset_output = [&](size_t i, Type val_type) {
524       const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
525       CHECK(rtype != nullptr);
526       // Allocate output tensor.
527       std::vector<int64_t> shape;
528       for (auto dim : rtype->shape) {
529         const auto* ivalue = as_const_int(dim);
530         CHECK(ivalue) << "expected concrete dimensions";
531         shape.push_back(ivalue[0]);
532       }
533       DLDataType dtype = Type2TVMType(rtype->dtype);
534       auto out_tensor = TensorValueNode::make(
535           NDArray::Empty(shape, dtype, context_));
536       setter(num_inputs + i, out_tensor->data);
537       return out_tensor;
538     };
539 
540     Array<Shape> out_shapes;
541     auto ret_type = func->body->checked_type();
542     bool is_dyn = IsDynamic(func->checked_type());
543     if (call_node->op == Op::Get("shape_of")) {
544       // The output shape of shape_of must be static since Relay doesn't support
545       // dynamic rank tensors.
546       is_dyn = false;
547     }
548 
549     if (is_dyn) {
550       CHECK(func->IsPrimitive());
551       out_shapes = ComputeDynamicShape(func, args);
552     }
553 
554     PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_));
555     TVMRetValue rv;
556     if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
557       CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
558       Array<Value> fields;
559       for (size_t i = 0; i < rtype->fields.size(); ++i) {
560         if (is_dyn) {
561           auto sh = out_shapes[i];
562           auto tt = Downcast<TensorType>(rtype->fields[i]);
563           fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype)));
564         } else {
565           fields.push_back(fset_output(i, rtype->fields[i]));
566         }
567       }
568       packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
569       return TupleValueNode::make(fields);
570     } else {
571       Value out_tensor;
572       if (is_dyn) {
573         CHECK_EQ(out_shapes.size(), 1);
574         auto sh = out_shapes[0];
575         auto tt = Downcast<TensorType>(ret_type);
576         out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype));
577       } else {
578         out_tensor = fset_output(0, ret_type);
579       }
580       packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
581       return out_tensor;
582     }
583   }
584 
585   // Invoke the closure
Invoke(const Closure & closure,const tvm::Array<Value> & args,const Var & bind=Var ())586   Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
587     // Get a reference to the function inside the closure.
588     if (closure->func->IsPrimitive()) {
589       return InvokePrimitiveOp(closure->func, args);
590     }
591     auto func = closure->func;
592     // Allocate a frame with the parameters and free variables.
593     tvm::Map<Var, Value> locals;
594 
595     CHECK_EQ(func->params.size(), args.size());
596 
597     for (size_t i = 0; i < func->params.size(); i++) {
598       CHECK_EQ(locals.count(func->params[i]), 0);
599       locals.Set(func->params[i], args[i]);
600     }
601 
602     // Add the var to value mappings from the Closure's environment.
603     for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
604       CHECK_EQ(locals.count((*it).first), 0);
605       locals.Set((*it).first, (*it).second);
606     }
607 
608     if (bind.defined()) {
609       locals.Set(bind, RecClosureNode::make(closure, bind));
610     }
611 
612     return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
613   }
614 
VisitExpr_(const CallNode * call)615   Value VisitExpr_(const CallNode* call) final {
616     tvm::Array<Value> args;
617     for (auto arg : call->args) {
618       args.push_back(Eval(arg));
619     }
620     // We should not find operators after running fusion,
621     // and operator lowering.
622     //
623     // We have some functions cotaining chunks of operators
624     // which will be loaded into operator map.
625     if (const auto* op_node = call->op.as<OpNode>()) {
626       LOG(FATAL) << "found " << op_node->name
627                  << "; operators should be removed by future passes; try "
628                     "fusing and lowering";
629     }
630     if (auto con = call->op.as<ConstructorNode>()) {
631       return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
632     }
633     // Now we just evaluate and expect to find a closure.
634     Value fn_val = Eval(call->op);
635     if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
636       auto closure = GetRef<Closure>(closure_node);
637       return this->Invoke(closure, args);
638     } else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
639       return this->Invoke(closure_node->clos, args, closure_node->bind);
640     } else {
641       LOG(FATAL) << "internal error: type error, expected function value in the call "
642                  << "position";
643       return Value();
644     }
645   }
646 
VisitExpr_(const LetNode * let)647   Value VisitExpr_(const LetNode* let) final {
648     if (auto func = let->value.as<FunctionNode>()) {
649       auto clo = MakeClosure(GetRef<Function>(func), let->var);
650       this->extend(let->var, clo);
651     } else {
652       auto value = Eval(let->value);
653       this->extend(let->var, value);
654     }
655 
656     return Eval(let->body);
657   }
658 
VisitExpr_(const TupleGetItemNode * op)659   Value VisitExpr_(const TupleGetItemNode* op) final {
660     Value val = Eval(op->tuple);
661     auto product_node = val.as<TupleValueNode>();
662     CHECK(product_node)
663       << "interal error: when evaluating TupleGetItem expected a tuple value";
664     CHECK_LT(static_cast<size_t>(op->index), product_node->fields.size())
665         << "internal error: index out of bounds";
666     return product_node->fields[op->index];
667   }
668 
VisitExpr_(const IfNode * op)669   Value VisitExpr_(const IfNode* op) final {
670     Value v = Eval(op->cond);
671     if (const TensorValueNode* bv = v.as<TensorValueNode>()) {
672       DLContext cpu_ctx;
673       cpu_ctx.device_type = kDLCPU;
674       cpu_ctx.device_id = 0;
675       NDArray cpu_array = bv->data.CopyTo(cpu_ctx);
676       CHECK_EQ(TVMType2Type(cpu_array->dtype), Bool());
677       // TODO(@jroesch, @MK): Refactor code into helper from DCE.
678       if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) {
679         return Eval(op->true_branch);
680       } else {
681         return Eval(op->false_branch);
682       }
683     } else {
684       LOG(FATAL) << "type error, type system should have caught this";
685       return Value();
686     }
687   }
688 
VisitExpr_(const RefWriteNode * op)689   Value VisitExpr_(const RefWriteNode* op) final {
690     Value r = Eval(op->ref);
691     if (const RefValueNode* rv = r.as<RefValueNode>()) {
692       rv->value = Eval(op->value);
693       return TupleValueNode::make({});
694     } else {
695       LOG(FATAL) << "type error, type system should have caught this";
696       return Value();
697     }
698   }
699 
VisitExpr_(const RefCreateNode * op)700   Value VisitExpr_(const RefCreateNode* op) final {
701     return RefValueNode::make(Eval(op->value));
702   }
703 
VisitExpr_(const RefReadNode * op)704   Value VisitExpr_(const RefReadNode* op) final {
705     Value r = Eval(op->ref);
706     if (const RefValueNode* rv = r.as<RefValueNode>()) {
707       return rv->value;
708     } else {
709       LOG(FATAL) << "type error, type system should have caught this";
710       return Value();
711     }
712   }
713 
VisitExpr_(const MatchNode * op)714   Value VisitExpr_(const MatchNode* op) final {
715     Value v = Eval(op->data);
716     for (const Clause& c : op->clauses) {
717       if (VisitPattern(c->lhs, v)) {
718         return VisitExpr(c->rhs);
719       }
720     }
721     LOG(FATAL) << "did not find any match";
722     return Value();
723   }
724 
VisitPattern_(const PatternConstructorNode * op,const Value & v)725   bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final {
726     const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
727     CHECK(cvn) << "need to be a constructor for match";
728     CHECK_NE(op->constructor->tag, -1);
729     CHECK_NE(cvn->tag, -1);
730     if (op->constructor->tag == cvn->tag) {
731       CHECK_EQ(op->patterns.size(), cvn->fields.size());
732       for (size_t i = 0; i < op->patterns.size(); ++i) {
733         if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
734           return false;
735         }
736       }
737       return true;
738     }
739     return false;
740   }
741 
VisitPattern_(const PatternTupleNode * op,const Value & v)742   bool VisitPattern_(const PatternTupleNode* op, const Value& v) final {
743     const TupleValueNode* tvn = v.as<TupleValueNode>();
744     CHECK(tvn) << "need to be a tuple for match";
745     CHECK_EQ(op->patterns.size(), tvn->fields.size());
746     for (size_t i = 0; i < op->patterns.size(); ++i) {
747       if (!VisitPattern(op->patterns[i], tvn->fields[i])) {
748         return false;
749       }
750     }
751     return true;
752   }
753 
VisitPattern_(const PatternWildcardNode * op,const Value & v)754   bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
755     return true;
756   }
757 
VisitPattern_(const PatternVarNode * op,const Value & v)758   bool VisitPattern_(const PatternVarNode* op, const Value& v) final {
759     extend(op->var, v);
760     return true;
761   }
762 
get_state(Expr e=Expr ()) const763   InterpreterState get_state(Expr e = Expr()) const {
764     InterpreterStateNode::Stack stack;
765     for (auto fr : this->stack_.frames) {
766       InterpreterStateNode::Frame frame = fr.locals;
767       stack.push_back(frame);
768     }
769     auto state = InterpreterStateNode::make(e, stack);
770     return state;
771   }
772 
773  private:
774   // Module
775   Module mod_;
776   // For simplicity we only run the interpreter on a single context.
777   // Context to run the interpreter on.
778   DLContext context_;
779   // Target parameter being used by the interpreter.
780   Target target_;
781   // Value stack.
782   Stack stack_;
783   // Backend compile engine.
784   CompileEngine engine_;
785 };
786 
787 
788 TypedPackedFunc<Value(Expr)>
CreateInterpreter(Module mod,DLContext context,Target target)789 CreateInterpreter(
790     Module mod,
791     DLContext context,
792     Target target) {
793   if (mod.defined()) {
794     // eta expand to support constructors in argument position
795     transform::Sequential seq({
796         transform::EtaExpand(
797             /* expand_constructor */ true, /* expand_global_var */ false)});
798     transform::PassContext pass_ctx = transform::PassContext::Current();
799     tvm::With<transform::PassContext> ctx(pass_ctx);
800     mod = seq(mod);
801   }
802 
803   auto intrp = std::make_shared<Interpreter>(mod, context, target);
804   auto packed = [intrp](Expr expr) {
805     auto f = DetectFeature(expr);
806     CHECK(f.is_subset_of(FeatureSet::All() - fGraph));
807     return intrp->Eval(expr);
808   };
809   return TypedPackedFunc<Value(Expr)>(packed);
810 }
811 
812 TVM_REGISTER_API("relay.backend.CreateInterpreter")
813 .set_body_typed(CreateInterpreter);
814 
815 TVM_REGISTER_NODE_TYPE(ClosureNode);
816 TVM_REGISTER_NODE_TYPE(TupleValueNode);
817 TVM_REGISTER_NODE_TYPE(TensorValueNode);
818 
819 }  // namespace relay
820 }  // namespace tvm
821