1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import numpy as np 18 19import topi 20import topi.testing 21import tvm 22from tvm import relay 23from tvm.relay.testing import check_grad, ctx_list, run_infer_type 24from tvm.relay.transform import gradient 25 26 27def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode): 28 x = relay.var("x", relay.TensorType(x_shape, "float32")) 29 y = tvm.relay.nn.max_pool2d(x, pool_size=pool_size, strides=strides, padding=padding, 30 ceil_mode=ceil_mode) 31 32 fwd_func = relay.Function([x], y) 33 fwd_func = run_infer_type(fwd_func) 34 bwd_func = run_infer_type(gradient(fwd_func)) 35 36 data = np.random.rand(*x_shape).astype("float32") 37 ph, pw = padding 38 y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape) 39 out_grad = np.ones(shape=y_shape) 40 ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides, 41 padding=[ph, pw, ph, pw], 42 pool_type='max', ceil_mode=ceil_mode) 43 44 for target, ctx in ctx_list(): 45 intrp = relay.create_executor(ctx=ctx, target=target) 46 op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) 47 np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) 48 49 50def test_max_pool2d_grad(): 51 verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False) 52 verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False) 53 54 55def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad): 56 x = relay.var("x", relay.TensorType(x_shape, "float32")) 57 y = tvm.relay.nn.avg_pool2d(x, pool_size=pool_size, strides=strides, padding=padding, 58 ceil_mode=ceil_mode, count_include_pad=count_include_pad) 59 60 fwd_func = relay.Function([x], y) 61 fwd_func = run_infer_type(fwd_func) 62 bwd_func = run_infer_type(gradient(fwd_func)) 63 64 data = np.random.rand(*x_shape).astype("float32") 65 ph, pw = padding 66 y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape) 67 out_grad = np.ones(shape=y_shape) 68 ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides, 69 padding=[ph, pw, ph, pw], 70 pool_type='avg', ceil_mode=ceil_mode) 71 72 for target, ctx in ctx_list(): 73 intrp = relay.create_executor(ctx=ctx, target=target) 74 op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) 75 np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) 76 77def test_avg_pool2d_grad(): 78 verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), 79 ceil_mode=False, count_include_pad=True) 80 verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), 81 ceil_mode=False, count_include_pad=False) 82 83 84def verify_global_avg_pool2d_grad(x_shape): 85 x = relay.var("x", relay.TensorType(x_shape, "float32")) 86 y = tvm.relay.nn.global_avg_pool2d(x) 87 88 fwd_func = relay.Function([x], y) 89 fwd_func = run_infer_type(fwd_func) 90 bwd_func = run_infer_type(gradient(fwd_func)) 91 92 data = np.random.rand(*x_shape).astype("float32") 93 y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape) 94 out_grad = np.ones(shape=y_shape) 95 ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]), 96 strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg', 97 ceil_mode=False) 98 99 for target, ctx in ctx_list(): 100 intrp = relay.create_executor(ctx=ctx, target=target) 101 op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) 102 np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) 103 104def test_global_avg_pool2d_grad(): 105 verify_global_avg_pool2d_grad((1, 4, 16, 16)) 106 verify_global_avg_pool2d_grad((1, 8, 8, 24)) 107 108def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'): 109 try: 110 import torch 111 import torch.nn.functional as F 112 except ImportError: 113 print('Skip because pytorch is not installed') 114 return 115 116 dtype = 'float32' 117 data = relay.var('data', shape=dshape, dtype=dtype) 118 weight = relay.var('weight', shape=wshape, dtype=dtype) 119 conv = relay.nn.conv2d(data, weight, strides=strides, padding=padding, dilation=dilation, 120 groups=groups) 121 fwd_func = relay.Function([data, weight], conv) 122 fwd_func = run_infer_type(fwd_func) 123 bwd_func = run_infer_type(gradient(fwd_func, mode=mode)) 124 125 data_pt = torch.randn(*dshape, dtype=torch.float32, requires_grad=True) 126 weight_pt = torch.randn(*wshape, dtype=torch.float32, requires_grad=True) 127 out_pt = F.conv2d(data_pt, weight_pt, stride=strides, padding=padding, dilation=dilation, 128 groups=groups) 129 grad_output_pt = torch.ones(out_pt.shape) 130 grad_input_pt = F.grad.conv2d_input(dshape, weight_pt, grad_output_pt, stride=strides, 131 padding=padding, dilation=dilation, groups=groups) \ 132 .detach().numpy() 133 grad_weight_pt = F.grad.conv2d_weight(data_pt, wshape, grad_output_pt, stride=strides, 134 padding=padding, dilation=dilation, groups=groups) \ 135 .detach().numpy() 136 137 138 for target, ctx in ctx_list(): 139 data = tvm.nd.array(data_pt.detach().numpy(), ctx) 140 weight = tvm.nd.array(weight_pt.detach().numpy(), ctx) 141 intrp = relay.create_executor(ctx=ctx, target=target) 142 op_res, (grad_input, grad_weight) = intrp.evaluate(bwd_func)(data, weight) 143 np.testing.assert_allclose(grad_input.asnumpy(), grad_input_pt, rtol=1e-4, atol=1e-4) 144 np.testing.assert_allclose(grad_weight.asnumpy(), grad_weight_pt, rtol=1e-4, atol=1e-4) 145 146 147def test_conv2d_grad(): 148 verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1]) 149 verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [1, 1], [0, 0], [1, 1]) 150 verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [2, 2], [0, 0], [1, 1]) 151 verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode='first_order') 152 153 154def verify_dense_grad(d_shape, w_shape): 155 data = relay.var("data", relay.TensorType(d_shape, "float32")) 156 weight = relay.var("weight", relay.TensorType(w_shape, "float32")) 157 fwd_func = relay.Function([data, weight], relay.nn.dense(data, weight)) 158 check_grad(fwd_func) 159 160 161def test_dense_grad(): 162 verify_dense_grad((1, 8), (16, 8)) 163 verify_dense_grad((1, 4), (3, 4)) 164 verify_dense_grad((5, 4), (3, 4)) 165 166 167def verify_batch_flatten_grad(d_shape): 168 data = relay.var("data", relay.TensorType(d_shape, "float32")) 169 fwd_func = relay.Function([data], relay.nn.batch_flatten(data)) 170 check_grad(fwd_func) 171 172 173def test_batch_flatten_grad(): 174 verify_batch_flatten_grad((1, 2, 3, 4)) 175 verify_batch_flatten_grad((1, 8)) 176 177 178if __name__ == "__main__": 179 test_max_pool2d_grad() 180 test_avg_pool2d_grad() 181 test_global_avg_pool2d_grad() 182 test_conv2d_grad() 183 test_dense_grad() 184 test_batch_flatten_grad() 185