1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import numpy as np
18
19import topi
20import topi.testing
21import tvm
22from tvm import relay
23from tvm.relay.testing import check_grad, ctx_list, run_infer_type
24from tvm.relay.transform import gradient
25
26
27def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
28    x = relay.var("x", relay.TensorType(x_shape, "float32"))
29    y = tvm.relay.nn.max_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
30                                ceil_mode=ceil_mode)
31
32    fwd_func = relay.Function([x], y)
33    fwd_func = run_infer_type(fwd_func)
34    bwd_func = run_infer_type(gradient(fwd_func))
35
36    data = np.random.rand(*x_shape).astype("float32")
37    ph, pw = padding
38    y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
39    out_grad = np.ones(shape=y_shape)
40    ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides,
41                                           padding=[ph, pw, ph, pw],
42                                           pool_type='max', ceil_mode=ceil_mode)
43
44    for target, ctx in ctx_list():
45        intrp = relay.create_executor(ctx=ctx, target=target)
46        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
47        np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
48
49
50def test_max_pool2d_grad():
51    verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
52    verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)
53
54
55def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad):
56    x = relay.var("x", relay.TensorType(x_shape, "float32"))
57    y = tvm.relay.nn.avg_pool2d(x, pool_size=pool_size, strides=strides, padding=padding,
58                                ceil_mode=ceil_mode, count_include_pad=count_include_pad)
59
60    fwd_func = relay.Function([x], y)
61    fwd_func = run_infer_type(fwd_func)
62    bwd_func = run_infer_type(gradient(fwd_func))
63
64    data = np.random.rand(*x_shape).astype("float32")
65    ph, pw = padding
66    y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
67    out_grad = np.ones(shape=y_shape)
68    ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=pool_size, strides=strides,
69                                           padding=[ph, pw, ph, pw],
70                                           pool_type='avg', ceil_mode=ceil_mode)
71
72    for target, ctx in ctx_list():
73        intrp = relay.create_executor(ctx=ctx, target=target)
74        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
75        np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
76
77def test_avg_pool2d_grad():
78    verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
79                           ceil_mode=False, count_include_pad=True)
80    verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1),
81                           ceil_mode=False, count_include_pad=False)
82
83
84def verify_global_avg_pool2d_grad(x_shape):
85    x = relay.var("x", relay.TensorType(x_shape, "float32"))
86    y = tvm.relay.nn.global_avg_pool2d(x)
87
88    fwd_func = relay.Function([x], y)
89    fwd_func = run_infer_type(fwd_func)
90    bwd_func = run_infer_type(gradient(fwd_func))
91
92    data = np.random.rand(*x_shape).astype("float32")
93    y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
94    out_grad = np.ones(shape=y_shape)
95    ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]),
96                                            strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg',
97                                            ceil_mode=False)
98
99    for target, ctx in ctx_list():
100        intrp = relay.create_executor(ctx=ctx, target=target)
101        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
102        np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
103
104def test_global_avg_pool2d_grad():
105    verify_global_avg_pool2d_grad((1, 4, 16, 16))
106    verify_global_avg_pool2d_grad((1, 8, 8, 24))
107
108def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
109    try:
110        import torch
111        import torch.nn.functional as F
112    except ImportError:
113        print('Skip because pytorch is not installed')
114        return
115
116    dtype = 'float32'
117    data = relay.var('data', shape=dshape, dtype=dtype)
118    weight = relay.var('weight', shape=wshape, dtype=dtype)
119    conv = relay.nn.conv2d(data, weight, strides=strides, padding=padding, dilation=dilation,
120                           groups=groups)
121    fwd_func = relay.Function([data, weight], conv)
122    fwd_func = run_infer_type(fwd_func)
123    bwd_func = run_infer_type(gradient(fwd_func, mode=mode))
124
125    data_pt = torch.randn(*dshape, dtype=torch.float32, requires_grad=True)
126    weight_pt = torch.randn(*wshape, dtype=torch.float32, requires_grad=True)
127    out_pt = F.conv2d(data_pt, weight_pt, stride=strides, padding=padding, dilation=dilation,
128                      groups=groups)
129    grad_output_pt = torch.ones(out_pt.shape)
130    grad_input_pt = F.grad.conv2d_input(dshape, weight_pt, grad_output_pt, stride=strides,
131                                        padding=padding, dilation=dilation, groups=groups) \
132                          .detach().numpy()
133    grad_weight_pt = F.grad.conv2d_weight(data_pt, wshape, grad_output_pt, stride=strides,
134                                          padding=padding, dilation=dilation, groups=groups) \
135                           .detach().numpy()
136
137
138    for target, ctx in ctx_list():
139        data = tvm.nd.array(data_pt.detach().numpy(), ctx)
140        weight = tvm.nd.array(weight_pt.detach().numpy(), ctx)
141        intrp = relay.create_executor(ctx=ctx, target=target)
142        op_res, (grad_input, grad_weight) = intrp.evaluate(bwd_func)(data, weight)
143        np.testing.assert_allclose(grad_input.asnumpy(), grad_input_pt, rtol=1e-4, atol=1e-4)
144        np.testing.assert_allclose(grad_weight.asnumpy(), grad_weight_pt, rtol=1e-4, atol=1e-4)
145
146
147def test_conv2d_grad():
148    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1])
149    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [1, 1], [0, 0], [1, 1])
150    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 1, 1), [2, 2], [0, 0], [1, 1])
151    verify_conv2d_grad((1, 4, 16, 16), (16, 4, 3, 3), [1, 1], [1, 1], [1, 1], mode='first_order')
152
153
154def verify_dense_grad(d_shape, w_shape):
155    data = relay.var("data", relay.TensorType(d_shape, "float32"))
156    weight = relay.var("weight", relay.TensorType(w_shape, "float32"))
157    fwd_func = relay.Function([data, weight], relay.nn.dense(data, weight))
158    check_grad(fwd_func)
159
160
161def test_dense_grad():
162    verify_dense_grad((1, 8), (16, 8))
163    verify_dense_grad((1, 4), (3, 4))
164    verify_dense_grad((5, 4), (3, 4))
165
166
167def verify_batch_flatten_grad(d_shape):
168    data = relay.var("data", relay.TensorType(d_shape, "float32"))
169    fwd_func = relay.Function([data], relay.nn.batch_flatten(data))
170    check_grad(fwd_func)
171
172
173def test_batch_flatten_grad():
174    verify_batch_flatten_grad((1, 2, 3, 4))
175    verify_batch_flatten_grad((1, 8))
176
177
178if __name__ == "__main__":
179    test_max_pool2d_grad()
180    test_avg_pool2d_grad()
181    test_global_avg_pool2d_grad()
182    test_conv2d_grad()
183    test_dense_grad()
184    test_batch_flatten_grad()
185