1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-argument 18"""Backend compiler related feature registration""" 19from __future__ import absolute_import 20 21from tvm.topi.nn.util import get_pad_tuple 22from tvm.topi.util import get_const_tuple 23 24from ..expr import Tuple, TupleGetItem, const 25from . import nn as _nn 26from .op import register_gradient 27from .reduce import sum as _sum 28from .tensor import ( 29 cos, 30 cosh, 31 exp, 32 less, 33 negative, 34 ones_like, 35 power, 36 sin, 37 sinh, 38 sqrt, 39 zeros_like, 40 equal, 41 shape_of, 42 log, 43) 44from .transform import ( 45 broadcast_to_like, 46 collapse_sum_like, 47 cast_like, 48 reshape, 49 reshape_like, 50 strided_slice, 51 take, 52 tile, 53 transpose, 54 where, 55 repeat, 56 expand_dims, 57 full_like, 58) 59 60 61@register_gradient("log") 62def log_grad(orig, grad): 63 """Returns [grad * (1 / x)]""" 64 x = orig.args[0] 65 return [grad * ones_like(x) / x] 66 67 68@register_gradient("log2") 69def log2_grad(orig, grad): 70 """Returns [grad * 1 / (log(2) * x)]""" 71 x = orig.args[0] 72 ones = ones_like(x) 73 two = const(2.0, dtype=x.checked_type.dtype) 74 return [grad * ones / (log(two) * x)] 75 76 77@register_gradient("log10") 78def log10_grad(orig, grad): 79 """Returns [grad * 1 / (log(10) * x)]""" 80 x = orig.args[0] 81 ones = ones_like(x) 82 ten = const(10.0, dtype=x.checked_type.dtype) 83 return [grad * ones / (log(ten) * x)] 84 85 86@register_gradient("tan") 87def tan_grad(orig, grad): 88 """Returns [grad / (cos^2(x))]""" 89 x = orig.args[0] 90 return [grad / (cos(x) * cos(x))] 91 92 93@register_gradient("cos") 94def cos_grad(orig, grad): 95 """Returns [grad * (-sin(x))]""" 96 x = orig.args[0] 97 ones = ones_like(x) 98 return [grad * (-ones * sin(x))] 99 100 101@register_gradient("cosh") 102def cosh_grad(orig, grad): 103 """Returns [grad * sinh(x)]""" 104 x = orig.args[0] 105 return [grad * sinh(x)] 106 107 108@register_gradient("sin") 109def sin_grad(orig, grad): 110 """Returns [grad * cos(x)]""" 111 x = orig.args[0] 112 return [grad * cos(x)] 113 114 115@register_gradient("sinh") 116def sinh_grad(orig, grad): 117 """Returns [grad * cosh(x)]""" 118 x = orig.args[0] 119 return [grad * cosh(x)] 120 121 122@register_gradient("acos") 123def acos_grad(orig, grad): 124 """Returns [grad * -1/((1 - (x ^ 2)) ^ 1/2)]""" 125 x = orig.args[0] 126 ones = ones_like(x) 127 return [grad * (-ones / sqrt(ones - (x * x)))] 128 129 130@register_gradient("acosh") 131def acosh_grad(orig, grad): 132 """Returns [grad * 1/((x - 1) ^ 1/2 * (x + 1) ^ 1/2)]""" 133 x = orig.args[0] 134 ones = ones_like(x) 135 return [grad * ones / sqrt((x * x) - ones)] 136 137 138@register_gradient("asin") 139def asin_grad(orig, grad): 140 """Returns [grad * 1/((1 - (x ^ 2)) ^ (1/2))]""" 141 x = orig.args[0] 142 ones = ones_like(x) 143 return [grad * ones / sqrt(ones - (x * x))] 144 145 146@register_gradient("asinh") 147def asinh_grad(orig, grad): 148 """Returns [grad * 1/((1 + (x ^ 2)) ^ (1/2))]""" 149 x = orig.args[0] 150 ones = ones_like(x) 151 return [grad * ones / sqrt(ones + (x * x))] 152 153 154@register_gradient("atan") 155def atan_grad(orig, grad): 156 """Returns [grad * 1 / (1 + x ^ 2)]""" 157 x = orig.args[0] 158 ones = ones_like(x) 159 return [grad * ones / (ones + (x * x))] 160 161 162@register_gradient("atanh") 163def atanh_grad(orig, grad): 164 """Returns [grad * 1 / (1 - x ^ 2)]""" 165 x = orig.args[0] 166 ones = ones_like(x) 167 return [grad * ones / (ones - (x * x))] 168 169 170@register_gradient("exp") 171def exp_grad(orig, grad): 172 """Returns [grad * exp(x)]""" 173 return [grad * exp(orig.args[0])] 174 175 176@register_gradient("sqrt") 177def sqrt_grad(orig, grad): 178 """Returns [grad * 0.5 * (x ^ -0.5)]""" 179 x = orig.args[0] 180 a = const(0.5, dtype=x.checked_type.dtype) 181 return [grad * a * power(x, negative(a))] 182 183 184@register_gradient("sigmoid") 185def sigmoid_grad(orig, grad): 186 """Returns [grad * sigmoid(x) * (1 - sigmoid(x))].""" 187 return [grad * orig * (ones_like(orig) - orig)] 188 189 190@register_gradient("tanh") 191def tanh_grad(orig, grad): 192 """Returns grad * (1 - tanh(x) * tanh(x)).""" 193 return [grad * ones_like(orig) - orig * orig] 194 195 196@register_gradient("nn.relu") 197def relu_grad(orig, grad): 198 """Returns grad * (select(x < 0, 0, 1)).""" 199 x = orig.args[0] 200 zeros = zeros_like(x) 201 ones = ones_like(x) 202 return [where(less(x, zeros), zeros, ones * grad)] 203 204 205@register_gradient("add") 206def add_grad(orig, grad): 207 """Returns [grad, grad]""" 208 return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(grad, orig.args[1])] 209 210 211@register_gradient("subtract") 212def subtract_grad(orig, grad): 213 """Returns [grad, -grad]""" 214 return [collapse_sum_like(grad, orig.args[0]), collapse_sum_like(negative(grad), orig.args[1])] 215 216 217@register_gradient("multiply") 218def multiply_grad(orig, grad): 219 """Returns [grad * y, grad * x]""" 220 x, y = orig.args 221 return [collapse_sum_like(grad * y, x), collapse_sum_like(grad * x, y)] 222 223 224@register_gradient("divide") 225def divide_grad(orig, grad): 226 """Returns [grad / y, - grad * (x / y) / y]""" 227 x, y = orig.args 228 return [collapse_sum_like(grad / y, x), collapse_sum_like(-(grad * orig / y), y)] 229 230 231@register_gradient("zeros") 232def zeros_grad(orig, grad): 233 """Returns [shape]""" 234 return [orig.args[0]] 235 236 237@register_gradient("ones") 238def ones_grad(orig, grad): 239 """Returns [shape]""" 240 return [orig.args[0]] 241 242 243@register_gradient("zeros_like") 244def zeros_like_grad(orig, grad): 245 """Returns [0]""" 246 return [orig] 247 248 249@register_gradient("ones_like") 250def ones_like_grad(orig, grad): 251 """Returns [0]""" 252 return [zeros_like(orig.args[0])] 253 254 255@register_gradient("collapse_sum_like") 256def collapse_sum_like_grad(orig, grad): 257 """Returns [broadcast_to_like(grad, x), 0]""" 258 x, y = orig.args 259 return [broadcast_to_like(grad, x), zeros_like(y)] 260 261 262@register_gradient("collapse_sum_to") 263def collapse_sum_to_grad(orig, grad): 264 """Returns [broadcast_to_like(grad, x), 0]""" 265 x, y = orig.args 266 return [broadcast_to_like(grad, x), zeros_like(y)] 267 268 269@register_gradient("abs") 270def abs_grad(orig, grad): 271 """Returns grad * (select(x < 0, -1, 1)).""" 272 x = orig.args[0] 273 zeros = zeros_like(x) 274 ones = ones_like(x) 275 return [where(less(x, zeros), -ones * grad, ones * grad)] 276 277 278@register_gradient("erf") 279def erf_grad(orig, grad): 280 # c_2_div_sqrt_pi = 2.0 / math.sqrt(math.pi) 281 (inp,) = orig.args 282 c_2_div_sqrt_pi = const(1.1283791670955126, dtype=inp.checked_type.dtype) 283 return [c_2_div_sqrt_pi * exp(-inp * inp) * grad] 284 285 286@register_gradient("clip") 287def clip_grad(orig, grad): 288 """Returns grad * (select(x < min || max < x , 0, 1)).""" 289 x = orig.args[0] 290 a_min = orig.attrs.get_int("a_min") 291 a_max = orig.attrs.get_int("a_max") 292 a_mins = broadcast_to_like(const(a_min, dtype=x.checked_type.dtype), x) 293 a_maxs = broadcast_to_like(const(a_max, dtype=x.checked_type.dtype), x) 294 zeros = zeros_like(x) 295 ones = ones_like(x) 296 return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * grad))] 297 298 299@register_gradient("nn.max_pool2d") 300def max_pool2d_grad(orig, grad): 301 """Returns the gradient of max_pool2d.""" 302 attrs = orig.attrs 303 pool_grad = _nn.max_pool2d_grad( 304 grad, 305 orig.args[0], 306 pool_size=attrs.pool_size, 307 strides=attrs.strides, 308 padding=attrs.padding, 309 layout=attrs.layout, 310 ceil_mode=attrs.ceil_mode, 311 ) 312 return [pool_grad] 313 314 315@register_gradient("nn.avg_pool2d") 316def avg_pool2d_grad(orig, grad): 317 """Returns the gradient of avg_pool2d.""" 318 attrs = orig.attrs 319 pool_grad = _nn.avg_pool2d_grad( 320 grad, 321 orig.args[0], 322 pool_size=attrs.pool_size, 323 strides=attrs.strides, 324 padding=attrs.padding, 325 layout=attrs.layout, 326 ceil_mode=attrs.ceil_mode, 327 count_include_pad=attrs.count_include_pad, 328 ) 329 return [pool_grad] 330 331 332@register_gradient("nn.global_avg_pool2d") 333def global_avg_pool2d_grad(orig, grad): 334 """Returns the gradient of global_avg_pool2d.""" 335 data = orig.args[0] 336 shape = data.checked_type.shape 337 layout = orig.attrs.layout 338 339 # we assume NCHW or NHWC layout for now, but easy to add more 340 assert layout in ["NCHW", "NHWC"] 341 if layout == "NCHW": 342 pool_size = shape[2], shape[3] 343 elif layout == "NHWC": 344 pool_size = shape[1], shape[2] 345 346 pool_grad = _nn.avg_pool2d_grad( 347 grad, data, pool_size=pool_size, strides=(1, 1), padding=(0, 0), layout=layout 348 ) 349 return [pool_grad] 350 351 352# not implemented, this is only for testing. 353@register_gradient("concatenate") 354def concatenate_grad(orig, grad): 355 assert len(orig.args) == 1 356 t = orig.args[0] 357 x = TupleGetItem(t, 0) 358 y = TupleGetItem(t, 1) 359 # Assume only two element in tuple rn. 360 # In the real implementation, concatenate_grad probably need to be implemented by an operator. 361 return [Tuple([zeros_like(x), zeros_like(y)])] 362 363 364@register_gradient("nn.conv2d") 365def conv2d_grad(orig, grad): 366 """Gradient of conv2d""" 367 attrs = orig.attrs 368 data, weight = orig.args 369 data_shape = get_const_tuple(data.checked_type.shape) 370 weight_shape = get_const_tuple(weight.checked_type.shape) 371 _, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape) 372 batch, in_channel, in_h, in_w = data_shape 373 out_channel, _, filter_h, filter_w = weight_shape 374 375 # infer output_padding 376 fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple( 377 get_const_tuple(attrs.padding), (filter_h, filter_w) 378 ) 379 stride_h, stride_w = get_const_tuple(attrs.strides) 380 dilation_h, dilation_w = get_const_tuple(attrs.dilation) 381 out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h 382 out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w 383 output_padding = (in_h - out_h, in_w - out_w) 384 385 assert attrs.data_layout == "NCHW", "only support NCHW data layout" 386 assert attrs.kernel_layout == "OIHW", "only support OIHW kernel layout" 387 assert attrs.out_layout in ["", "NCHW"], "only support NCHW output layout" 388 389 backward_data = _nn.conv2d_transpose( 390 grad, 391 weight, 392 strides=attrs.strides, 393 padding=attrs.padding, 394 dilation=attrs.dilation, 395 groups=attrs.groups, 396 output_padding=output_padding, 397 ) 398 grad = tile(grad, [1, in_channel // attrs.groups, 1, 1]) 399 grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow 400 data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw 401 402 backward_weight = _nn.conv2d( 403 data, 404 grad, 405 strides=attrs.dilation, 406 padding=attrs.padding, 407 dilation=attrs.strides, 408 groups=in_channel * batch, 409 ) 410 # infer shape of backward_weight 411 padded_weight_grad_h = ( 412 in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom 413 ) // dilation_h + 1 414 padded_weight_grad_w = ( 415 in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right 416 ) // dilation_w + 1 417 backward_weight = reshape( 418 backward_weight, 419 [ 420 batch, 421 in_channel // attrs.groups, 422 out_channel, 423 padded_weight_grad_h, 424 padded_weight_grad_w, 425 ], 426 ) 427 backward_weight = _sum(backward_weight, axis=0) 428 backward_weight = transpose(backward_weight, [1, 0, 2, 3]) 429 430 assert padded_weight_grad_h >= filter_h 431 assert padded_weight_grad_w >= filter_w 432 if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: 433 backward_weight = strided_slice( 434 backward_weight, 435 begin=[0, 0, 0, 0], 436 end=[out_channel, in_channel // attrs.groups, filter_h, filter_w], 437 ) 438 439 return [backward_data, backward_weight] 440 441 442def _get_reduce_axis(call): 443 """Helper function that returns the reduce axis of the call as plain python ints.""" 444 x, axis = call.args[0], call.attrs.axis 445 shape = x.checked_type.concrete_shape 446 447 # should never exclude when axis is None 448 assert not (axis is None and call.attrs.exclude) 449 450 if axis is None: 451 return None 452 453 # convert to nonnegative integers and sort 454 axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)]) 455 if call.attrs.exclude: 456 axis = [ax for ax in range(len(shape)) if ax not in axis] 457 return axis 458 459 460def _unreduce_expand(x, axis): 461 """Helper function that returns x expanded on the reduced dimensions in axis.""" 462 # assume axis is sorted nonnegative ints 463 for ax in axis: 464 x = expand_dims(x, ax) 465 return x 466 467 468@register_gradient("max") 469def max_grad(orig, grad): 470 """Returns the gradient of max""" 471 x, axis = orig.args[0], _get_reduce_axis(orig) 472 shape = x.checked_type.concrete_shape 473 474 repeated = orig 475 if axis is None: 476 repeated = full_like(x, repeated) 477 else: 478 # expand dims (if necessary) and repeat along each axis 479 if not orig.attrs.keepdims: 480 repeated = _unreduce_expand(repeated, axis) 481 grad = _unreduce_expand(grad, axis) 482 for ax in axis: 483 repeated = repeat(repeated, shape[ax], ax) 484 485 indicators = cast_like(equal(repeated, x), grad) 486 num_selected = _sum(indicators, axis, keepdims=True) 487 # spread error across all max weights 488 return [indicators * grad / num_selected] 489 490 491@register_gradient("nn.softmax") 492def softmax_grad(orig, grad): 493 """Gradient of softmax""" 494 return [(grad - _sum(grad * orig, orig.attrs.axis, True)) * orig] 495 496 497@register_gradient("nn.log_softmax") 498def log_softmax_grad(orig, grad): 499 """Gradient of log_softmax""" 500 x = orig.args[0] 501 sm = _nn.softmax(x, axis=orig.attrs.axis) 502 grad = grad / sm 503 return softmax_grad(sm, grad) 504 505 506@register_gradient("nn.bias_add") 507def bias_add_grad(orig, grad): 508 """Returns gradient of bias_add""" 509 data = orig.args[0] 510 return [ 511 collapse_sum_like(grad, data), 512 _sum(grad, orig.attrs.axis, keepdims=False, exclude=True), 513 ] 514 515 516@register_gradient("nn.dense") 517def dense_grad(orig, grad): 518 """Returns [grad' @ weight, data @ grad']""" 519 data, weight = orig.args 520 return [ 521 collapse_sum_like( 522 _nn.dense(grad, transpose(weight), units=weight.checked_type.shape[1]), data 523 ), 524 collapse_sum_like( 525 _nn.dense(transpose(grad), transpose(data), units=data.checked_type.shape[1]), weight 526 ), 527 ] 528 529 530@register_gradient("nn.batch_matmul") 531def batch_matmul_grad(orig, grad): 532 """gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij 533 grads: GRAD_OUT_bij,RHS_bjk->GRAD_IN_LHS_bik 534 GRAD_OUT_bij,LHS_bik->GRAD_IN_RHS_bjk 535 """ 536 lhs, rhs = orig.args 537 return [ 538 collapse_sum_like(_nn.batch_matmul(grad, transpose(rhs, [0, 2, 1])), lhs), 539 collapse_sum_like( 540 _nn.batch_matmul(transpose(grad, [0, 2, 1]), transpose(lhs, [0, 2, 1])), rhs 541 ), 542 ] 543 544 545@register_gradient("reshape") 546def reshape_grad(orig, grad): 547 """Gradient of reshape""" 548 return [reshape_like(grad, orig.args[0])] 549 550 551@register_gradient("dyn.reshape") 552def dyn_reshape_grad(orig, grad): 553 """Gradient of dyn_reshape""" 554 return [reshape_like(grad, orig.args[0]), zeros_like(orig.args[1])] 555 556 557@register_gradient("shape_of") 558def shape_of_grad(orig, grad): 559 """Gradient of shape_of""" 560 return [zeros_like(orig.args[0])] 561 562 563@register_gradient("cast") 564def cast_grad(orig, grad): 565 x = orig.args[0] 566 return [cast_like(grad, x)] 567 568 569@register_gradient("nn.batch_flatten") 570def batch_flatten_grad(orig, grad): 571 """Returns grad reshaped to data dims""" 572 data = orig.args[0] 573 return [reshape_like(grad, data)] 574 575 576@register_gradient("transpose") 577def transpose_grad(orig, grad): 578 """Returns grad transposed over the complement of original transpose axes""" 579 orig_axes = orig.attrs.axes 580 if orig_axes: 581 dims = len(orig_axes) 582 new_axes = [0] * dims 583 for i in range(dims): 584 new_axes[int(orig_axes[i])] = i 585 else: 586 new_axes = None 587 return [transpose(grad, axes=new_axes)] 588 589 590@register_gradient("negative") 591def negative_grad(orig, grad): 592 """Returns -grad""" 593 return [-grad] 594 595 596@register_gradient("sum") 597def sum_grad(orig, grad): 598 """Returns grad broadcasted to data dims""" 599 data, axis = orig.args[0], _get_reduce_axis(orig) 600 if not orig.attrs.keepdims: 601 if axis is None: 602 axis = list(range(len(data.checked_type.concrete_shape))) 603 grad = _unreduce_expand(grad, axis) 604 return [broadcast_to_like(grad, data)] 605 606 607@register_gradient("mean") 608def mean_grad(orig, grad): 609 """Returns grad broadcasted to data dims""" 610 data, axis = orig.args[0], _get_reduce_axis(orig) 611 shape = data.checked_type.concrete_shape 612 if axis is None: 613 axis = list(range(len(data.checked_type.concrete_shape))) 614 if not orig.attrs.keepdims: 615 grad = _unreduce_expand(grad, axis) 616 mult = 1.0 617 for a in axis: 618 mult /= shape[a] 619 return [broadcast_to_like(grad * const(mult, dtype=data.checked_type.dtype), data)] 620 621 622@register_gradient("variance") 623def variance_grad(orig, grad): 624 """Note that we take mean as an argument in the variance node""" 625 data, data_mean, axis = orig.args[0], orig.args[1], _get_reduce_axis(orig) 626 unbiased = orig.attrs.unbiased 627 shape = data.checked_type.concrete_shape 628 if axis is None: 629 axis = list(range(len(data.checked_type.concrete_shape))) 630 if not orig.attrs.keepdims: 631 grad = _unreduce_expand(grad, axis) 632 mult1 = 2.0 633 mult2 = -2.0 634 count = 1 635 for a in axis: 636 count *= shape[a] 637 if unbiased: 638 mult2 = mult2 * count / (count - 1) 639 count -= 1 640 mult1 /= count 641 return [ 642 (grad * const(mult1, dtype=data.checked_type.dtype)) * data, 643 const(mult2, dtype=data.checked_type.dtype) * grad * data_mean, 644 ] 645 646 647@register_gradient("copy") 648def copy_grad(orig, grad): 649 return [grad] 650 651 652@register_gradient("nn.cross_entropy") 653def cross_entropy_grad(orig, grad): 654 x, y = orig.args 655 shape = shape_of(x) 656 batch_size = take(shape, const(0, dtype="int32"), axis=0) 657 grad = grad / batch_size.astype(x.checked_type.dtype) 658 return [-grad * y / x, -grad * log(x)] 659 660 661@register_gradient("nn.cross_entropy_with_logits") 662def cross_entropy_with_logits_grad(orig, grad): 663 x, y = orig.args 664 shape = shape_of(x) 665 batch_size = take(shape, const(0, dtype="int32"), axis=0) 666 grad = grad / batch_size.astype(x.checked_type.dtype) 667 return [-grad * y, -grad * x] 668