1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/tvm/relay/interpreter.cc
22 * \brief An interpreter for the Relay IR.
23 */
24 #include <tvm/packed_func_ext.h>
25 #include <tvm/runtime/device_api.h>
26 #include <tvm/relay/expr_functor.h>
27 #include <tvm/relay/pattern_functor.h>
28 #include <tvm/relay/interpreter.h>
29 #include <tvm/relay/transform.h>
30 #include <tvm/relay/analysis.h>
31 #include <tvm/relay/attrs/debug.h>
32 #include <tvm/relay/feature.h>
33 #include "compile_engine.h"
34
35 namespace tvm {
36 namespace relay {
37
38 using namespace runtime;
39
GetPackedFunc(const std::string & name)40 inline const PackedFunc& GetPackedFunc(const std::string& name) {
41 const PackedFunc* pf = tvm::runtime::Registry::Get(name);
42 CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
43 return *pf;
44 }
45
46 /* Value Implementation */
make(tvm::Map<Var,Value> env,Function func)47 Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
48 NodePtr<ClosureNode> n = make_node<ClosureNode>();
49 n->env = std::move(env);
50 n->func = std::move(func);
51 return Closure(n);
52 }
53
54 TVM_REGISTER_API("relay._make.Closure")
55 .set_body_typed(ClosureNode::make);
56
57 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0102(const ObjectRef& ref, IRPrinter* p) 58 .set_dispatch<ClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
59 auto* node = static_cast<const ClosureNode*>(ref.get());
60 p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
61 });
62
63
64 // TODO(@jroesch): this doesn't support mutual letrec
65 /* Value Implementation */
make(Closure clos,Var bind)66 RecClosure RecClosureNode::make(Closure clos, Var bind) {
67 NodePtr<RecClosureNode> n = make_node<RecClosureNode>();
68 n->clos = std::move(clos);
69 n->bind = std::move(bind);
70 return RecClosure(n);
71 }
72
73 TVM_REGISTER_API("relay._make.RecClosure")
74 .set_body_typed(RecClosureNode::make);
75
76 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0202(const ObjectRef& ref, IRPrinter* p) 77 .set_dispatch<RecClosureNode>([](const ObjectRef& ref, IRPrinter* p) {
78 auto* node = static_cast<const RecClosureNode*>(ref.get());
79 p->stream << "RecClosureNode(" << node->clos << ")";
80 });
81
make(tvm::Array<Value> value)82 TupleValue TupleValueNode::make(tvm::Array<Value> value) {
83 NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
84 n->fields = value;
85 return TupleValue(n);
86 }
87
88 TVM_REGISTER_API("relay._make.TupleValue")
89 .set_body_typed(TupleValueNode::make);
90
91 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0302(const ObjectRef& ref, IRPrinter* p) 92 .set_dispatch<TupleValueNode>([](const ObjectRef& ref, IRPrinter* p) {
93 auto* node = static_cast<const TupleValueNode*>(ref.get());
94 p->stream << "TupleValueNode(" << node->fields << ")";
95 });
96
make(runtime::NDArray data)97 TensorValue TensorValueNode::make(runtime::NDArray data) {
98 NodePtr<TensorValueNode> n = make_node<TensorValueNode>();
99 n->data = std::move(data);
100 return TensorValue(n);
101 }
102
103 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0402(const ObjectRef& ref, IRPrinter* p) 104 .set_dispatch<TensorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
105 auto* node = static_cast<const TensorValueNode*>(ref.get());
106 auto to_str = GetPackedFunc("relay._tensor_value_repr");
107 std::string data_str = to_str(GetRef<TensorValue>(node));
108 p->stream << "TensorValueNode(" << data_str << ")";
109 });
110
111 TVM_REGISTER_API("relay._make.TensorValue")
112 .set_body_typed(TensorValueNode::make);
113
make(Value value)114 RefValue RefValueNode::make(Value value) {
115 NodePtr<RefValueNode> n = make_node<RefValueNode>();
116 n->value = value;
117 return RefValue(n);
118 }
119
120 TVM_REGISTER_API("relay._make.RefValue")
121 .set_body_typed(RefValueNode::make);
122
123 TVM_REGISTER_NODE_TYPE(RefValueNode);
124
125 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0502(const ObjectRef& ref, IRPrinter* p) 126 .set_dispatch<RefValueNode>([](const ObjectRef& ref, IRPrinter* p) {
127 auto* node = static_cast<const RefValueNode*>(ref.get());
128 p->stream << "RefValueNode(" << node->value << ")";
129 });
130
make(int32_t tag,tvm::Array<Value> fields,Constructor constructor)131 ConstructorValue ConstructorValueNode::make(int32_t tag,
132 tvm::Array<Value> fields,
133 Constructor constructor) {
134 NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
135 n->tag = tag;
136 n->fields = fields;
137 n->constructor = constructor;
138 return ConstructorValue(n);
139 }
140
141 TVM_REGISTER_API("relay._make.ConstructorValue")
142 .set_body_typed(ConstructorValueNode::make);
143
144 TVM_REGISTER_NODE_TYPE(ConstructorValueNode);
145
146 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anonf055c43d0602(const ObjectRef& ref, IRPrinter* p) 147 .set_dispatch<ConstructorValueNode>([](const ObjectRef& ref, IRPrinter* p) {
148 auto* node = static_cast<const ConstructorValueNode*>(ref.get());
149 p->stream << "ConstructorValueNode(" << node->tag << ","
150 << node->fields << ")";
151 });
152
153 /*!
154 * \brief A stack frame in the Relay interpreter.
155 *
156 * Contains a mapping from relay::Var to relay::Value.
157 */
158 struct Frame {
159 /*! \brief The set of local variables and arguments for the frame. */
160 tvm::Map<Var, Value> locals;
161
Frametvm::relay::Frame162 explicit Frame(tvm::Map<Var, Value> locals) : locals(locals) {}
163 };
164
165 /*!
166 * \brief The call stack in the Relay interpreter.
167 *
168 * Contains a stack of frames; each corresponding to
169 * a function call.
170 */
171 struct Stack {
172 /*! \brief The stack frames. */
173 std::vector<Frame> frames;
Stacktvm::relay::Stack174 Stack() : frames() { frames.push_back(Frame({})); }
175
current_frametvm::relay::Stack176 Frame& current_frame() { return frames.back(); }
177
Lookuptvm::relay::Stack178 Value Lookup(const Var& local) {
179 for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) {
180 auto elem = frame->locals.find(local);
181 if (elem != frame->locals.end()) {
182 return (*elem).second;
183 }
184 }
185
186 LOG(FATAL) << "could not find variable binding for " << local
187 << "address= " << local.operator->();
188 return Value();
189 }
190 /*!
191 * A wrapper around Frame to add RAII semantics to pushing and popping
192 * stack frames.
193 */
194 struct LocalFrame {
195 Stack& st;
LocalFrametvm::relay::Stack::LocalFrame196 explicit LocalFrame(Stack& st, const Frame& fr) : st(st) {
197 st.frames.push_back(fr);
198 }
~LocalFrametvm::relay::Stack::LocalFrame199 ~LocalFrame() { st.frames.pop_back(); }
200 };
201 };
202
203 /*! \brief A representation of the interpreter state which can be passed back to Python. */
204 class InterpreterState;
205
206 /*! \brief A container capturing the state of the interpreter. */
207 class InterpreterStateNode : public Node {
208 public:
209 using Frame = tvm::Map<Var, Value>;
210 using Stack = tvm::Array<Frame>;
211
212 /*! \brief The current expression under evaluation. */
213 Expr current_expr;
214
215 /*! \brief The call stack of the interpreter. */
216 Stack stack;
217
VisitAttrs(tvm::AttrVisitor * v)218 void VisitAttrs(tvm::AttrVisitor* v) {
219 v->Visit("current_expr", ¤t_expr);
220 v->Visit("stack", &stack);
221 }
222
223 static InterpreterState make(Expr current_expr, Stack stack);
224
225 static constexpr const char* _type_key = "relay.InterpreterState";
226 TVM_DECLARE_NODE_TYPE_INFO(InterpreterStateNode, Node);
227 };
228
229 RELAY_DEFINE_NODE_REF(InterpreterState, InterpreterStateNode, NodeRef);
230
make(Expr current_expr,Stack stack)231 InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
232 NodePtr<InterpreterStateNode> n = make_node<InterpreterStateNode>();
233 n->current_expr = std::move(current_expr);
234 n->stack = std::move(stack);
235 return InterpreterState(n);
236 }
237
238 // NOTE: the current interpreter assumes A-normal form.
239 // which is better for execution.
240 //
241 // It will run duplicated computations when taking program that
242 // contains DAG in dataflow-form.
243 //
244 // Conversion to ANF is recommended before running the interpretation.
245 class Interpreter :
246 public ExprFunctor<Value(const Expr& n)>,
247 PatternFunctor<bool(const Pattern& p, const Value& v)> {
248 public:
Interpreter(Module mod,DLContext context,Target target)249 Interpreter(Module mod,
250 DLContext context,
251 Target target)
252 : mod_(mod), context_(context), target_(target) {
253 engine_ = CompileEngine::Global();
254 }
255
256 template <typename T>
WithFrame(const Frame & fr,const std::function<T ()> & f)257 T WithFrame(const Frame& fr, const std::function<T()>& f) {
258 Stack::LocalFrame lf(stack_, fr);
259 return f();
260 }
261
extend(const Var & id,Value v)262 void extend(const Var& id, Value v) {
263 stack_.current_frame().locals.Set(id, v);
264 }
265
Lookup(const Var & local)266 inline Value Lookup(const Var& local) {
267 return stack_.Lookup(local);
268 }
269
Eval(const Expr & expr)270 Value Eval(const Expr& expr) {
271 return VisitExpr(expr);
272 }
273
VisitExpr(const Expr & expr)274 Value VisitExpr(const Expr& expr) final {
275 auto ret = ExprFunctor<Value(const Expr& n)>::VisitExpr(expr);
276 return ret;
277 }
278
VisitExpr_(const VarNode * var_node)279 Value VisitExpr_(const VarNode* var_node) final {
280 return Lookup(GetRef<Var>(var_node));
281 }
282
VisitExpr_(const GlobalVarNode * op)283 Value VisitExpr_(const GlobalVarNode* op) final {
284 return Eval(mod_->Lookup(GetRef<GlobalVar>(op)));
285 }
286
VisitExpr_(const OpNode * id)287 Value VisitExpr_(const OpNode* id) override {
288 // TODO(@jroesch): Eta-expand and return in this case.
289 LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node "
290 << "in "
291 << "this case, eta expand";
292 return Value();
293 }
294
VisitExpr_(const ConstantNode * op)295 Value VisitExpr_(const ConstantNode* op) final {
296 return TensorValueNode::make(op->data.CopyTo(context_));
297 }
298
VisitExpr_(const TupleNode * op)299 Value VisitExpr_(const TupleNode* op) final {
300 std::vector<Value> values;
301
302 for (const auto& field : op->fields) {
303 Value field_value = Eval(field);
304 values.push_back(field_value);
305 }
306
307 return TupleValueNode::make(values);
308 }
309
MakeClosure(const Function & func,Var letrec_name=Var ())310 inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
311 tvm::Map<Var, Value> captured_mod;
312 Array<Var> free_vars = FreeVars(func);
313
314 for (const auto& var : free_vars) {
315 // Evaluate the free var (which could be a function call) if it hasn't
316 // shown up in a letting binding that has invoked the function.
317 if (letrec_name.defined() && letrec_name == var) {
318 continue;
319 }
320
321 captured_mod.Set(var, Eval(var));
322 }
323
324 // We must use mutation here to build a self referential closure.
325 auto closure = ClosureNode::make(captured_mod, func);
326 if (letrec_name.defined()) {
327 return RecClosureNode::make(closure, letrec_name);
328 }
329 return std::move(closure);
330 }
331
VisitExpr_(const FunctionNode * func_node)332 Value VisitExpr_(const FunctionNode* func_node) final {
333 auto func = GetRef<Function>(func_node);
334 return MakeClosure(func);
335 }
336
ComputeDynamicShape(const Function & func,const Array<Value> & args)337 Array<Shape> ComputeDynamicShape(const Function& func,
338 const Array<Value>& args) {
339 auto key = CCacheKeyNode::make(func, Target::Create("llvm"));
340 auto cfunc = engine_->LowerShapeFunc(key);
341 size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
342
343 std::vector<TVMValue> values(arity);
344 std::vector<int> codes(arity);
345 TVMArgsSetter setter(values.data(), codes.data());
346 std::vector<NDArray> inputs(cfunc->inputs.size());
347 std::vector<NDArray> outputs(cfunc->outputs.size());
348
349 DLContext cpu_ctx;
350 cpu_ctx.device_type = kDLCPU;
351 cpu_ctx.device_id = 0;
352
353 auto fset_input = [&](size_t i, Value val, bool need_shape) {
354 const TensorValueNode* tv = val.as<TensorValueNode>();
355 CHECK(tv != nullptr) << "expect Tensor argument";
356 if (need_shape) {
357 int64_t ndim = tv->data.Shape().size();
358 NDArray shape_arr;
359 if (ndim == 0) {
360 shape_arr = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
361 } else {
362 shape_arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
363 int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
364 for (auto j = 0; j < ndim; ++j) {
365 data[j] = tv->data.Shape()[j];
366 }
367 }
368 inputs[i] = shape_arr;
369 setter(i, shape_arr);
370 } else {
371 auto arr = tv->data.CopyTo(cpu_ctx);
372 inputs[i] = arr;
373 setter(i, arr);
374 }
375 };
376
377 size_t arg_counter = 0;
378 for (size_t i = 0; i < args.size(); ++i) {
379 auto arg = args[i];
380 auto param = func->params[i];
381 int state = cfunc->shape_func_param_states[i]->value;
382 if (arg.as<TensorValueNode>()) {
383 if (state & kNeedInputData) {
384 fset_input(arg_counter++, arg, false);
385 }
386 if (state & kNeedInputShape) {
387 fset_input(arg_counter++, arg, true);
388 }
389 } else {
390 const TupleValueNode* tuple = arg.as<TupleValueNode>();
391 CHECK(tuple != nullptr);
392 if (state & kNeedInputData) {
393 for (size_t i = 0; i < tuple->fields.size(); ++i) {
394 fset_input(arg_counter++, tuple->fields[i], false);
395 }
396 }
397 if (state & kNeedInputShape) {
398 for (size_t i = 0; i < tuple->fields.size(); ++i) {
399 fset_input(arg_counter++, tuple->fields[i], true);
400 }
401 }
402 }
403 }
404 CHECK_EQ(arg_counter, cfunc->inputs.size())
405 << "Shape function input sizes mismatch";
406
407 auto fset_shape_output = [&](size_t i, Type val_type) {
408 // TODO(@icemelon): allow recursive tuple
409 const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
410 CHECK(rtype != nullptr);
411 int64_t ndim = rtype->shape.size();
412 auto arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
413 outputs[i] = arr;
414 setter(arg_counter + i, arr);
415 };
416
417 auto ret_type = func->body->checked_type();
418 size_t out_cnt = 0;
419 if (auto rtype = ret_type.as<TupleTypeNode>()) {
420 out_cnt = rtype->fields.size();
421 for (size_t i = 0; i < out_cnt; ++i) {
422 fset_shape_output(i, rtype->fields[i]);
423 }
424 } else {
425 out_cnt = 1;
426 auto tt = Downcast<TensorType>(ret_type);
427 fset_shape_output(0, tt);
428 }
429 CHECK_EQ(cfunc->outputs.size(), out_cnt)
430 << "Shape function output sizes mismatch";
431
432 PackedFunc shape_func;
433 TVMRetValue rv;
434 if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
435 tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target);
436 shape_func = m.GetFunction(cfunc->func_name);
437 } else {
438 LOG(FATAL) << "relay.backend.build is not registered";
439 }
440 shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
441
442 // Get output shapes
443 Array<Shape> out_shapes;
444 for (auto out_tensor : outputs) {
445 int64_t* shape_data = reinterpret_cast<int64_t*>(out_tensor->data);
446 Shape out_shape;
447 for (int i = 0; i < out_tensor->shape[0]; ++i) {
448 out_shape.push_back(tvm::Integer(shape_data[i]));
449 }
450 out_shapes.push_back(out_shape);
451 }
452 return out_shapes;
453 }
454
InvokePrimitiveOp(const Function & func,const Array<Value> & args)455 Value InvokePrimitiveOp(const Function& func,
456 const Array<Value>& args) {
457 auto call_node = func->body.as<CallNode>();
458
459 if (call_node && call_node->op == Op::Get("debug")) {
460 auto dattrs = call_node->attrs.as<DebugAttrs>();
461 auto interp_state = this->get_state(call_node->args[0]);
462
463 if (dattrs->debug_func.defined()) {
464 dattrs->debug_func(interp_state);
465 } else {
466 RELAY_DEBUG_INTERP(interp_state);
467 }
468
469 return args[0];
470 }
471
472 // Marshal the arguments.
473 // Handle tuple input/output by flattening them.
474 size_t arg_len = 0;
475 for (size_t i = 0; i < args.size(); ++i) {
476 if (args[i].as<TensorValueNode>()) {
477 ++arg_len;
478 } else {
479 const auto* tvalue = args[i].as<TupleValueNode>();
480 arg_len += tvalue->fields.size();
481 }
482 }
483 size_t num_inputs = arg_len;
484 if (const auto* tuple_type = func->body->checked_type().as<TupleTypeNode>()) {
485 arg_len += tuple_type->fields.size();
486 } else {
487 CHECK(func->body->checked_type().as<TensorTypeNode>())
488 << func->body->checked_type();
489 arg_len += 1;
490 }
491 std::vector<TVMValue> values(arg_len);
492 std::vector<int> codes(arg_len);
493 TVMArgsSetter setter(values.data(), codes.data());
494
495 auto fset_input = [&](size_t i, Value val) {
496 const TensorValueNode* tv = val.as<TensorValueNode>();
497 CHECK(tv != nullptr) << "expect Tensor argument";
498 setter(i, tv->data);
499 DLContext arg_ctx = tv->data->ctx;
500 CHECK(arg_ctx.device_type == context_.device_type &&
501 arg_ctx.device_id == context_.device_id)
502 << "Interpreter expect context to be "
503 << context_ << ", but get " << arg_ctx;
504 };
505
506 int arg_counter = 0;
507 for (Value arg : args) {
508 if (arg.as<TensorValueNode>()) {
509 fset_input(arg_counter++, arg);
510 } else {
511 const TupleValueNode* tuple = arg.as<TupleValueNode>();
512 CHECK(tuple != nullptr);
513 for (size_t i = 0; i < tuple->fields.size(); ++i) {
514 fset_input(arg_counter++, tuple->fields[i]);
515 }
516 }
517 }
518
519 // TVM's calling convention is that the final argument is the output
520 // buffer. To preserve the illusion of being a functional language
521 // we need to allocate space for the output buffer based on the
522 // return type.
523 auto fset_output = [&](size_t i, Type val_type) {
524 const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
525 CHECK(rtype != nullptr);
526 // Allocate output tensor.
527 std::vector<int64_t> shape;
528 for (auto dim : rtype->shape) {
529 const auto* ivalue = as_const_int(dim);
530 CHECK(ivalue) << "expected concrete dimensions";
531 shape.push_back(ivalue[0]);
532 }
533 DLDataType dtype = Type2TVMType(rtype->dtype);
534 auto out_tensor = TensorValueNode::make(
535 NDArray::Empty(shape, dtype, context_));
536 setter(num_inputs + i, out_tensor->data);
537 return out_tensor;
538 };
539
540 Array<Shape> out_shapes;
541 auto ret_type = func->body->checked_type();
542 bool is_dyn = IsDynamic(func->checked_type());
543 if (call_node->op == Op::Get("shape_of")) {
544 // The output shape of shape_of must be static since Relay doesn't support
545 // dynamic rank tensors.
546 is_dyn = false;
547 }
548
549 if (is_dyn) {
550 CHECK(func->IsPrimitive());
551 out_shapes = ComputeDynamicShape(func, args);
552 }
553
554 PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_));
555 TVMRetValue rv;
556 if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
557 CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
558 Array<Value> fields;
559 for (size_t i = 0; i < rtype->fields.size(); ++i) {
560 if (is_dyn) {
561 auto sh = out_shapes[i];
562 auto tt = Downcast<TensorType>(rtype->fields[i]);
563 fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype)));
564 } else {
565 fields.push_back(fset_output(i, rtype->fields[i]));
566 }
567 }
568 packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
569 return TupleValueNode::make(fields);
570 } else {
571 Value out_tensor;
572 if (is_dyn) {
573 CHECK_EQ(out_shapes.size(), 1);
574 auto sh = out_shapes[0];
575 auto tt = Downcast<TensorType>(ret_type);
576 out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype));
577 } else {
578 out_tensor = fset_output(0, ret_type);
579 }
580 packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
581 return out_tensor;
582 }
583 }
584
585 // Invoke the closure
Invoke(const Closure & closure,const tvm::Array<Value> & args,const Var & bind=Var ())586 Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
587 // Get a reference to the function inside the closure.
588 if (closure->func->IsPrimitive()) {
589 return InvokePrimitiveOp(closure->func, args);
590 }
591 auto func = closure->func;
592 // Allocate a frame with the parameters and free variables.
593 tvm::Map<Var, Value> locals;
594
595 CHECK_EQ(func->params.size(), args.size());
596
597 for (size_t i = 0; i < func->params.size(); i++) {
598 CHECK_EQ(locals.count(func->params[i]), 0);
599 locals.Set(func->params[i], args[i]);
600 }
601
602 // Add the var to value mappings from the Closure's environment.
603 for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
604 CHECK_EQ(locals.count((*it).first), 0);
605 locals.Set((*it).first, (*it).second);
606 }
607
608 if (bind.defined()) {
609 locals.Set(bind, RecClosureNode::make(closure, bind));
610 }
611
612 return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
613 }
614
VisitExpr_(const CallNode * call)615 Value VisitExpr_(const CallNode* call) final {
616 tvm::Array<Value> args;
617 for (auto arg : call->args) {
618 args.push_back(Eval(arg));
619 }
620 // We should not find operators after running fusion,
621 // and operator lowering.
622 //
623 // We have some functions cotaining chunks of operators
624 // which will be loaded into operator map.
625 if (const auto* op_node = call->op.as<OpNode>()) {
626 LOG(FATAL) << "found " << op_node->name
627 << "; operators should be removed by future passes; try "
628 "fusing and lowering";
629 }
630 if (auto con = call->op.as<ConstructorNode>()) {
631 return ConstructorValueNode::make(con->tag, args, GetRef<Constructor>(con));
632 }
633 // Now we just evaluate and expect to find a closure.
634 Value fn_val = Eval(call->op);
635 if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
636 auto closure = GetRef<Closure>(closure_node);
637 return this->Invoke(closure, args);
638 } else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
639 return this->Invoke(closure_node->clos, args, closure_node->bind);
640 } else {
641 LOG(FATAL) << "internal error: type error, expected function value in the call "
642 << "position";
643 return Value();
644 }
645 }
646
VisitExpr_(const LetNode * let)647 Value VisitExpr_(const LetNode* let) final {
648 if (auto func = let->value.as<FunctionNode>()) {
649 auto clo = MakeClosure(GetRef<Function>(func), let->var);
650 this->extend(let->var, clo);
651 } else {
652 auto value = Eval(let->value);
653 this->extend(let->var, value);
654 }
655
656 return Eval(let->body);
657 }
658
VisitExpr_(const TupleGetItemNode * op)659 Value VisitExpr_(const TupleGetItemNode* op) final {
660 Value val = Eval(op->tuple);
661 auto product_node = val.as<TupleValueNode>();
662 CHECK(product_node)
663 << "interal error: when evaluating TupleGetItem expected a tuple value";
664 CHECK_LT(static_cast<size_t>(op->index), product_node->fields.size())
665 << "internal error: index out of bounds";
666 return product_node->fields[op->index];
667 }
668
VisitExpr_(const IfNode * op)669 Value VisitExpr_(const IfNode* op) final {
670 Value v = Eval(op->cond);
671 if (const TensorValueNode* bv = v.as<TensorValueNode>()) {
672 DLContext cpu_ctx;
673 cpu_ctx.device_type = kDLCPU;
674 cpu_ctx.device_id = 0;
675 NDArray cpu_array = bv->data.CopyTo(cpu_ctx);
676 CHECK_EQ(TVMType2Type(cpu_array->dtype), Bool());
677 // TODO(@jroesch, @MK): Refactor code into helper from DCE.
678 if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) {
679 return Eval(op->true_branch);
680 } else {
681 return Eval(op->false_branch);
682 }
683 } else {
684 LOG(FATAL) << "type error, type system should have caught this";
685 return Value();
686 }
687 }
688
VisitExpr_(const RefWriteNode * op)689 Value VisitExpr_(const RefWriteNode* op) final {
690 Value r = Eval(op->ref);
691 if (const RefValueNode* rv = r.as<RefValueNode>()) {
692 rv->value = Eval(op->value);
693 return TupleValueNode::make({});
694 } else {
695 LOG(FATAL) << "type error, type system should have caught this";
696 return Value();
697 }
698 }
699
VisitExpr_(const RefCreateNode * op)700 Value VisitExpr_(const RefCreateNode* op) final {
701 return RefValueNode::make(Eval(op->value));
702 }
703
VisitExpr_(const RefReadNode * op)704 Value VisitExpr_(const RefReadNode* op) final {
705 Value r = Eval(op->ref);
706 if (const RefValueNode* rv = r.as<RefValueNode>()) {
707 return rv->value;
708 } else {
709 LOG(FATAL) << "type error, type system should have caught this";
710 return Value();
711 }
712 }
713
VisitExpr_(const MatchNode * op)714 Value VisitExpr_(const MatchNode* op) final {
715 Value v = Eval(op->data);
716 for (const Clause& c : op->clauses) {
717 if (VisitPattern(c->lhs, v)) {
718 return VisitExpr(c->rhs);
719 }
720 }
721 LOG(FATAL) << "did not find any match";
722 return Value();
723 }
724
VisitPattern_(const PatternConstructorNode * op,const Value & v)725 bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final {
726 const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
727 CHECK(cvn) << "need to be a constructor for match";
728 CHECK_NE(op->constructor->tag, -1);
729 CHECK_NE(cvn->tag, -1);
730 if (op->constructor->tag == cvn->tag) {
731 CHECK_EQ(op->patterns.size(), cvn->fields.size());
732 for (size_t i = 0; i < op->patterns.size(); ++i) {
733 if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
734 return false;
735 }
736 }
737 return true;
738 }
739 return false;
740 }
741
VisitPattern_(const PatternTupleNode * op,const Value & v)742 bool VisitPattern_(const PatternTupleNode* op, const Value& v) final {
743 const TupleValueNode* tvn = v.as<TupleValueNode>();
744 CHECK(tvn) << "need to be a tuple for match";
745 CHECK_EQ(op->patterns.size(), tvn->fields.size());
746 for (size_t i = 0; i < op->patterns.size(); ++i) {
747 if (!VisitPattern(op->patterns[i], tvn->fields[i])) {
748 return false;
749 }
750 }
751 return true;
752 }
753
VisitPattern_(const PatternWildcardNode * op,const Value & v)754 bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
755 return true;
756 }
757
VisitPattern_(const PatternVarNode * op,const Value & v)758 bool VisitPattern_(const PatternVarNode* op, const Value& v) final {
759 extend(op->var, v);
760 return true;
761 }
762
get_state(Expr e=Expr ()) const763 InterpreterState get_state(Expr e = Expr()) const {
764 InterpreterStateNode::Stack stack;
765 for (auto fr : this->stack_.frames) {
766 InterpreterStateNode::Frame frame = fr.locals;
767 stack.push_back(frame);
768 }
769 auto state = InterpreterStateNode::make(e, stack);
770 return state;
771 }
772
773 private:
774 // Module
775 Module mod_;
776 // For simplicity we only run the interpreter on a single context.
777 // Context to run the interpreter on.
778 DLContext context_;
779 // Target parameter being used by the interpreter.
780 Target target_;
781 // Value stack.
782 Stack stack_;
783 // Backend compile engine.
784 CompileEngine engine_;
785 };
786
787
788 TypedPackedFunc<Value(Expr)>
CreateInterpreter(Module mod,DLContext context,Target target)789 CreateInterpreter(
790 Module mod,
791 DLContext context,
792 Target target) {
793 if (mod.defined()) {
794 // eta expand to support constructors in argument position
795 transform::Sequential seq({
796 transform::EtaExpand(
797 /* expand_constructor */ true, /* expand_global_var */ false)});
798 transform::PassContext pass_ctx = transform::PassContext::Current();
799 tvm::With<transform::PassContext> ctx(pass_ctx);
800 mod = seq(mod);
801 }
802
803 auto intrp = std::make_shared<Interpreter>(mod, context, target);
804 auto packed = [intrp](Expr expr) {
805 auto f = DetectFeature(expr);
806 CHECK(f.is_subset_of(FeatureSet::All() - fGraph));
807 return intrp->Eval(expr);
808 };
809 return TypedPackedFunc<Value(Expr)>(packed);
810 }
811
812 TVM_REGISTER_API("relay.backend.CreateInterpreter")
813 .set_body_typed(CreateInterpreter);
814
815 TVM_REGISTER_NODE_TYPE(ClosureNode);
816 TVM_REGISTER_NODE_TYPE(TupleValueNode);
817 TVM_REGISTER_NODE_TYPE(TensorValueNode);
818
819 } // namespace relay
820 } // namespace tvm
821