1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file jacobian.cc
22  * \brief Calculate Jacobian of two tensors dY/dX.
23  *        X must be direct input tensor of Y.
24  *        The result Jacobian shape will be (Y.shape, X.shape)
25  */
26 #include <tvm/arith/analyzer.h>
27 #include <tvm/runtime/registry.h>
28 #include <tvm/te/autodiff.h>
29 #include <tvm/tir/stmt_functor.h>
30 
31 #include <memory>
32 
33 #include "ad_util.h"
34 
35 namespace tvm {
36 namespace te {
37 
38 #define NOT_IMPLEMENTED                                                                   \
39   {                                                                                       \
40     LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef<PrimExpr>(op); \
41     throw;                                                                                \
42   }
43 
44 /*! \brief Differentiate an expression wrt a variable or a tensor element */
45 class JacobianMutator : public ExprMutator {
46  public:
47   /*!
48    * \brief Differentiate wrt `input(indices)`.
49    * \param input The input tensor.
50    * \param indices The indices of the element with respect to which to differentiate.
51    */
JacobianMutator(Tensor input,Array<PrimExpr> indices)52   explicit JacobianMutator(Tensor input, Array<PrimExpr> indices)
53       : input_(input), indices_(indices) {}
54   /*!
55    * \brief Differentiate wrt the input variable.
56    * \param input The input variable.
57    */
JacobianMutator(Var input)58   explicit JacobianMutator(Var input) : input_var_(input) {}
59 
Mutate(PrimExpr e)60   PrimExpr Mutate(PrimExpr e) {
61     if (e.dtype().is_int() || e.dtype().is_uint()) {
62       LOG(WARNING) << "For now we assume that the derivative of any integer expression is always 0."
63                    << " e = " << e;
64       return make_zero(e.dtype());
65     } else {
66       return ExprMutator::VisitExpr(e);
67     }
68   }
69 
VisitExpr_(const VarNode * op)70   PrimExpr VisitExpr_(const VarNode* op) {
71     if (input_var_.get() && input_var_.get() == op && op->dtype.is_float()) {
72       return FloatImm(op->dtype, 1.0);
73     } else {
74       return make_zero(op->dtype);
75     }
76   }
77 
78   PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED;
79   PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED;
80 
VisitExpr_(const ProducerLoadNode * op)81   PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
82     auto tensor = Downcast<te::Tensor>(op->producer);
83     if (input_.get() && tensor == input_) {
84       // Tensor(indices)
85       CHECK_EQ(indices_.size(), op->indices.size());
86       PrimExpr condition = const_true();
87       for (size_t i = 0; i < input_.ndim(); ++i) {
88         condition = And(condition, EQ(indices_[i], op->indices[i]));
89       }
90       return Cast(op->dtype, condition);
91     } else {
92       return make_zero(op->dtype);
93     }
94   }
95 
VisitExpr_(const CallNode * op)96   PrimExpr VisitExpr_(const CallNode* op) {
97     PrimExpr expr = GetRef<PrimExpr>(op);
98     if (op->op.same_as(op_exp_)) {
99       return Mul(Mutate(op->args[0]), expr);
100     } else if (op->op.same_as(op_log_)) {
101       return Div(Mutate(op->args[0]), op->args[0]);
102     } else if (op->op.same_as(op_sigmoid_)) {
103       return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr)));
104     } else if (op->op.same_as(op_sqrt_)) {
105       return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0)));
106     } else if (op->op.same_as(op_tanh_)) {
107       return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr)));
108     } else if (op->op.same_as(op_pow_)) {
109       auto x = op->args[0], y = op->args[1];
110       return expr * (Mutate(y) * log(x) + Mutate(x) * y / x);
111     } else if (op->op.same_as(op_fabs_)) {
112       auto type = op->args[0].dtype();
113       return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), FloatImm(type, 1.0),
114                                              FloatImm(type, -1.0)));
115     } else if (op->op.same_as(op_if_then_else_)) {
116       Array<PrimExpr> new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])};
117       return Call(op->dtype, op->op, new_args);
118     } else if (piecewise_const.count(op->op)) {
119       return FloatImm(expr.dtype(), 0.0);
120     } else {
121       LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op;
122       return PrimExpr();
123     }
124   }
125 
VisitExpr_(const AddNode * op)126   PrimExpr VisitExpr_(const AddNode* op) { return Add(Mutate(op->a), Mutate(op->b)); }
127 
VisitExpr_(const SubNode * op)128   PrimExpr VisitExpr_(const SubNode* op) { return Sub(Mutate(op->a), Mutate(op->b)); }
129 
VisitExpr_(const MulNode * op)130   PrimExpr VisitExpr_(const MulNode* op) {
131     return Add(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b)));
132   }
133 
VisitExpr_(const DivNode * op)134   PrimExpr VisitExpr_(const DivNode* op) {
135     return Div(Sub(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))), Mul(op->b, op->b));
136   }
137 
138   PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED;
139 
VisitExpr_(const FloorDivNode * op)140   PrimExpr VisitExpr_(const FloorDivNode* op) {
141     return FloorDiv(Sub(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))), Mul(op->b, op->b));
142   }
143 
144   PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED;
145 
VisitExpr_(const MinNode * op)146   PrimExpr VisitExpr_(const MinNode* op) {
147     return Select(LE(op->a, op->b), Mutate(op->a), Mutate(op->b));
148   }
149 
VisitExpr_(const MaxNode * op)150   PrimExpr VisitExpr_(const MaxNode* op) {
151     return Select(GE(op->a, op->b), Mutate(op->a), Mutate(op->b));
152   }
153 
154   PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED;
155   PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED;
156   PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED;
157   PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED;
158   PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED;
159   PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED;
160   PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED;
161   PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED;
162 
VisitExpr_(const ReduceNode * op)163   PrimExpr VisitExpr_(const ReduceNode* op) {
164     // This case is relatively difficult because a reduction expression
165     // may use an arbitrary combiner.
166     // The resulting reduction expression will return a tuple containing
167     // both derivatives and the original results (in exactly this order).
168     // The order matters when original init value is different from its derivative init value,
169     // and they depend on each other during gradient calculation,
170     // we must calculate derivatives first (using origin's init value),
171     // switching the order (original results first, then derivatives)
172     // makes the origin value be replaced before using,
173     // produces incorrect results.
174 
175     // Example of a ReduceNode,
176     // reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]),
177     //   source=[A(k)], axis=[iter_var(k, range(min=0, ext=5))], where=(bool)1, value_index=0)
178 
179     // We have to clone the reduction axes because otherwise the original expression
180     // cannot be used together with the derivative (it will lead to errors during lowering)
181     PrimExpr expr_with_new_axes = te::CloneReduction(GetRef<PrimExpr>(op));
182     const ReduceNode* new_op = expr_with_new_axes.as<ReduceNode>();
183 
184     CHECK(new_op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
185 
186     // New lhs and rhs variables of the new combiner consist of
187     // variables representing derivatives (which are later derived from new_op->source)
188     // followed by the original variables.
189     Array<Var> new_lhs;
190     for (const auto& var : new_op->combiner->lhs) {
191       new_lhs.push_back(var.copy_with_suffix(".jac"));
192     }
193     for (const auto& var : new_op->combiner->lhs) {
194       new_lhs.push_back(var);
195     }
196 
197     Array<Var> new_rhs;
198     for (const auto& var : new_op->combiner->rhs) {
199       new_rhs.push_back(var.copy_with_suffix(".jac"));
200     }
201     for (const auto& var : new_op->combiner->rhs) {
202       new_rhs.push_back(var);
203     }
204 
205     // The new combiner result also consists of the resulting derivatives
206     // followed by the original results.
207     Array<PrimExpr> new_result;
208     for (const auto& res : new_op->combiner->result) {
209       // Each resulting derivative is computed as a sum of derivatives
210       // wrt lhs and rhs multiplied by the derivatives of lhs and rhs
211       PrimExpr new_res = make_zero(res.dtype());
212       for (size_t i = 0; i < new_op->combiner->lhs.size(); ++i) {
213         PrimExpr res_di = Derivative(res, new_op->combiner->lhs[i]);
214         // new_lhs[i] is the derivative of lhs[i] (wrt our input tensor)
215         new_res = Add(new_res, Mul(new_lhs[i], res_di));
216       }
217       for (size_t i = 0; i < new_op->combiner->rhs.size(); ++i) {
218         PrimExpr res_di = Derivative(res, new_op->combiner->rhs[i]);
219         // new_rhs[i] is the derivative of rhs[i] (wrt our input tensor)
220         new_res = Add(new_res, Mul(new_rhs[i], res_di));
221       }
222       new_result.push_back(new_res);
223     }
224     // add original results
225     for (const auto& res : new_op->combiner->result) {
226       new_result.push_back(res);
227     }
228 
229     // The identity is transformed in a similar way
230     Array<PrimExpr> new_identity;
231     for (const auto& id : new_op->combiner->identity_element) {
232       new_identity.push_back(Mutate(id));
233     }
234     for (const auto& id : new_op->combiner->identity_element) {
235       new_identity.push_back(id);
236     }
237 
238     // Same as source
239     Array<PrimExpr> new_source;
240     for (const auto& src : new_op->source) {
241       new_source.push_back(Mutate(src));
242     }
243     for (const auto& src : new_op->source) {
244       new_source.push_back(src);
245     }
246 
247     CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity);
248     // Also simplify the resulting combiner
249     // (mostly to get rid of unused components, e.g., the original expressions)
250     return analyzer_.Simplify(Reduce(new_combiner, new_source, new_op->axis, new_op->condition,
251                                      new_op->value_index, new_op->init));
252   }
253 
VisitExpr_(const CastNode * op)254   PrimExpr VisitExpr_(const CastNode* op) {
255     if (op->dtype.is_float()) {
256       return Cast(op->dtype, Mutate(op->value));
257     } else {
258       return make_zero(op->dtype);
259     }
260   }
261 
262   PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED;
263 
VisitExpr_(const SelectNode * op)264   PrimExpr VisitExpr_(const SelectNode* op) {
265     return Select(op->condition, Mutate(op->true_value), Mutate(op->false_value));
266   }
267 
268   PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED;
269   PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED;
270   PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED;
271 
VisitExpr_(const IntImmNode * op)272   PrimExpr VisitExpr_(const IntImmNode* op) { return IntImm(op->dtype, 0); }
273 
VisitExpr_(const FloatImmNode * op)274   PrimExpr VisitExpr_(const FloatImmNode* op) { return FloatImm(op->dtype, 0); }
275 
276   PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED;
277 
278  private:
279   Tensor input_;
280   Array<PrimExpr> indices_;
281   Var input_var_;
282   arith::Analyzer analyzer_;
283 
284   const Op& op_exp_ = Op::Get("tir.exp");
285   const Op& op_log_ = Op::Get("tir.log");
286   const Op& op_sigmoid_ = Op::Get("tir.sigmoid");
287   const Op& op_sqrt_ = Op::Get("tir.sqrt");
288   const Op& op_tanh_ = Op::Get("tir.tanh");
289   const Op& op_pow_ = Op::Get("tir.pow");
290   const Op& op_fabs_ = Op::Get("tir.fabs");
291   const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
292   std::unordered_set<RelayExpr, ObjectPtrHash, ObjectPtrEqual> piecewise_const = {
293       Op::Get("tir.floor"), Op::Get("tir.ceil"), Op::Get("tir.trunc"), Op::Get("tir.round")};
294 };
295 
Derivative(const PrimExpr & expr,const Var & var)296 PrimExpr Derivative(const PrimExpr& expr, const Var& var) {
297   return JacobianMutator(var).Mutate(expr);
298 }
299 
Jacobian(const PrimExpr & expr,const Tensor & input,const Array<PrimExpr> & indices)300 PrimExpr Jacobian(const PrimExpr& expr, const Tensor& input, const Array<PrimExpr>& indices) {
301   return JacobianMutator(input, indices).Mutate(expr);
302 }
303 
Jacobian(const Tensor & output,const Tensor & input)304 Tensor Jacobian(const Tensor& output, const Tensor& input) {
305   const ComputeOpNode* op = output->op.as<ComputeOpNode>();
306   CHECK(op) << "Derivative of this operation is not implemented: " << output->op;
307   bool is_input_tensor = false;
308   for (const Tensor& child : op->InputTensors()) {
309     if (input == child) {
310       is_input_tensor = true;
311       break;
312     }
313   }
314   CHECK(is_input_tensor) << "Jacobian is called on a pair of tensors such that the output "
315                          << "does not directly depend on the input.";
316 
317   // We have to clone the iteration axes because otherwise the original expression
318   // cannot be used together with the derivative (it will lead to errors during lowering)
319   Array<IterVar> new_axis;
320   Map<Var, PrimExpr> vmap;
321   std::tie(new_axis, vmap) = te::CloneIterVars(op->axis);
322 
323   Array<PrimExpr> input_indices;
324   size_t i = 0;
325   for (PrimExpr ext : input->shape) {
326     IterVar new_v =
327         IterVar(Range(0, ext), Var("jac_i" + std::to_string(i++)), IterVarType::kDataPar);
328     // Append jacobian iter to new_axis
329     new_axis.push_back(new_v);
330     // Differentiate wrt input[input_indices]
331     input_indices.push_back(new_v);
332   }
333   arith::Analyzer analzyer;
334   // Compute Jacobian
335   PrimExpr new_body =
336       Jacobian(Substitute(op->body[output->value_index], vmap), input, input_indices);
337   new_body = analzyer.Simplify(new_body);
338 
339   int value_index = 0;
340   Array<PrimExpr> new_bodies;
341 
342   // If this is a reduction then it may return a tuple and we have
343   // to repeat the body several times
344   if (const ReduceNode* red = new_body.as<ReduceNode>()) {
345     value_index = red->value_index;
346     for (size_t idx = 0; idx < red->source.size(); ++idx) {
347       new_bodies.push_back(
348           Reduce(red->combiner, red->source, red->axis, red->condition, idx, red->init));
349     }
350   } else {
351     new_bodies.push_back(new_body);
352   }
353 
354   auto new_op = ComputeOp(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies);
355 
356   // Jacobian shape = output.shape + input.shape
357   Array<PrimExpr> new_shape = output->shape;
358   for (const auto& e : input->shape) {
359     new_shape.push_back(e);
360   }
361 
362   Tensor ret = Tensor(new_shape, output->dtype, new_op, value_index);
363   ret = RemoveJacobianAndLiftNonzeroCond(ret);
364   return ret;
365 }
366 
367 }  // namespace te
368 }  // namespace tvm
369