1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file ad.cc
22 * \brief API for Automatic Differentiation for the Relay IR.
23 */
24
25 #include <tvm/lowered_func.h>
26 #include <tvm/operation.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/relay/analysis.h>
29 #include <tvm/relay/transform.h>
30 #include "pattern_util.h"
31 #include "pass_util.h"
32 #include "let_list.h"
33 #include "../ir/type_functor.h"
34
35 namespace tvm {
36 namespace relay {
37
38 using namespace tvm::runtime;
39
40 /*! What is automatic differentiation(AD) and why is it important?
41 * By AD, we roughly mean, given a term which denotes some mathematical function,
42 * derive a term which denotes the derivative of that mathematical function.
43 * Such a method can be compile-time, which is a macro on completely known function.
44 * Formally speaking, such requirement mean that the input function is a closed expression -
45 * that is, it only refer to local variable that is it's parameter, or defined inside it.
46 * Every top level definition satisfy this criteria.
47 * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> (Float[] -> Float[]).
48 * In relay we currently only support compile-time AD, but it should be enough for a lot of use case.
49 *
50 * In deep learning, the most common way to train a deep neural network is by gradient descent or some of it's variant.
51 * Such optimization method require us to input the gradient of neural network, which can be obtained easily using AD.
52 * In fact, back propagation is essentially reverse-mode automatic differentiation, a kind of AD!
53 */
54
55 /*! In relay, automatic differentiation(AD) is a macro,
56 * that transform closed expr(expr without free variable/free type variable) of type
57 * (x0, x1, x2, ...) -> Float[] to
58 * (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)),
59 * When x0, x1, x2... are Float of different shape.
60 * the return value is a pair, with left hand side as the original value, and right hand side as gradient of the input.
61 * WithGradientType will take the type of input, and produce the type of output.
62 * There are multiple implementation of AD in relay, with different characteristic.
63 * However, they all transform the input expr according to WithGradientType.
64 */
65 Type WithGradientType(const Type&);
66
67 /*! return an expression that represent differentiation of e (according to WithGradientType).
68 * This version only work on first order code without control flow.
69 */
70 Expr FirstOrderGradient(const Expr& e, const Module& mod);
71
WithGradientType(const Type & t)72 Type WithGradientType(const Type& t) {
73 // TODO(M.K.): stricter checking
74 auto ty = t.as<FuncTypeNode>();
75 CHECK(ty) << "input should be a function";
76 return FuncTypeNode::make(ty->arg_types,
77 TupleTypeNode::make({
78 ty->ret_type,
79 TupleTypeNode::make(ty->arg_types)}), {}, {});
80 }
81
82 //! \brief if the expression is a GlobalVar, transform to it's expression.
DeGlobal(const Module & mod,const Expr & e)83 Expr DeGlobal(const Module& mod, const Expr& e) {
84 if (const auto* x = e.as<GlobalVarNode>()) {
85 return mod->Lookup(GetRef<GlobalVar>(x))->body;
86 } else {
87 return e;
88 }
89 }
90
91 /*! \brief A fragment of the program being built by the automatic differentation
92 * pass.
93 */
94 struct ADValueNode {
~ADValueNodetvm::relay::ADValueNode95 virtual ~ADValueNode() { }
96 template <typename T>
gettvm::relay::ADValueNode97 T& get() {
98 auto ret = dynamic_cast<T*>(this);
99 CHECK(ret) << "cannot downcast";
100 return *ret;
101 }
102 };
103
104 using ADValue = std::shared_ptr<ADValueNode>;
105
106 /*! \brief AD over a program which generates a tensor output. */
107 struct ADTensor : ADValueNode {
108 Expr forward;
109 mutable Expr reverse; // must be a variable to avoid duplication
ADTensortvm::relay::ADTensor110 ADTensor(LetList* ll, const Expr& forward) :
111 forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) {
112 this->forward->checked_type_ = forward->checked_type();
113 }
114 };
115
116 /*! \brief A staged representation of the program, we reflect
117 * Relay functions into a function over fragments of AD. We
118 * can compute away this function to obtain a reverse mode program.
119 */
120 struct ADFunction : ADValueNode {
121 std::function<ADValue(const Type&,
122 const std::vector<ADValue>&,
123 const Attrs&,
124 const tvm::Array<Type>&)> func;
ADFunctiontvm::relay::ADFunction125 explicit ADFunction(const std::function<ADValue(const Type&,
126 const std::vector<ADValue>&,
127 const Attrs&,
128 const tvm::Array<Type>&)>& func) :
129 func(func) { }
130 };
131
132 struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
133 const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
134 std::vector<std::function<void(LetList* ll)>> backprop_actions;
135 // we assume no closure so no need for lexical scoping
136 std::unordered_map<Var, ADValue, NodeHash, NodeEqual> env;
137 LetList* ll;
138
FirstOrderReverseADtvm::relay::FirstOrderReverseAD139 FirstOrderReverseAD(LetList* ll) : ll(ll) { }
140
VisitExpr_tvm::relay::FirstOrderReverseAD141 ADValue VisitExpr_(const OpNode* op) final {
142 Op op_ref = GetRef<Op>(op);
143 CHECK(rev_map.count(op_ref))
144 << op->name << " does not have reverse mode defined";
145 return std::make_shared<ADFunction>([this, op_ref](const Type& orig_type,
146 const std::vector<ADValue>& args,
147 const Attrs& attrs,
148 const tvm::Array<Type>& type_args) {
149 std::vector<Expr> call_args;
150 for (const ADValue& adval : args) {
151 call_args.push_back(adval->get<ADTensor>().forward);
152 }
153 auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
154 orig->checked_type_ = orig_type;
155 auto ret = std::make_shared<ADTensor>(ll, orig);
156 backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
157 tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse);
158 CHECK(args.size() == rev.size());
159 for (size_t i = 0; i < args.size(); ++i) {
160 args[i]->get<ADTensor>().reverse =
161 ll->Push(Add(args[i]->get<ADTensor>().reverse, rev[i]));
162 }
163 });
164 return ret;
165 });
166 }
167
VisitExpr_tvm::relay::FirstOrderReverseAD168 ADValue VisitExpr_(const ConstantNode* op) final {
169 Expr e = GetRef<Expr>(op);
170 return std::make_shared<ADTensor>(ll, e);
171 }
172
VisitExpr_tvm::relay::FirstOrderReverseAD173 ADValue VisitExpr_(const CallNode* op) final {
174 ADValue f = VisitExpr(op->op);
175 std::vector<ADValue> args;
176 for (const auto& arg : op->args) {
177 args.push_back(VisitExpr(arg));
178 }
179 return f->get<ADFunction>().func(op->checked_type(), args, op->attrs, op->type_args);
180 }
181
VisitExpr_tvm::relay::FirstOrderReverseAD182 ADValue VisitExpr_(const FunctionNode* op) final {
183 Function f = GetRef<Function>(op);
184 // todo: assert no closure
185 return std::make_shared<ADFunction>([this, f](const Type& orig_type,
186 const std::vector<ADValue>& args,
187 const Attrs& attrs,
188 const tvm::Array<Type>& type_args) {
189 CHECK_EQ(f->params.size(), args.size());
190 for (size_t i = 0; i < f->params.size(); ++i) {
191 env[f->params[i]] = args[i];
192 }
193 return VisitExpr(f->body);
194 });
195 }
196
VisitExpr_tvm::relay::FirstOrderReverseAD197 ADValue VisitExpr_(const VarNode* op) final {
198 Var v = GetRef<Var>(op);
199 return env.at(v);
200 }
201 };
202
GradRetType(const Function & f)203 Type GradRetType(const Function& f) {
204 // if type annotations are provided, we will construct a ret type;
205 // otherwise, leave it to be inferred
206 if (!f->ret_type.defined()) {
207 return Type();
208 }
209 std::vector<Type> vt;
210 for (const auto& p : f->params) {
211 if (!p->type_annotation.defined()) {
212 return Type();
213 }
214 vt.push_back(p->type_annotation);
215 }
216
217 return TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)});
218 }
219
FirstOrderGradient(const Expr & re,const Module & mod)220 Expr FirstOrderGradient(const Expr& re, const Module& mod) {
221 // Currently we first remove any global functions for the first
222 // order case.
223 auto e = DeGlobal(mod, re);
224 auto f = e.as<FunctionNode>();
225 CHECK(f) << "FOWithGradient expects its argument to be a function: " << f;
226 CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
227
228 // We will then build a sequence of lets which implement reverse mode.
229 Expr body = LetList::With([&](LetList* ll) {
230 FirstOrderReverseAD reverse_ad(ll);
231 ADValue rev = reverse_ad(e);
232 std::vector<ADValue> args;
233 for (const auto& p : f->params) {
234 args.push_back(std::make_shared<ADTensor>(ll, p));
235 }
236 auto c = rev->get<ADFunction>().func(f->checked_type(), args, Attrs(), {});
237 const auto& res = c->get<ADTensor>();
238 Expr grad = LetList::With([&](LetList* ll) {
239 res.reverse = OnesLike(res.forward);
240 for (auto it = reverse_ad.backprop_actions.rbegin();
241 it != reverse_ad.backprop_actions.rend();
242 ++it) {
243 (*it)(ll);
244 }
245 std::vector<Expr> grad_res;
246 for (const auto& a : args) {
247 grad_res.push_back(a->get<ADTensor>().reverse);
248 }
249 return TupleNode::make(grad_res);
250 });
251 return Pair(res.forward, grad);
252 });
253
254 return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
255 }
256
257 TVM_REGISTER_API("relay._transform.first_order_gradient")
258 .set_body_typed(FirstOrderGradient);
259
260 struct ReverseADType : TypeMutator {
VisitType_tvm::relay::ReverseADType261 Type VisitType_(const TensorTypeNode* ttn) final {
262 Type t = GetRef<Type>(ttn);
263 return TupleTypeNode::make({t, RefTypeNode::make(t)});
264 }
265 };
266
ReverseType(const Type & t)267 Type ReverseType(const Type& t) {
268 return ReverseADType()(t);
269 }
270
271 /*! \brief Lift a function that transform Tensor to a function that also transform more type
272 * by doing a structure preserving map.
273 */
LiftTensor(const std::function<Expr (const Expr & t)> & f,const std::function<Type (const Type &)> & tf,const Type & forward_type,const Expr & e,LetList * ll)274 Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
275 const std::function<Type(const Type&)>& tf,
276 const Type& forward_type,
277 const Expr& e,
278 LetList* ll) {
279 CHECK(IsAtomic(e)) << e;
280 if (forward_type.as<TensorTypeNode>()) {
281 auto ret = f(e);
282 ret->checked_type_ = tf(forward_type);
283 return ret;
284 } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
285 tvm::Array<Expr> fields;
286 tvm::Array<Type> types;
287 for (size_t i = 0; i < tt->fields.size(); ++i) {
288 auto field = LiftTensor(f,
289 tf,
290 tt->fields[i],
291 ll->Push(GetField(e, i)),
292 ll);
293 fields.push_back(field);
294 types.push_back(field->checked_type_);
295 }
296 auto ret = TupleNode::make(fields);
297 ret->checked_type_ = TupleTypeNode::make(types);
298 return std::move(ret);
299 } else {
300 LOG(FATAL) << "unsupported input/output type: " << tt;
301 throw;
302 }
303 }
304
305 /*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
306 * by stitching the references in the AD values.
307 */
TransferGrads(const Type & forward_type,const Expr & from,const Expr & to,LetList * ll)308 void TransferGrads(const Type& forward_type,
309 const Expr& from,
310 const Expr& to,
311 LetList* ll) {
312 CHECK(IsAtomic(from)) << from;
313 CHECK(IsAtomic(to)) << to;
314 if (forward_type.as<TensorTypeNode>()) {
315 auto from_ref = TupleGetItemNode::make(from, 1);
316 auto to_ref = TupleGetItemNode::make(to, 1);
317 ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref)));
318 } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
319 for (size_t i = 0; i < tt->fields.size(); ++i) {
320 TransferGrads(tt->fields[i],
321 ll->Push(TupleGetItemNode::make(from, i)),
322 ll->Push(TupleGetItemNode::make(to, i)),
323 ll);
324 }
325 } else {
326 LOG(FATAL) << "Unsupported input/output type: " << forward_type;
327 throw;
328 }
329 }
330
331 /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
GetRev(const Type & forward_type,const Expr & e,LetList * ll)332 Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
333 auto rev = [&](const Expr& e) {
334 return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
335 };
336 auto rev_type = [&](const Type& forward_type) {
337 return ReverseType(forward_type);
338 };
339 return LiftTensor(rev, rev_type, forward_type, e, ll);
340 }
341
342 /*! \brief ReverseType(t) -> t. Get the original value. */
GetValue(const Type & forward_type,const Expr & e,LetList * ll)343 Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
344 auto val = [&](const Expr& e) {
345 return GetField(e, 0);
346 };
347 auto val_type = [&](const Type& forward_type) {
348 return forward_type;
349 };
350 return LiftTensor(val, val_type, forward_type, e, ll);
351 }
352
353 /*! \brief ReverseType(t) -> t. Get the gradient. */
GetGrad(const Type & forward_type,const Expr & e,LetList * ll)354 Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
355 auto grad = [&](const Expr& e) {
356 return ll->Push(RefReadNode::make(GetField(e, 1)));
357 };
358 auto grad_type = [&](const Type& forward_type) {
359 return forward_type;
360 };
361 return LiftTensor(grad, grad_type, forward_type, e, ll);
362 }
363
UpdateGrad(const Type & t,const Expr & arg,const Expr & grad,LetList * ll)364 void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
365 if (t.as<TensorTypeNode>()) {
366 ll->Push(RefWriteNode::make(GetField(arg, 1),
367 Add(ll->Push(RefReadNode::make(GetField(arg, 1))),
368 grad)));
369 } else if (auto* tt = t.as<TupleTypeNode>()) {
370 for (size_t i = 0; i < tt->fields.size(); ++i) {
371 UpdateGrad(tt->fields[i],
372 ll->Push(GetField(arg, i)),
373 ll->Push(GetField(grad, i)),
374 ll);
375 }
376 } else {
377 LOG(FATAL) << "unsupported arg type of operator: " << t;
378 throw;
379 }
380 }
381
BPEmpty()382 Expr BPEmpty() {
383 Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
384 return RefCreateNode::make(unitF);
385 }
386
387 struct ReverseAD : ExprMutator {
388 using ADVarMap = std::unordered_map<Var, Var, NodeHash, NodeEqual>;
389
390 Var bp;
391 std::shared_ptr<ADVarMap> ad_vars;
392 const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
393
ReverseADtvm::relay::ReverseAD394 explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars)
395 : bp(bp), ad_vars(ad_vars) { }
396
VisitExpr_tvm::relay::ReverseAD397 Expr VisitExpr_(const OpNode* op) final {
398 LOG(FATAL) << "op should only be inside call";
399 throw;
400 }
401
VisitCheckpointtvm::relay::ReverseAD402 Expr VisitCheckpoint(const CallNode *call) {
403 const OpNode* op_node = call->op.as<OpNode>();
404 CHECK(op_node) << "expected op in call";
405 Op op_ref = GetRef<Op>(op_node);
406 CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
407 auto x = call->args[0];
408 return LetList::With([&](LetList* ll) {
409 auto x_var = ll->Push(x);
410 auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
411 auto bpv = ll->Push(RefReadNode::make(bp));
412 Expr nbp = FunctionNode::make(
413 {},
414 LetList::With([&](LetList* ll) {
415 // we need a new ReverseAD visitor to avoid clobbering the bp local var
416 auto dup_bp = ll->Push(BPEmpty());
417 ReverseAD dup_diff(dup_bp, ad_vars);
418 auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
419
420 TransferGrads(call->checked_type(), ret, dup_ad, ll);
421 ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
422 return CallNode::make(bpv, {});
423 }),
424 TupleTypeNode::make({}),
425 {});
426 ll->Push(RefWriteNode::make(bp, nbp));
427 return ret;
428 });
429 }
430
VisitExpr_tvm::relay::ReverseAD431 Expr VisitExpr_(const CallNode* call) final {
432 if (const OpNode* op_node = call->op.as<OpNode>()) {
433 Op op_ref = GetRef<Op>(op_node);
434
435 if (op_ref->name == "annotation.checkpoint") {
436 return VisitCheckpoint(call);
437 }
438
439 CHECK(rev_map.count(op_ref))
440 << op_node->name << " does not have reverse mode defined";
441 return LetList::With([&](LetList* ll) {
442 std::vector<Var> args;
443 for (const auto& arg : call->args) {
444 args.push_back(ll->Push(VisitExpr(arg)));
445 }
446 std::vector<Expr> orig_args;
447 for (size_t i = 0; i < args.size(); i++) {
448 orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
449 }
450 Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args);
451 orig->checked_type_ = call->checked_type();
452 Var orig_var = ll->Push(orig);
453 orig_var->checked_type_ = call->checked_type();
454 auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
455 auto bpv = ll->Push(RefReadNode::make(bp));
456 Expr nbp = FunctionNode::make(
457 {},
458 LetList::With([&](LetList* ll) {
459 tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
460 CHECK(args.size() == rev.size());
461 for (size_t i = 0; i < args.size(); ++i) {
462 UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
463 }
464 return CallNode::make(bpv, {});
465 }),
466 TupleTypeNode::make({}),
467 {});
468 ll->Push(RefWriteNode::make(bp, nbp));
469 return ret;
470 });
471 }
472 return ExprMutator::VisitExpr_(call);
473 }
474
VisitExpr_tvm::relay::ReverseAD475 Expr VisitExpr_(const ConstantNode* op) final {
476 Expr e = GetRef<Expr>(op);
477 return Pair(e, RefCreateNode::make(ZerosLike(e)));
478 }
479
VisitExpr_tvm::relay::ReverseAD480 Expr VisitExpr_(const IfNode* op) final {
481 return IfNode::make(TupleGetItemNode::make(VisitExpr(op->cond), 0),
482 VisitExpr(op->true_branch),
483 VisitExpr(op->false_branch));
484 }
485
VisitExpr_tvm::relay::ReverseAD486 Expr VisitExpr_(const VarNode* var) final {
487 // memoize Var -> ADVar so we don't end up with free Vars when checkpointing
488 auto var_ref = GetRef<Var>(var);
489 if (!ad_vars->count(var_ref)) {
490 auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
491 (*ad_vars)[var_ref] = res;
492 }
493
494 return ad_vars->at(var_ref);
495 }
496
VisitTypetvm::relay::ReverseAD497 Type VisitType(const Type& t) final {
498 return t.defined() ? ReverseType(t) : t;
499 }
500 };
501
MissingGrad(const Expr & e)502 bool MissingGrad(const Expr& e) {
503 struct MGVisitor : ExprVisitor {
504 const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
505 std::unordered_set<std::string> op_names;
506
507 void VisitExpr_(const OpNode* op) final {
508 Op op_ref = GetRef<Op>(op);
509 if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) {
510 op_names.insert(op_ref->name);
511 }
512 ExprVisitor::VisitExpr_(op);
513 }
514 };
515
516 MGVisitor mg;
517 mg.VisitExpr(e);
518
519 if (mg.op_names.size() > 0) {
520 LOG(WARNING) << "found operators with missing gradients:";
521 for (const auto& op : mg.op_names) {
522 LOG(WARNING) << " " << op;
523 }
524 return true;
525 }
526
527 return false;
528 }
529
Gradient(const Expr & re,const Module & mod)530 Expr Gradient(const Expr& re, const Module& mod) {
531 auto e = DeGlobal(mod, re);
532 auto f = e.as<FunctionNode>();
533 CHECK(f) << "input need to be a function";
534 CHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
535 for (const auto& p : f->params) {
536 CHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor";
537 }
538 CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
539 Expr body = LetList::With([&](LetList* ll) {
540 Var bp = ll->Push(BPEmpty());
541 Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
542 std::vector<Expr> args;
543 for (const auto& p : f->params) {
544 args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
545 }
546 auto c = ll->Push(CallNode::make(rev, args));
547 std::function<void(const Expr&, const Type&)> init_grad;
548 init_grad = [&](const Expr& e, const Type& t) {
549 if (t.as<TensorTypeNode>()) {
550 ll->Push(RefWriteNode::make(GetField(e, 1), OnesLike(GetField(e, 0))));
551 } else if (auto tt = t.as<TupleTypeNode>()) {
552 CHECK_GT(tt->fields.size(), 0);
553 init_grad(ll->Push(GetField(e, 0)), tt->fields[0]);
554 } else {
555 LOG(FATAL) << "unhandled type " << t;
556 throw;
557 }
558 };
559 init_grad(c, f->body->checked_type());
560 ll->Push(CallNode::make(RefReadNode::make(bp), {}));
561 std::vector<Expr> ret;
562 for (const auto& a : args) {
563 ret.push_back(RefReadNode::make(GetField(a, 1)));
564 }
565 std::function<Expr(const Expr&, const Type&)> get_final_result;
566 get_final_result = [&](const Expr& e, const Type& t) -> Expr {
567 if (t.as<TensorTypeNode>()) {
568 return GetField(e, 0);
569 } else if (auto tt = t.as<TupleTypeNode>()) {
570 tvm::Array<Expr> fields;
571 for (size_t i = 0; i < tt->fields.size(); ++i) {
572 fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i]));
573 }
574 return TupleNode::make(fields);
575 } else {
576 LOG(FATAL) << "unhandled type " << t;
577 throw;
578 }
579 };
580 return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret));
581 });
582 return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
583 }
584
585 TVM_REGISTER_API("relay._transform.gradient")
586 .set_body_typed(Gradient);
587
588 } // namespace relay
589 } // namespace tvm
590