Lines Matching refs:op_ref
142 Op op_ref = GetRef<Op>(op); in VisitExpr_() local
143 CHECK(rev_map.count(op_ref)) in VisitExpr_()
145 return std::make_shared<ADFunction>([this, op_ref](const Type& orig_type, in VisitExpr_()
153 auto orig = CallNode::make(op_ref, call_args, attrs, type_args); in VisitExpr_()
156 backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { in VisitExpr_()
157 tvm::Array<Expr> rev = rev_map[op_ref](orig, ret->reverse); in VisitExpr_()
405 Op op_ref = GetRef<Op>(op_node); in VisitCheckpoint() local
406 CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation"; in VisitCheckpoint()
433 Op op_ref = GetRef<Op>(op_node); in VisitExpr_() local
435 if (op_ref->name == "annotation.checkpoint") { in VisitExpr_()
439 CHECK(rev_map.count(op_ref)) in VisitExpr_()
459 tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); in VisitExpr_()
508 Op op_ref = GetRef<Op>(op); in MissingGrad() local
509 if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) { in MissingGrad()
510 op_names.insert(op_ref->name); in MissingGrad()