1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Test code for pooling
18Copied from topi/tests/python/test_topi_pooling.py.
19Should be removed once we fix OpenGL testing on Jenkins.
20"""
21import numpy as np
22import tvm
23import topi
24import math
25from topi.util import get_const_tuple
26
27def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
28    iw = ih
29    kw = kh
30    sw = sh
31    ph, pw = padding
32    A = tvm.placeholder((n, ic, ih, iw), name='A')
33    B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
34                     pool_type=pool_type, ceil_mode=ceil_mode)
35    B = topi.nn.relu(B)
36    dtype = A.dtype
37
38    bshape = get_const_tuple(B.shape)
39    ashape = get_const_tuple(A.shape)
40    if ceil_mode:
41        assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1)
42        assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1)
43    else:
44        assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1)
45        assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1)
46
47
48    a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
49    pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype)
50    no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
51    pad_np[np.ix_(*no_zero)] = a_np
52    _, oc, oh, ow = get_const_tuple(B.shape)
53    b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
54
55    if pool_type == 'avg':
56        for i in range(oh):
57            for j in range(ow):
58                b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
59    elif pool_type =='max':
60        for i in range(oh):
61            for j in range(ow):
62                b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
63    b_np = np.maximum(b_np, 0.0)
64
65    def check_device(device):
66        if not tvm.module.enabled(device):
67            print("Skip because %s is not enabled" % device)
68            return
69        print("Running on target: %s" % device)
70        with tvm.target.create(device):
71            s = topi.generic.schedule_pool(B)
72        ctx = tvm.context(device, 0)
73        a = tvm.nd.array(a_np, ctx)
74        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
75        print(tvm.lower(s, [A, B], simple_mode=True))
76
77        f = tvm.build(s, [A, B], device)
78        f(a, b)
79        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
80
81    for device in ['opengl']:
82        check_device(device)
83
84def test_pool():
85    verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False)
86    verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False)
87    verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False)
88    verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False)
89    verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True)
90
91
92
93def verify_global_pool(n, c, h, w, pool_type):
94    A = tvm.placeholder((n, c, h, w), name='A')
95    B = topi.nn.global_pool(A, pool_type=pool_type)
96    B = topi.nn.relu(B)
97
98    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
99    if pool_type == 'avg':
100        b_np = np.mean(a_np, axis=(2,3), keepdims=True)
101    elif pool_type =='max':
102        b_np = np.max(a_np, axis=(2,3), keepdims=True)
103    b_np = np.maximum(b_np, 0.0)
104
105    def check_device(device):
106        if not tvm.module.enabled(device):
107            print("Skip because %s is not enabled" % device)
108            return
109        print("Running on target: %s" % device)
110        with tvm.target.create(device):
111            s = topi.generic.schedule_global_pool(B)
112        ctx = tvm.context(device, 0)
113        a = tvm.nd.array(a_np, ctx)
114        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
115        f = tvm.build(s, [A, B], device)
116        f(a, b)
117        tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
118
119    for device in ['opengl']:
120        check_device(device)
121
122def test_global_pool():
123    verify_global_pool(1, 1024, 7, 7, 'avg')
124    verify_global_pool(4, 1024, 7, 7, 'avg')
125    verify_global_pool(1, 1024, 7, 7, 'max')
126    verify_global_pool(4, 1024, 7, 7, 'max')
127
128
129if __name__ == "__main__":
130    test_pool()
131    test_global_pool()
132