1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file jacobian.cc
22 * \brief Calculate Jacobian of two tensors dY/dX.
23 * X must be direct input tensor of Y.
24 * The result Jacobian shape will be (Y.shape, X.shape)
25 */
26 #include <tvm/arith/analyzer.h>
27 #include <tvm/runtime/registry.h>
28 #include <tvm/te/autodiff.h>
29 #include <tvm/tir/stmt_functor.h>
30
31 #include <memory>
32
33 #include "ad_util.h"
34
35 namespace tvm {
36 namespace te {
37
38 #define NOT_IMPLEMENTED \
39 { \
40 LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef<PrimExpr>(op); \
41 throw; \
42 }
43
44 /*! \brief Differentiate an expression wrt a variable or a tensor element */
45 class JacobianMutator : public ExprMutator {
46 public:
47 /*!
48 * \brief Differentiate wrt `input(indices)`.
49 * \param input The input tensor.
50 * \param indices The indices of the element with respect to which to differentiate.
51 */
JacobianMutator(Tensor input,Array<PrimExpr> indices)52 explicit JacobianMutator(Tensor input, Array<PrimExpr> indices)
53 : input_(input), indices_(indices) {}
54 /*!
55 * \brief Differentiate wrt the input variable.
56 * \param input The input variable.
57 */
JacobianMutator(Var input)58 explicit JacobianMutator(Var input) : input_var_(input) {}
59
Mutate(PrimExpr e)60 PrimExpr Mutate(PrimExpr e) {
61 if (e.dtype().is_int() || e.dtype().is_uint()) {
62 LOG(WARNING) << "For now we assume that the derivative of any integer expression is always 0."
63 << " e = " << e;
64 return make_zero(e.dtype());
65 } else {
66 return ExprMutator::VisitExpr(e);
67 }
68 }
69
VisitExpr_(const VarNode * op)70 PrimExpr VisitExpr_(const VarNode* op) {
71 if (input_var_.get() && input_var_.get() == op && op->dtype.is_float()) {
72 return FloatImm(op->dtype, 1.0);
73 } else {
74 return make_zero(op->dtype);
75 }
76 }
77
78 PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED;
79 PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED;
80
VisitExpr_(const ProducerLoadNode * op)81 PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
82 auto tensor = Downcast<te::Tensor>(op->producer);
83 if (input_.get() && tensor == input_) {
84 // Tensor(indices)
85 CHECK_EQ(indices_.size(), op->indices.size());
86 PrimExpr condition = const_true();
87 for (size_t i = 0; i < input_.ndim(); ++i) {
88 condition = And(condition, EQ(indices_[i], op->indices[i]));
89 }
90 return Cast(op->dtype, condition);
91 } else {
92 return make_zero(op->dtype);
93 }
94 }
95
VisitExpr_(const CallNode * op)96 PrimExpr VisitExpr_(const CallNode* op) {
97 PrimExpr expr = GetRef<PrimExpr>(op);
98 if (op->op.same_as(op_exp_)) {
99 return Mul(Mutate(op->args[0]), expr);
100 } else if (op->op.same_as(op_log_)) {
101 return Div(Mutate(op->args[0]), op->args[0]);
102 } else if (op->op.same_as(op_sigmoid_)) {
103 return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr)));
104 } else if (op->op.same_as(op_sqrt_)) {
105 return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0)));
106 } else if (op->op.same_as(op_tanh_)) {
107 return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr)));
108 } else if (op->op.same_as(op_pow_)) {
109 auto x = op->args[0], y = op->args[1];
110 return expr * (Mutate(y) * log(x) + Mutate(x) * y / x);
111 } else if (op->op.same_as(op_fabs_)) {
112 auto type = op->args[0].dtype();
113 return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), FloatImm(type, 1.0),
114 FloatImm(type, -1.0)));
115 } else if (op->op.same_as(op_if_then_else_)) {
116 Array<PrimExpr> new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])};
117 return Call(op->dtype, op->op, new_args);
118 } else if (piecewise_const.count(op->op)) {
119 return FloatImm(expr.dtype(), 0.0);
120 } else {
121 LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op;
122 return PrimExpr();
123 }
124 }
125
VisitExpr_(const AddNode * op)126 PrimExpr VisitExpr_(const AddNode* op) { return Add(Mutate(op->a), Mutate(op->b)); }
127
VisitExpr_(const SubNode * op)128 PrimExpr VisitExpr_(const SubNode* op) { return Sub(Mutate(op->a), Mutate(op->b)); }
129
VisitExpr_(const MulNode * op)130 PrimExpr VisitExpr_(const MulNode* op) {
131 return Add(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b)));
132 }
133
VisitExpr_(const DivNode * op)134 PrimExpr VisitExpr_(const DivNode* op) {
135 return Div(Sub(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))), Mul(op->b, op->b));
136 }
137
138 PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED;
139
VisitExpr_(const FloorDivNode * op)140 PrimExpr VisitExpr_(const FloorDivNode* op) {
141 return FloorDiv(Sub(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))), Mul(op->b, op->b));
142 }
143
144 PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED;
145
VisitExpr_(const MinNode * op)146 PrimExpr VisitExpr_(const MinNode* op) {
147 return Select(LE(op->a, op->b), Mutate(op->a), Mutate(op->b));
148 }
149
VisitExpr_(const MaxNode * op)150 PrimExpr VisitExpr_(const MaxNode* op) {
151 return Select(GE(op->a, op->b), Mutate(op->a), Mutate(op->b));
152 }
153
154 PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED;
155 PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED;
156 PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED;
157 PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED;
158 PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED;
159 PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED;
160 PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED;
161 PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED;
162
VisitExpr_(const ReduceNode * op)163 PrimExpr VisitExpr_(const ReduceNode* op) {
164 // This case is relatively difficult because a reduction expression
165 // may use an arbitrary combiner.
166 // The resulting reduction expression will return a tuple containing
167 // both derivatives and the original results (in exactly this order).
168 // The order matters when original init value is different from its derivative init value,
169 // and they depend on each other during gradient calculation,
170 // we must calculate derivatives first (using origin's init value),
171 // switching the order (original results first, then derivatives)
172 // makes the origin value be replaced before using,
173 // produces incorrect results.
174
175 // Example of a ReduceNode,
176 // reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0f]),
177 // source=[A(k)], axis=[iter_var(k, range(min=0, ext=5))], where=(bool)1, value_index=0)
178
179 // We have to clone the reduction axes because otherwise the original expression
180 // cannot be used together with the derivative (it will lead to errors during lowering)
181 PrimExpr expr_with_new_axes = te::CloneReduction(GetRef<PrimExpr>(op));
182 const ReduceNode* new_op = expr_with_new_axes.as<ReduceNode>();
183
184 CHECK(new_op->init.empty()) << "Derivative of Reduction with initialization is not implemented";
185
186 // New lhs and rhs variables of the new combiner consist of
187 // variables representing derivatives (which are later derived from new_op->source)
188 // followed by the original variables.
189 Array<Var> new_lhs;
190 for (const auto& var : new_op->combiner->lhs) {
191 new_lhs.push_back(var.copy_with_suffix(".jac"));
192 }
193 for (const auto& var : new_op->combiner->lhs) {
194 new_lhs.push_back(var);
195 }
196
197 Array<Var> new_rhs;
198 for (const auto& var : new_op->combiner->rhs) {
199 new_rhs.push_back(var.copy_with_suffix(".jac"));
200 }
201 for (const auto& var : new_op->combiner->rhs) {
202 new_rhs.push_back(var);
203 }
204
205 // The new combiner result also consists of the resulting derivatives
206 // followed by the original results.
207 Array<PrimExpr> new_result;
208 for (const auto& res : new_op->combiner->result) {
209 // Each resulting derivative is computed as a sum of derivatives
210 // wrt lhs and rhs multiplied by the derivatives of lhs and rhs
211 PrimExpr new_res = make_zero(res.dtype());
212 for (size_t i = 0; i < new_op->combiner->lhs.size(); ++i) {
213 PrimExpr res_di = Derivative(res, new_op->combiner->lhs[i]);
214 // new_lhs[i] is the derivative of lhs[i] (wrt our input tensor)
215 new_res = Add(new_res, Mul(new_lhs[i], res_di));
216 }
217 for (size_t i = 0; i < new_op->combiner->rhs.size(); ++i) {
218 PrimExpr res_di = Derivative(res, new_op->combiner->rhs[i]);
219 // new_rhs[i] is the derivative of rhs[i] (wrt our input tensor)
220 new_res = Add(new_res, Mul(new_rhs[i], res_di));
221 }
222 new_result.push_back(new_res);
223 }
224 // add original results
225 for (const auto& res : new_op->combiner->result) {
226 new_result.push_back(res);
227 }
228
229 // The identity is transformed in a similar way
230 Array<PrimExpr> new_identity;
231 for (const auto& id : new_op->combiner->identity_element) {
232 new_identity.push_back(Mutate(id));
233 }
234 for (const auto& id : new_op->combiner->identity_element) {
235 new_identity.push_back(id);
236 }
237
238 // Same as source
239 Array<PrimExpr> new_source;
240 for (const auto& src : new_op->source) {
241 new_source.push_back(Mutate(src));
242 }
243 for (const auto& src : new_op->source) {
244 new_source.push_back(src);
245 }
246
247 CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity);
248 // Also simplify the resulting combiner
249 // (mostly to get rid of unused components, e.g., the original expressions)
250 return analyzer_.Simplify(Reduce(new_combiner, new_source, new_op->axis, new_op->condition,
251 new_op->value_index, new_op->init));
252 }
253
VisitExpr_(const CastNode * op)254 PrimExpr VisitExpr_(const CastNode* op) {
255 if (op->dtype.is_float()) {
256 return Cast(op->dtype, Mutate(op->value));
257 } else {
258 return make_zero(op->dtype);
259 }
260 }
261
262 PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED;
263
VisitExpr_(const SelectNode * op)264 PrimExpr VisitExpr_(const SelectNode* op) {
265 return Select(op->condition, Mutate(op->true_value), Mutate(op->false_value));
266 }
267
268 PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED;
269 PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED;
270 PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED;
271
VisitExpr_(const IntImmNode * op)272 PrimExpr VisitExpr_(const IntImmNode* op) { return IntImm(op->dtype, 0); }
273
VisitExpr_(const FloatImmNode * op)274 PrimExpr VisitExpr_(const FloatImmNode* op) { return FloatImm(op->dtype, 0); }
275
276 PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED;
277
278 private:
279 Tensor input_;
280 Array<PrimExpr> indices_;
281 Var input_var_;
282 arith::Analyzer analyzer_;
283
284 const Op& op_exp_ = Op::Get("tir.exp");
285 const Op& op_log_ = Op::Get("tir.log");
286 const Op& op_sigmoid_ = Op::Get("tir.sigmoid");
287 const Op& op_sqrt_ = Op::Get("tir.sqrt");
288 const Op& op_tanh_ = Op::Get("tir.tanh");
289 const Op& op_pow_ = Op::Get("tir.pow");
290 const Op& op_fabs_ = Op::Get("tir.fabs");
291 const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
292 std::unordered_set<RelayExpr, ObjectPtrHash, ObjectPtrEqual> piecewise_const = {
293 Op::Get("tir.floor"), Op::Get("tir.ceil"), Op::Get("tir.trunc"), Op::Get("tir.round")};
294 };
295
Derivative(const PrimExpr & expr,const Var & var)296 PrimExpr Derivative(const PrimExpr& expr, const Var& var) {
297 return JacobianMutator(var).Mutate(expr);
298 }
299
Jacobian(const PrimExpr & expr,const Tensor & input,const Array<PrimExpr> & indices)300 PrimExpr Jacobian(const PrimExpr& expr, const Tensor& input, const Array<PrimExpr>& indices) {
301 return JacobianMutator(input, indices).Mutate(expr);
302 }
303
Jacobian(const Tensor & output,const Tensor & input)304 Tensor Jacobian(const Tensor& output, const Tensor& input) {
305 const ComputeOpNode* op = output->op.as<ComputeOpNode>();
306 CHECK(op) << "Derivative of this operation is not implemented: " << output->op;
307 bool is_input_tensor = false;
308 for (const Tensor& child : op->InputTensors()) {
309 if (input == child) {
310 is_input_tensor = true;
311 break;
312 }
313 }
314 CHECK(is_input_tensor) << "Jacobian is called on a pair of tensors such that the output "
315 << "does not directly depend on the input.";
316
317 // We have to clone the iteration axes because otherwise the original expression
318 // cannot be used together with the derivative (it will lead to errors during lowering)
319 Array<IterVar> new_axis;
320 Map<Var, PrimExpr> vmap;
321 std::tie(new_axis, vmap) = te::CloneIterVars(op->axis);
322
323 Array<PrimExpr> input_indices;
324 size_t i = 0;
325 for (PrimExpr ext : input->shape) {
326 IterVar new_v =
327 IterVar(Range(0, ext), Var("jac_i" + std::to_string(i++)), IterVarType::kDataPar);
328 // Append jacobian iter to new_axis
329 new_axis.push_back(new_v);
330 // Differentiate wrt input[input_indices]
331 input_indices.push_back(new_v);
332 }
333 arith::Analyzer analzyer;
334 // Compute Jacobian
335 PrimExpr new_body =
336 Jacobian(Substitute(op->body[output->value_index], vmap), input, input_indices);
337 new_body = analzyer.Simplify(new_body);
338
339 int value_index = 0;
340 Array<PrimExpr> new_bodies;
341
342 // If this is a reduction then it may return a tuple and we have
343 // to repeat the body several times
344 if (const ReduceNode* red = new_body.as<ReduceNode>()) {
345 value_index = red->value_index;
346 for (size_t idx = 0; idx < red->source.size(); ++idx) {
347 new_bodies.push_back(
348 Reduce(red->combiner, red->source, red->axis, red->condition, idx, red->init));
349 }
350 } else {
351 new_bodies.push_back(new_body);
352 }
353
354 auto new_op = ComputeOp(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies);
355
356 // Jacobian shape = output.shape + input.shape
357 Array<PrimExpr> new_shape = output->shape;
358 for (const auto& e : input->shape) {
359 new_shape.push_back(e);
360 }
361
362 Tensor ret = Tensor(new_shape, output->dtype, new_op, value_index);
363 ret = RemoveJacobianAndLiftNonzeroCond(ret);
364 return ret;
365 }
366
367 } // namespace te
368 } // namespace tvm
369