1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file ad.cc
22  * \brief API for Automatic Differentiation for the Relay IR.
23  */
24 
25 #include <tvm/lowered_func.h>
26 #include <tvm/operation.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/relay/analysis.h>
29 #include <tvm/relay/transform.h>
30 #include "pattern_util.h"
31 #include "pass_util.h"
32 #include "let_list.h"
33 #include "../ir/type_functor.h"
34 
35 namespace tvm {
36 namespace relay {
37 
38 using namespace tvm::runtime;
39 
40 /*! What is automatic differentiation(AD) and why is it important?
41  * By AD, we roughly mean, given a term which denotes some mathematical function,
42  * derive a term which denotes the derivative of that mathematical function.
43  * Such a method can be compile-time, which is a macro on completely known function.
44  * Formally speaking, such requirement mean that the input function is a closed expression -
45  * that is, it only refer to local variable that is it's parameter, or defined inside it.
46  * Every top level definition satisfy this criteria.
47  * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> (Float[] -> Float[]).
48  * In relay we currently only support compile-time AD, but it should be enough for a lot of use case.
49  *
50  * In deep learning, the most common way to train a deep neural network is by gradient descent or some of it's variant.
51  * Such optimization method require us to input the gradient of neural network, which can be obtained easily using AD.
52  * In fact, back propagation is essentially reverse-mode automatic differentiation, a kind of AD!
53  */
54 
55 /*! In relay, automatic differentiation(AD) is a macro,
56  *  that transform closed expr(expr without free variable/free type variable) of type
57  *  (x0, x1, x2, ...) -> Float[] to
58  *  (x0, x1, x2, ...) -> (Float[], (x0, x1,  x2, ...)),
59  *  When x0, x1, x2... are Float of different shape.
60  * the return value is a pair, with left hand side as the original value, and right hand side as gradient of the input.
61  *  WithGradientType will take the type of input, and produce the type of output.
62  *  There are multiple implementation of AD in relay, with different characteristic.
63  *  However, they all transform the input expr according to WithGradientType.
64  */
65 Type WithGradientType(const Type&);
66 
67 /*! return an expression that represent differentiation of e (according to WithGradientType).
68  *  This version only work on first order code without control flow.
69  */
70 Expr FirstOrderGradient(const Expr& e, const Module& mod);
71 
WithGradientType(const Type & t)72 Type WithGradientType(const Type& t) {
73   // TODO(M.K.): stricter checking
74   auto ty = t.as<FuncTypeNode>();
75   CHECK(ty) << "input should be a function";
76   return FuncTypeNode::make(ty->arg_types,
77                             TupleTypeNode::make({
78                               ty->ret_type,
79                               TupleTypeNode::make(ty->arg_types)}), {}, {});
80 }
81 
82 //! \brief if the expression is a GlobalVar, transform to it's expression.
DeGlobal(const Module & mod,const Expr & e)83 Expr DeGlobal(const Module& mod, const Expr& e) {
84   if (const auto* x = e.as<GlobalVarNode>()) {
85     return mod->Lookup(GetRef<GlobalVar>(x))->body;
86   } else {
87     return e;
88   }
89 }
90 
91 /*! \brief A fragment of the program being built by the automatic differentation
92  *  pass.
93  */
94 struct ADValueNode {
~ADValueNodetvm::relay::ADValueNode95   virtual ~ADValueNode() { }
96   template <typename T>
gettvm::relay::ADValueNode97   T& get() {
98     auto ret = dynamic_cast<T*>(this);
99     CHECK(ret) << "cannot downcast";
100     return *ret;
101   }
102 };
103 
104 using ADValue = std::shared_ptr<ADValueNode>;
105 
106 /*! \brief AD over a program which generates a tensor output. */
107 struct ADTensor : ADValueNode {
108   Expr forward;
109   mutable Expr reverse;  // must be a variable to avoid duplication
ADTensortvm::relay::ADTensor110   ADTensor(LetList* ll, const Expr& forward) :
111     forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) {
112     this->forward->checked_type_ = forward->checked_type();
113   }
114 };
115 
116 /*! \brief A staged representation of the program, we reflect
117  * Relay functions into a function over fragments of AD. We
118  * can compute away this function to obtain a reverse mode program.
119  */
120 struct ADFunction : ADValueNode {
121   std::function<ADValue(const Type&,
122                         const std::vector<ADValue>&,
123                         const Attrs&,
124                         const tvm::Array<Type>&)> func;
ADFunctiontvm::relay::ADFunction125   explicit ADFunction(const std::function<ADValue(const Type&,
126                                                   const std::vector<ADValue>&,
127                                                   const Attrs&,
128                                                   const tvm::Array<Type>&)>& func) :
129     func(func) { }
130 };
131 
132 struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
133   const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
134   std::vector<std::function<void(LetList* ll)>> backprop_actions;
135   // we assume no closure so no need for lexical scoping
136   std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env;
137   LetList* ll;
138 
FirstOrderReverseADtvm::relay::FirstOrderReverseAD139   FirstOrderReverseAD(LetList* ll) : ll(ll) { }
140 
VisitExpr_tvm::relay::FirstOrderReverseAD141   ADValue VisitExpr_(const OpNode* op) final {
142     Op op_ref = GetRef<Op>(op);
143     CHECK(rev_map.count(op_ref))
144       << op->name << " does not have reverse mode defined";
145     return std::make_shared<ADFunction>([this, op_ref](const Type& orig_type,
146                                                        const std::vector<ADValue>& args,
147                                                        const Attrs& attrs,
148                                                        const tvm::Array<Type>& type_args) {
149       std::vector<Expr> call_args;
150       for (const ADValue& adval : args) {
151         call_args.push_back(adval->get<ADTensor>().forward);
152       }
153       auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
154       orig->checked_type_ = orig_type;
155       auto ret = std::make_shared<ADTensor>(ll, orig);
156       backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
157         tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
158         CHECK(args.size() == rev.size());
159         for (size_t i = 0; i < args.size(); ++i) {
160           args[i]->get<ADTensor>().reverse =
161             ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
162         }
163       });
164       return ret;
165     });
166   }
167 
VisitExpr_tvm::relay::FirstOrderReverseAD168   ADValue VisitExpr_(const ConstantNode* op) final {
169     Expr e = GetRef<Expr>(op);
170     return std::make_shared<ADTensor>(ll, e);
171   }
172 
VisitExpr_tvm::relay::FirstOrderReverseAD173   ADValue VisitExpr_(const CallNode* op) final {
174     ADValue f = VisitExpr(op->op);
175     std::vector<ADValue> args;
176     for (const auto& arg : op->args) {
177       args.push_back(VisitExpr(arg));
178     }
179     return f->get<ADFunction>().func(op->checked_type(), args, op->attrs, op->type_args);
180   }
181 
VisitExpr_tvm::relay::FirstOrderReverseAD182   ADValue VisitExpr_(const FunctionNode* op) final {
183     Function f = GetRef<Function>(op);
184     // todo: assert no closure
185     return std::make_shared<ADFunction>([this, f](const Type& orig_type,
186                                                   const std::vector<ADValue>& args,
187                                                   const Attrs& attrs,
188                                                   const tvm::Array<Type>& type_args) {
189         CHECK_EQ(f->params.size(), args.size());
190         for (size_t i = 0; i < f->params.size(); ++i) {
191           env[f->params[i]] = args[i];
192         }
193         return VisitExpr(f->body);
194       });
195   }
196 
VisitExpr_tvm::relay::FirstOrderReverseAD197   ADValue VisitExpr_(const VarNode* op) final {
198     Var v = GetRef<Var>(op);
199     return env.at(v);
200   }
201 };
202 
GradRetType(const Function & f)203 Type GradRetType(const Function& f) {
204   // if type annotations are provided, we will construct a ret type;
205   // otherwise, leave it to be inferred
206   if (!f->ret_type.defined()) {
207     return Type();
208   }
209   std::vector<Type> vt;
210   for (const auto& p : f->params) {
211     if (!p->type_annotation.defined()) {
212       return Type();
213     }
214     vt.push_back(p->type_annotation);
215   }
216 
217   return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
218 }
219 
FirstOrderGradient(const Expr & re,const Module & mod)220 Expr FirstOrderGradient(const Expr& re, const Module& mod) {
221   // Currently we first remove any global functions for the first
222   // order case.
223   auto e = DeGlobal(mod, re);
224   auto f = e.as<FunctionNode>();
225   CHECK(f) << "FOWithGradient expects its argument to be a function: " << f;
226   CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
227 
228   // We will then build a sequence of lets which implement reverse mode.
229   Expr body = LetList::With([&](LetList* ll) {
230     FirstOrderReverseAD reverse_ad(ll);
231     ADValue rev = reverse_ad(e);
232     std::vector<ADValue> args;
233     for (const auto& p : f->params) {
234       args.push_back(std::make_shared<ADTensor>(ll, p));
235     }
236     auto c = rev->get<ADFunction>().func(f->checked_type(), args, Attrs(), {});
237     const auto& res = c->get<ADTensor>();
238     Expr grad = LetList::With([&](LetList* ll) {
239       res.reverse = OnesLike(res.forward);
240       for (auto it = reverse_ad.backprop_actions.rbegin();
241            it != reverse_ad.backprop_actions.rend();
242            ++it) {
243         (*it)(ll);
244       }
245       std::vector<Expr> grad_res;
246       for (const auto& a : args) {
247         grad_res.push_back(a->get<ADTensor>().reverse);
248       }
249       return TupleNode::make(grad_res);
250     });
251     return Pair(res.forward, grad);
252   });
253 
254   return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
255 }
256 
257 TVM_REGISTER_API("relay._transform.first_order_gradient")
258 .set_body_typed(FirstOrderGradient);
259 
260 struct ReverseADType : TypeMutator {
VisitType_tvm::relay::ReverseADType261   Type VisitType_(const TensorTypeNode* ttn) final {
262     Type t = GetRef<Type>(ttn);
263     return TupleTypeNode::make({t, RefTypeNode::make(t)});
264   }
265 };
266 
ReverseType(const Type & t)267 Type ReverseType(const Type& t) {
268   return ReverseADType()(t);
269 }
270 
271 /*! \brief Lift a function that transform Tensor to a function that also transform more type
272  * by doing a structure preserving map.
273  */
LiftTensor(const std::function<Expr (const Expr & t)> & f,const std::function<Type (const Type &)> & tf,const Type & forward_type,const Expr & e,LetList * ll)274 Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
275                 const std::function<Type(const Type&)>& tf,
276                 const Type& forward_type,
277                 const Expr& e,
278                 LetList* ll) {
279   CHECK(IsAtomic(e)) << e;
280   if (forward_type.as<TensorTypeNode>()) {
281     auto ret = f(e);
282     ret->checked_type_ = tf(forward_type);
283     return ret;
284   } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
285     tvm::Array<Expr> fields;
286     tvm::Array<Type> types;
287     for (size_t i = 0; i < tt->fields.size(); ++i) {
288       auto field = LiftTensor(f,
289                               tf,
290                               tt->fields[i],
291                               ll->Push(GetField(e, i)),
292                               ll);
293       fields.push_back(field);
294       types.push_back(field->checked_type_);
295     }
296     auto ret = TupleNode::make(fields);
297     ret->checked_type_ = TupleTypeNode::make(types);
298     return std::move(ret);
299   } else {
300     LOG(FATAL) << "unsupported input/output type: " << tt;
301     throw;
302   }
303 }
304 
305 /*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
306  * by stitching the references in the AD values.
307  */
TransferGrads(const Type & forward_type,const Expr & from,const Expr & to,LetList * ll)308 void TransferGrads(const Type& forward_type,
309                    const Expr& from,
310                    const Expr& to,
311                    LetList* ll) {
312   CHECK(IsAtomic(from)) << from;
313   CHECK(IsAtomic(to)) << to;
314   if (forward_type.as<TensorTypeNode>()) {
315     auto from_ref = TupleGetItemNode::make(from, 1);
316     auto to_ref = TupleGetItemNode::make(to, 1);
317     ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref)));
318   } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
319     for (size_t i = 0; i < tt->fields.size(); ++i) {
320       TransferGrads(tt->fields[i],
321                     ll->Push(TupleGetItemNode::make(from, i)),
322                     ll->Push(TupleGetItemNode::make(to, i)),
323                     ll);
324     }
325   } else {
326     LOG(FATAL) << "Unsupported input/output type: " << forward_type;
327     throw;
328   }
329 }
330 
331 /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
GetRev(const Type & forward_type,const Expr & e,LetList * ll)332 Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
333   auto rev = [&](const Expr& e) {
334     return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
335   };
336   auto rev_type = [&](const Type& forward_type) {
337     return ReverseType(forward_type);
338   };
339   return LiftTensor(rev, rev_type, forward_type, e, ll);
340 }
341 
342 /*! \brief ReverseType(t) -> t. Get the original value. */
GetValue(const Type & forward_type,const Expr & e,LetList * ll)343 Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
344   auto val = [&](const Expr& e) {
345     return GetField(e, 0);
346   };
347   auto val_type = [&](const Type& forward_type) {
348     return forward_type;
349   };
350   return LiftTensor(val, val_type, forward_type, e, ll);
351 }
352 
353 /*! \brief ReverseType(t) -> t. Get the gradient. */
GetGrad(const Type & forward_type,const Expr & e,LetList * ll)354 Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
355   auto grad = [&](const Expr& e) {
356     return ll->Push(RefReadNode::make(GetField(e, 1)));
357   };
358   auto grad_type = [&](const Type& forward_type) {
359     return forward_type;
360   };
361   return LiftTensor(grad, grad_type, forward_type, e, ll);
362 }
363 
UpdateGrad(const Type & t,const Expr & arg,const Expr & grad,LetList * ll)364 void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
365   if (t.as<TensorTypeNode>()) {
366     ll->Push(RefWriteNode::make(GetField(arg, 1),
367                                 Add(ll->Push(RefReadNode::make(GetField(arg, 1))),
368                                     grad)));
369   } else if (auto* tt = t.as<TupleTypeNode>()) {
370     for (size_t i = 0; i < tt->fields.size(); ++i) {
371       UpdateGrad(tt->fields[i],
372                  ll->Push(GetField(arg, i)),
373                  ll->Push(GetField(grad, i)),
374                  ll);
375     }
376   } else {
377     LOG(FATAL) << "unsupported arg type of operator: " << t;
378     throw;
379   }
380 }
381 
BPEmpty()382 Expr BPEmpty() {
383   Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
384   return RefCreateNode::make(unitF);
385 }
386 
387 struct ReverseAD : ExprMutator {
388   using ADVarMap = std::unordered_map<Var, Var, NodeHash, NodeEqual>;
389 
390   Var bp;
391   std::shared_ptr<ADVarMap> ad_vars;
392   const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
393 
ReverseADtvm::relay::ReverseAD394   explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars)
395       : bp(bp), ad_vars(ad_vars) { }
396 
VisitExpr_tvm::relay::ReverseAD397   Expr VisitExpr_(const OpNode* op) final {
398     LOG(FATAL) << "op should only be inside call";
399     throw;
400   }
401 
VisitCheckpointtvm::relay::ReverseAD402   Expr VisitCheckpoint(const CallNode *call) {
403     const OpNode* op_node = call->op.as<OpNode>();
404     CHECK(op_node) << "expected op in call";
405     Op op_ref = GetRef<Op>(op_node);
406     CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
407     auto x = call->args[0];
408     return LetList::With([&](LetList* ll) {
409       auto x_var = ll->Push(x);
410       auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
411       auto bpv = ll->Push(RefReadNode::make(bp));
412       Expr nbp = FunctionNode::make(
413         {},
414         LetList::With([&](LetList* ll) {
415           // we need a new ReverseAD visitor to avoid clobbering the bp local var
416           auto dup_bp = ll->Push(BPEmpty());
417           ReverseAD dup_diff(dup_bp, ad_vars);
418           auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
419 
420           TransferGrads(call->checked_type(), ret, dup_ad, ll);
421           ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
422           return CallNode::make(bpv, {});
423         }),
424         TupleTypeNode::make({}),
425         {});
426       ll->Push(RefWriteNode::make(bp, nbp));
427       return ret;
428     });
429   }
430 
VisitExpr_tvm::relay::ReverseAD431   Expr VisitExpr_(const CallNode* call) final {
432     if (const OpNode* op_node = call->op.as<OpNode>()) {
433       Op op_ref = GetRef<Op>(op_node);
434 
435       if (op_ref->name == "annotation.checkpoint") {
436         return VisitCheckpoint(call);
437       }
438 
439       CHECK(rev_map.count(op_ref))
440         << op_node->name << " does not have reverse mode defined";
441       return LetList::With([&](LetList* ll) {
442         std::vector<Var> args;
443         for (const auto& arg : call->args) {
444           args.push_back(ll->Push(VisitExpr(arg)));
445         }
446         std::vector<Expr> orig_args;
447         for (size_t i = 0; i < args.size(); i++) {
448           orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
449         }
450         Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args);
451         orig->checked_type_ = call->checked_type();
452         Var orig_var = ll->Push(orig);
453         orig_var->checked_type_ = call->checked_type();
454         auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
455         auto bpv = ll->Push(RefReadNode::make(bp));
456         Expr nbp = FunctionNode::make(
457           {},
458           LetList::With([&](LetList* ll) {
459             tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
460             CHECK(args.size() == rev.size());
461             for (size_t i = 0; i < args.size(); ++i) {
462               UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
463             }
464             return CallNode::make(bpv, {});
465           }),
466           TupleTypeNode::make({}),
467           {});
468         ll->Push(RefWriteNode::make(bp, nbp));
469         return ret;
470       });
471     }
472     return ExprMutator::VisitExpr_(call);
473   }
474 
VisitExpr_tvm::relay::ReverseAD475   Expr VisitExpr_(const ConstantNode* op) final {
476     Expr e = GetRef<Expr>(op);
477     return Pair(e, RefCreateNode::make(ZerosLike(e)));
478   }
479 
VisitExpr_tvm::relay::ReverseAD480   Expr VisitExpr_(const IfNode* op) final {
481     return IfNode::make(TupleGetItemNode::make(VisitExpr(op->cond), 0),
482                         VisitExpr(op->true_branch),
483                         VisitExpr(op->false_branch));
484   }
485 
VisitExpr_tvm::relay::ReverseAD486   Expr VisitExpr_(const VarNode* var) final {
487     // memoize Var -> ADVar so we don't end up with free Vars when checkpointing
488     auto var_ref = GetRef<Var>(var);
489     if (!ad_vars->count(var_ref)) {
490       auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
491       (*ad_vars)[var_ref] = res;
492     }
493 
494     return ad_vars->at(var_ref);
495   }
496 
VisitTypetvm::relay::ReverseAD497   Type VisitType(const Type& t) final {
498     return t.defined() ? ReverseType(t) : t;
499   }
500 };
501 
MissingGrad(const Expr & e)502 bool MissingGrad(const Expr& e) {
503   struct MGVisitor : ExprVisitor {
504     const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
505     std::unordered_set<std::string> op_names;
506 
507     void VisitExpr_(const OpNode* op) final {
508       Op op_ref = GetRef<Op>(op);
509       if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) {
510         op_names.insert(op_ref->name);
511       }
512       ExprVisitor::VisitExpr_(op);
513     }
514   };
515 
516   MGVisitor mg;
517   mg.VisitExpr(e);
518 
519   if (mg.op_names.size() > 0) {
520     LOG(WARNING) << "found operators with missing gradients:";
521     for (const auto& op : mg.op_names) {
522       LOG(WARNING) << "    " << op;
523     }
524     return true;
525   }
526 
527   return false;
528 }
529 
Gradient(const Expr & re,const Module & mod)530 Expr Gradient(const Expr& re, const Module& mod) {
531   auto e = DeGlobal(mod, re);
532   auto f = e.as<FunctionNode>();
533   CHECK(f) << "input need to be a function";
534   CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
535   for (const auto& p : f->params) {
536     CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor";
537   }
538   CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
539   Expr body = LetList::With([&](LetList* ll) {
540     Var bp = ll->Push(BPEmpty());
541     Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
542     std::vector<Expr> args;
543     for (const auto& p : f->params) {
544       args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
545     }
546     auto c = ll->Push(CallNode::make(rev, args));
547     std::function<void(const Expr&, const Type&)> init_grad;
548     init_grad = [&](const Expr& e, const Type& t) {
549       if (t.as<TensorTypeNode>()) {
550         ll->Push(RefWriteNode::make(GetField(e, 1), OnesLike(GetField(e, 0))));
551       } else if (auto tt = t.as<TupleTypeNode>()) {
552         CHECK_GT(tt->fields.size(), 0);
553         init_grad(ll->Push(GetField(e, 0)), tt->fields[0]);
554       } else {
555         LOG(FATAL) << "unhandled type " << t;
556         throw;
557       }
558     };
559     init_grad(c, f->body->checked_type());
560     ll->Push(CallNode::make(RefReadNode::make(bp), {}));
561     std::vector<Expr> ret;
562     for (const auto& a : args) {
563       ret.push_back(RefReadNode::make(GetField(a, 1)));
564     }
565     std::function<Expr(const Expr&, const Type&)> get_final_result;
566     get_final_result = [&](const Expr& e, const Type& t) -> Expr {
567       if (t.as<TensorTypeNode>()) {
568         return GetField(e, 0);
569       } else if (auto tt = t.as<TupleTypeNode>()) {
570         tvm::Array<Expr> fields;
571         for (size_t i = 0; i < tt->fields.size(); ++i) {
572           fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i]));
573         }
574         return TupleNode::make(fields);
575       } else {
576         LOG(FATAL) << "unhandled type " << t;
577         throw;
578       }
579     };
580     return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret));
581   });
582   return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
583 }
584 
585 TVM_REGISTER_API("relay._transform.gradient")
586 .set_body_typed(Gradient);
587 
588 }  // namespace relay
589 }  // namespace tvm
590