1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-argument
18"""Backend compiler related feature registration"""
19from __future__ import absolute_import
20
21from tvm.topi.nn.util import get_pad_tuple
22from tvm.topi.util import get_const_tuple
23
24from ..expr import Tuple, TupleGetItem, const
25from . import nn as _nn
26from .op import register_gradient
27from .reduce import sum as _sum
28from .tensor import (
29    cos,
30    cosh,
31    exp,
32    less,
33    negative,
34    ones_like,
35    power,
36    sin,
37    sinh,
38    sqrt,
39    zeros_like,
40    equal,
41    shape_of,
42    log,
43)
44from .transform import (
45    broadcast_to_like,
46    collapse_sum_like,
47    cast_like,
48    reshape,
49    reshape_like,
50    strided_slice,
51    take,
52    tile,
53    transpose,
54    where,
55    repeat,
56    expand_dims,
57    full_like,
58)
59
60
61@register_gradient("log")
62def log_grad(orig, grad):
63    """Returns [grad * (1 / x)]"""
64    x = orig.args[0]
65    return [grad * ones_like(x) / x]
66
67
68@register_gradient("log2")
69def log2_grad(orig, grad):
70    """Returns [grad * 1 / (log(2) * x)]"""
71    x = orig.args[0]
72    ones = ones_like(x)
73    two = const(2.0, dtype=x.checked_type.dtype)
74    return [grad * ones / (log(two) * x)]
75
76
77@register_gradient("log10")
78def log10_grad(orig, grad):
79    """Returns [grad * 1 / (log(10) * x)]"""
80    x = orig.args[0]
81    ones = ones_like(x)
82    ten = const(10.0, dtype=x.checked_type.dtype)
83    return [grad * ones / (log(ten) * x)]
84
85
86@register_gradient("tan")
87def tan_grad(orig, grad):
88    """Returns [grad / (cos^2(x))]"""
89    x = orig.args[0]
90    return [grad / (cos(x) * cos(x))]
91
92
93@register_gradient("cos")
94def cos_grad(orig, grad):
95    """Returns [grad * (-sin(x))]"""
96    x = orig.args[0]
97    ones = ones_like(x)
98    return [grad * (-ones * sin(x))]
99
100
101@register_gradient("cosh")
102def cosh_grad(orig, grad):
103    """Returns [grad * sinh(x)]"""
104    x = orig.args[0]
105    return [grad * sinh(x)]
106
107
108@register_gradient("sin")
109def sin_grad(orig, grad):
110    """Returns [grad * cos(x)]"""
111    x = orig.args[0]
112    return [grad * cos(x)]
113
114
115@register_gradient("sinh")
116def sinh_grad(orig, grad):
117    """Returns [grad * cosh(x)]"""
118    x = orig.args[0]
119    return [grad * cosh(x)]
120
121
122@register_gradient("acos")
123def acos_grad(orig, grad):
124    """Returns [grad * -1/((1 - (x ^ 2)) ^ 1/2)]"""
125    x = orig.args[0]
126    ones = ones_like(x)
127    return [grad * (-ones / sqrt(ones - (x * x)))]
128
129
130@register_gradient("acosh")
131def acosh_grad(orig, grad):
132    """Returns [grad * 1/((x - 1) ^ 1/2 * (x + 1) ^ 1/2)]"""
133    x = orig.args[0]
134    ones = ones_like(x)
135    return [grad * ones / sqrt((x * x) - ones)]
136
137
138@register_gradient("asin")
139def asin_grad(orig, grad):
140    """Returns [grad * 1/((1 - (x ^ 2)) ^ (1/2))]"""
141    x = orig.args[0]
142    ones = ones_like(x)
143    return [grad * ones / sqrt(ones - (x * x))]
144
145
146@register_gradient("asinh")
147def asinh_grad(orig, grad):
148    """Returns [grad * 1/((1 + (x ^ 2)) ^ (1/2))]"""
149    x = orig.args[0]
150    ones = ones_like(x)
151    return [grad * ones / sqrt(ones + (x * x))]
152
153
154@register_gradient("atan")
155def atan_grad(orig, grad):
156    """Returns [grad * 1 / (1 + x ^ 2)]"""
157    x = orig.args[0]
158    ones = ones_like(x)
159    return [grad * ones / (ones + (x * x))]
160
161
162@register_gradient("atanh")
163def atanh_grad(orig, grad):
164    """Returns [grad * 1 / (1 - x ^ 2)]"""
165    x = orig.args[0]
166    ones = ones_like(x)
167    return [grad * ones / (ones - (x * x))]
168
169
170@register_gradient("exp")
171def exp_grad(orig, grad):
172    """Returns [grad * exp(x)]"""
173    return [grad * exp(orig.args[0])]
174
175
176@register_gradient("sqrt")
177def sqrt_grad(orig, grad):
178    """Returns [grad * 0.5 * (x ^ -0.5)]"""
179    x = orig.args[0]
180    a = const(0.5, dtype=x.checked_type.dtype)
181    return [grad * a * power(x, negative(a))]
182
183
184@register_gradient("sigmoid")
185def sigmoid_grad(orig, grad):
186    """Returns [grad * sigmoid(x) * (1 - sigmoid(x))]."""
187    return [grad * orig * (ones_like(orig) - orig)]
188
189
190@register_gradient("tanh")
191def tanh_grad(orig, grad):
192    """Returns grad * (1 - tanh(x) * tanh(x))."""
193    return [grad * ones_like(orig) - orig * orig]
194
195
196@register_gradient("nn.relu")
197def relu_grad(orig, grad):
198    """Returns grad * (select(x < 0, 0, 1))."""
199    x = orig.args[0]
200    zeros = zeros_like(x)
201    ones = ones_like(x)
202    return [where(less(x, zeros), zeros, ones * grad)]
203
204
205@register_gradient("add")
206def add_grad(orig, grad):
207    """Returns [grad, grad]"""
208    return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(grad, orig.args[1])]
209
210
211@register_gradient("subtract")
212def subtract_grad(orig, grad):
213    """Returns [grad, -grad]"""
214    return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(negative(grad), orig.args[1])]
215
216
217@register_gradient("multiply")
218def multiply_grad(orig, grad):
219    """Returns [grad * y, grad * x]"""
220    x, y = orig.args
221    return [collapse_sum_like(grad * y, x), collapse_sum_like(grad * x, y)]
222
223
224@register_gradient("divide")
225def divide_grad(orig, grad):
226    """Returns [grad / y,  - grad * (x / y) / y]"""
227    x, y = orig.args
228    return [collapse_sum_like(grad / y, x), collapse_sum_like(-(grad * orig / y), y)]
229
230
231@register_gradient("zeros")
232def zeros_grad(orig, grad):
233    """Returns [shape]"""
234    return [orig.args[0]]
235
236
237@register_gradient("ones")
238def ones_grad(orig, grad):
239    """Returns [shape]"""
240    return [orig.args[0]]
241
242
243@register_gradient("zeros_like")
244def zeros_like_grad(orig, grad):
245    """Returns [0]"""
246    return [orig]
247
248
249@register_gradient("ones_like")
250def ones_like_grad(orig, grad):
251    """Returns [0]"""
252    return [zeros_like(orig.args[0])]
253
254
255@register_gradient("collapse_sum_like")
256def collapse_sum_like_grad(orig, grad):
257    """Returns [broadcast_to_like(grad, x), 0]"""
258    x, y = orig.args
259    return [broadcast_to_like(grad, x), zeros_like(y)]
260
261
262@register_gradient("collapse_sum_to")
263def collapse_sum_to_grad(orig, grad):
264    """Returns [broadcast_to_like(grad, x), 0]"""
265    x, y = orig.args
266    return [broadcast_to_like(grad, x), zeros_like(y)]
267
268
269@register_gradient("abs")
270def abs_grad(orig, grad):
271    """Returns grad * (select(x < 0, -1, 1))."""
272    x = orig.args[0]
273    zeros = zeros_like(x)
274    ones = ones_like(x)
275    return [where(less(x, zeros), -ones * grad, ones * grad)]
276
277
278@register_gradient("erf")
279def erf_grad(orig, grad):
280    # c_2_div_sqrt_pi = 2.0 / math.sqrt(math.pi)
281    (inp,) = orig.args
282    c_2_div_sqrt_pi = const(1.1283791670955126, dtype=inp.checked_type.dtype)
283    return [c_2_div_sqrt_pi * exp(-inp * inp) * grad]
284
285
286@register_gradient("clip")
287def clip_grad(orig, grad):
288    """Returns grad * (select(x < min || max < x , 0, 1))."""
289    x = orig.args[0]
290    a_min = orig.attrs.get_int("a_min")
291    a_max = orig.attrs.get_int("a_max")
292    a_mins = broadcast_to_like(const(a_min, dtype=x.checked_type.dtype), x)
293    a_maxs = broadcast_to_like(const(a_max, dtype=x.checked_type.dtype), x)
294    zeros = zeros_like(x)
295    ones = ones_like(x)
296    return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))]
297
298
299@register_gradient("nn.max_pool2d")
300def max_pool2d_grad(orig, grad):
301    """Returns the gradient of max_pool2d."""
302    attrs = orig.attrs
303    pool_grad = _nn.max_pool2d_grad(
304        grad,
305        orig.args[0],
306        pool_size=attrs.pool_size,
307        strides=attrs.strides,
308        padding=attrs.padding,
309        layout=attrs.layout,
310        ceil_mode=attrs.ceil_mode,
311    )
312    return [pool_grad]
313
314
315@register_gradient("nn.avg_pool2d")
316def avg_pool2d_grad(orig, grad):
317    """Returns the gradient of avg_pool2d."""
318    attrs = orig.attrs
319    pool_grad = _nn.avg_pool2d_grad(
320        grad,
321        orig.args[0],
322        pool_size=attrs.pool_size,
323        strides=attrs.strides,
324        padding=attrs.padding,
325        layout=attrs.layout,
326        ceil_mode=attrs.ceil_mode,
327        count_include_pad=attrs.count_include_pad,
328    )
329    return [pool_grad]
330
331
332@register_gradient("nn.global_avg_pool2d")
333def global_avg_pool2d_grad(orig, grad):
334    """Returns the gradient of global_avg_pool2d."""
335    data = orig.args[0]
336    shape = data.checked_type.shape
337    layout = orig.attrs.layout
338
339    # we assume NCHW or NHWC layout for now, but easy to add more
340    assert layout in ["NCHW", "NHWC"]
341    if layout == "NCHW":
342        pool_size = shape[2], shape[3]
343    elif layout == "NHWC":
344        pool_size = shape[1], shape[2]
345
346    pool_grad = _nn.avg_pool2d_grad(
347        grad, data, pool_size=pool_size, strides=(1, 1), padding=(0, 0), layout=layout
348    )
349    return [pool_grad]
350
351
352# not implemented, this is only for testing.
353@register_gradient("concatenate")
354def concatenate_grad(orig, grad):
355    assert len(orig.args) == 1
356    t = orig.args[0]
357    x = TupleGetItem(t, 0)
358    y = TupleGetItem(t, 1)
359    # Assume only two element in tuple rn.
360    # In the real implementation, concatenate_grad probably need to be implemented by an operator.
361    return [Tuple([zeros_like(x), zeros_like(y)])]
362
363
364@register_gradient("nn.conv2d")
365def conv2d_grad(orig, grad):
366    """Gradient of conv2d"""
367    attrs = orig.attrs
368    data, weight = orig.args
369    data_shape = get_const_tuple(data.checked_type.shape)
370    weight_shape = get_const_tuple(weight.checked_type.shape)
371    _, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape)
372    batch, in_channel, in_h, in_w = data_shape
373    out_channel, _, filter_h, filter_w = weight_shape
374
375    # infer output_padding
376    fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
377        get_const_tuple(attrs.padding), (filter_h, filter_w)
378    )
379    stride_h, stride_w = get_const_tuple(attrs.strides)
380    dilation_h, dilation_w = get_const_tuple(attrs.dilation)
381    out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
382    out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w
383    output_padding = (in_h - out_h, in_w - out_w)
384
385    assert attrs.data_layout == "NCHW", "only support NCHW data layout"
386    assert attrs.kernel_layout == "OIHW", "only support OIHW kernel layout"
387    assert attrs.out_layout in ["", "NCHW"], "only support NCHW output layout"
388
389    backward_data = _nn.conv2d_transpose(
390        grad,
391        weight,
392        strides=attrs.strides,
393        padding=attrs.padding,
394        dilation=attrs.dilation,
395        groups=attrs.groups,
396        output_padding=output_padding,
397    )
398    grad = tile(grad, [1, in_channel // attrs.groups, 1, 1])
399    grad = reshape(grad, [-1, 1, 0, 0])  # batch * oc * ic // groups, 1, oh, ow
400    data = reshape(data, [1, -1, 0, 0])  # 1, batch * ic, ih, iw
401
402    backward_weight = _nn.conv2d(
403        data,
404        grad,
405        strides=attrs.dilation,
406        padding=attrs.padding,
407        dilation=attrs.strides,
408        groups=in_channel * batch,
409    )
410    # infer shape of backward_weight
411    padded_weight_grad_h = (
412        in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom
413    ) // dilation_h + 1
414    padded_weight_grad_w = (
415        in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right
416    ) // dilation_w + 1
417    backward_weight = reshape(
418        backward_weight,
419        [
420            batch,
421            in_channel // attrs.groups,
422            out_channel,
423            padded_weight_grad_h,
424            padded_weight_grad_w,
425        ],
426    )
427    backward_weight = _sum(backward_weight, axis=0)
428    backward_weight = transpose(backward_weight, [1, 0, 2, 3])
429
430    assert padded_weight_grad_h >= filter_h
431    assert padded_weight_grad_w >= filter_w
432    if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
433        backward_weight = strided_slice(
434            backward_weight,
435            begin=[0, 0, 0, 0],
436            end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
437        )
438
439    return [backward_data, backward_weight]
440
441
442def _get_reduce_axis(call):
443    """Helper function that returns the reduce axis of the call as plain python ints."""
444    x, axis = call.args[0], call.attrs.axis
445    shape = x.checked_type.concrete_shape
446
447    # should never exclude when axis is None
448    assert not (axis is None and call.attrs.exclude)
449
450    if axis is None:
451        return None
452
453    # convert to nonnegative integers and sort
454    axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)])
455    if call.attrs.exclude:
456        axis = [ax for ax in range(len(shape)) if ax not in axis]
457    return axis
458
459
460def _unreduce_expand(x, axis):
461    """Helper function that returns x expanded on the reduced dimensions in axis."""
462    # assume axis is sorted nonnegative ints
463    for ax in axis:
464        x = expand_dims(x, ax)
465    return x
466
467
468@register_gradient("max")
469def max_grad(orig, grad):
470    """Returns the gradient of max"""
471    x, axis = orig.args[0], _get_reduce_axis(orig)
472    shape = x.checked_type.concrete_shape
473
474    repeated = orig
475    if axis is None:
476        repeated = full_like(x, repeated)
477    else:
478        # expand dims (if necessary) and repeat along each axis
479        if not orig.attrs.keepdims:
480            repeated = _unreduce_expand(repeated, axis)
481            grad = _unreduce_expand(grad, axis)
482        for ax in axis:
483            repeated = repeat(repeated, shape[ax], ax)
484
485    indicators = cast_like(equal(repeated, x), grad)
486    num_selected = _sum(indicators, axis, keepdims=True)
487    # spread error across all max weights
488    return [indicators * grad / num_selected]
489
490
491@register_gradient("nn.softmax")
492def softmax_grad(orig, grad):
493    """Gradient of softmax"""
494    return [(grad - _sum(grad * orig, orig.attrs.axis, True)) * orig]
495
496
497@register_gradient("nn.log_softmax")
498def log_softmax_grad(orig, grad):
499    """Gradient of log_softmax"""
500    x = orig.args[0]
501    sm = _nn.softmax(x, axis=orig.attrs.axis)
502    grad = grad / sm
503    return softmax_grad(sm, grad)
504
505
506@register_gradient("nn.bias_add")
507def bias_add_grad(orig, grad):
508    """Returns gradient of bias_add"""
509    data = orig.args[0]
510    return [
511        collapse_sum_like(grad, data),
512        _sum(grad, orig.attrs.axis, keepdims=False, exclude=True),
513    ]
514
515
516@register_gradient("nn.dense")
517def dense_grad(orig, grad):
518    """Returns [grad' @ weight, data @ grad']"""
519    data, weight = orig.args
520    return [
521        collapse_sum_like(
522            _nn.dense(grad, transpose(weight), units=weight.checked_type.shape[1]), data
523        ),
524        collapse_sum_like(
525            _nn.dense(transpose(grad), transpose(data), units=data.checked_type.shape[1]), weight
526        ),
527    ]
528
529
530@register_gradient("nn.batch_matmul")
531def batch_matmul_grad(orig, grad):
532    """gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
533    grads: GRAD_OUT_bij,RHS_bjk->GRAD_IN_LHS_bik
534           GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk
535    """
536    lhs, rhs = orig.args
537    return [
538        collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs),
539        collapse_sum_like(
540            _nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs
541        ),
542    ]
543
544
545@register_gradient("reshape")
546def reshape_grad(orig, grad):
547    """Gradient of reshape"""
548    return [reshape_like(grad, orig.args[0])]
549
550
551@register_gradient("dyn.reshape")
552def dyn_reshape_grad(orig, grad):
553    """Gradient of dyn_reshape"""
554    return [reshape_like(grad, orig.args[0]), zeros_like(orig.args[1])]
555
556
557@register_gradient("shape_of")
558def shape_of_grad(orig, grad):
559    """Gradient of shape_of"""
560    return [zeros_like(orig.args[0])]
561
562
563@register_gradient("cast")
564def cast_grad(orig, grad):
565    x = orig.args[0]
566    return [cast_like(grad, x)]
567
568
569@register_gradient("nn.batch_flatten")
570def batch_flatten_grad(orig, grad):
571    """Returns grad reshaped to data dims"""
572    data = orig.args[0]
573    return [reshape_like(grad, data)]
574
575
576@register_gradient("transpose")
577def transpose_grad(orig, grad):
578    """Returns grad transposed over the complement of original transpose axes"""
579    orig_axes = orig.attrs.axes
580    if orig_axes:
581        dims = len(orig_axes)
582        new_axes = [0] * dims
583        for i in range(dims):
584            new_axes[int(orig_axes[i])] = i
585    else:
586        new_axes = None
587    return [transpose(grad, axes=new_axes)]
588
589
590@register_gradient("negative")
591def negative_grad(orig, grad):
592    """Returns -grad"""
593    return [-grad]
594
595
596@register_gradient("sum")
597def sum_grad(orig, grad):
598    """Returns grad broadcasted to data dims"""
599    data, axis = orig.args[0], _get_reduce_axis(orig)
600    if not orig.attrs.keepdims:
601        if axis is None:
602            axis = list(range(len(data.checked_type.concrete_shape)))
603        grad = _unreduce_expand(grad, axis)
604    return [broadcast_to_like(grad, data)]
605
606
607@register_gradient("mean")
608def mean_grad(orig, grad):
609    """Returns grad broadcasted to data dims"""
610    data, axis = orig.args[0], _get_reduce_axis(orig)
611    shape = data.checked_type.concrete_shape
612    if axis is None:
613        axis = list(range(len(data.checked_type.concrete_shape)))
614    if not orig.attrs.keepdims:
615        grad = _unreduce_expand(grad, axis)
616    mult = 1.0
617    for a in axis:
618        mult /= shape[a]
619    return [broadcast_to_like(grad * const(mult, dtype=data.checked_type.dtype), data)]
620
621
622@register_gradient("variance")
623def variance_grad(orig, grad):
624    """Note that we take mean as an argument in the variance node"""
625    data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig)
626    unbiased = orig.attrs.unbiased
627    shape = data.checked_type.concrete_shape
628    if axis is None:
629        axis = list(range(len(data.checked_type.concrete_shape)))
630    if not orig.attrs.keepdims:
631        grad = _unreduce_expand(grad, axis)
632    mult1 = 2.0
633    mult2 = -2.0
634    count = 1
635    for a in axis:
636        count *= shape[a]
637    if unbiased:
638        mult2 = mult2 * count / (count - 1)
639        count -= 1
640    mult1 /= count
641    return [
642        (grad * const(mult1, dtype=data.checked_type.dtype)) * data,
643        const(mult2, dtype=data.checked_type.dtype) * grad * data_mean,
644    ]
645
646
647@register_gradient("copy")
648def copy_grad(orig, grad):
649    return [grad]
650
651
652@register_gradient("nn.cross_entropy")
653def cross_entropy_grad(orig, grad):
654    x, y = orig.args
655    shape = shape_of(x)
656    batch_size = take(shape, const(0, dtype="int32"), axis=0)
657    grad = grad / batch_size.astype(x.checked_type.dtype)
658    return [-grad * y / x, -grad * log(x)]
659
660
661@register_gradient("nn.cross_entropy_with_logits")
662def cross_entropy_with_logits_grad(orig, grad):
663    x, y = orig.args
664    shape = shape_of(x)
665    batch_size = take(shape, const(0, dtype="int32"), axis=0)
666    grad = grad / batch_size.astype(x.checked_type.dtype)
667    return [-grad * y, -grad * x]
668