1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Test code for pooling 18Copied from topi/tests/python/test_topi_pooling.py. 19Should be removed once we fix OpenGL testing on Jenkins. 20""" 21import numpy as np 22import tvm 23import topi 24import math 25from topi.util import get_const_tuple 26 27def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): 28 iw = ih 29 kw = kh 30 sw = sh 31 ph, pw = padding 32 A = tvm.placeholder((n, ic, ih, iw), name='A') 33 B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, 34 pool_type=pool_type, ceil_mode=ceil_mode) 35 B = topi.nn.relu(B) 36 dtype = A.dtype 37 38 bshape = get_const_tuple(B.shape) 39 ashape = get_const_tuple(A.shape) 40 if ceil_mode: 41 assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1) 42 assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1) 43 else: 44 assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1) 45 assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1) 46 47 48 a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) 49 pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) 50 no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) 51 pad_np[np.ix_(*no_zero)] = a_np 52 _, oc, oh, ow = get_const_tuple(B.shape) 53 b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) 54 55 if pool_type == 'avg': 56 for i in range(oh): 57 for j in range(ow): 58 b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) 59 elif pool_type =='max': 60 for i in range(oh): 61 for j in range(ow): 62 b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) 63 b_np = np.maximum(b_np, 0.0) 64 65 def check_device(device): 66 if not tvm.module.enabled(device): 67 print("Skip because %s is not enabled" % device) 68 return 69 print("Running on target: %s" % device) 70 with tvm.target.create(device): 71 s = topi.generic.schedule_pool(B) 72 ctx = tvm.context(device, 0) 73 a = tvm.nd.array(a_np, ctx) 74 b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) 75 print(tvm.lower(s, [A, B], simple_mode=True)) 76 77 f = tvm.build(s, [A, B], device) 78 f(a, b) 79 tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) 80 81 for device in ['opengl']: 82 check_device(device) 83 84def test_pool(): 85 verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False) 86 verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False) 87 verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) 88 verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) 89 verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) 90 91 92 93def verify_global_pool(n, c, h, w, pool_type): 94 A = tvm.placeholder((n, c, h, w), name='A') 95 B = topi.nn.global_pool(A, pool_type=pool_type) 96 B = topi.nn.relu(B) 97 98 a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) 99 if pool_type == 'avg': 100 b_np = np.mean(a_np, axis=(2,3), keepdims=True) 101 elif pool_type =='max': 102 b_np = np.max(a_np, axis=(2,3), keepdims=True) 103 b_np = np.maximum(b_np, 0.0) 104 105 def check_device(device): 106 if not tvm.module.enabled(device): 107 print("Skip because %s is not enabled" % device) 108 return 109 print("Running on target: %s" % device) 110 with tvm.target.create(device): 111 s = topi.generic.schedule_global_pool(B) 112 ctx = tvm.context(device, 0) 113 a = tvm.nd.array(a_np, ctx) 114 b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) 115 f = tvm.build(s, [A, B], device) 116 f(a, b) 117 tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) 118 119 for device in ['opengl']: 120 check_device(device) 121 122def test_global_pool(): 123 verify_global_pool(1, 1024, 7, 7, 'avg') 124 verify_global_pool(4, 1024, 7, 7, 'avg') 125 verify_global_pool(1, 1024, 7, 7, 'max') 126 verify_global_pool(4, 1024, 7, 7, 'max') 127 128 129if __name__ == "__main__": 130 test_pool() 131 test_global_pool() 132