1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import os 18import tvm 19import numpy as np 20from scipy import signal 21from tvm.contrib import nvcc 22 23import topi 24from topi.util import get_const_tuple 25from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc 26 27TASK = "depthwise_conv2d" 28USE_MANUAL_CODE = False 29 30@tvm.register_func 31def tvm_callback_cuda_compile(code): 32 ptx = nvcc.compile_cuda(code, target="ptx") 33 return ptx 34 35def write_code(code, fname): 36 with open(fname, "w") as f: 37 f.write(code) 38 39@tvm.register_func 40def tvm_callback_cuda_postproc(code): 41 if not os.path.exists("perf"): 42 os.mkdir("perf") 43 write_code(code, "perf/%s_generated.cu" % TASK) 44 if USE_MANUAL_CODE: 45 code = open("perf/%s_manual.cu" % TASK).read() 46 return code 47 48def test_depthwise_conv2d_nchw(): 49 """You may test different settings.""" 50 batch = 1 51 in_channel = 256 52 in_height = 96 53 in_width = 96 54 55 filter_channel = in_channel 56 channel_multiplier = 1 57 filter_height = 3 58 filter_width = 3 59 60 stride_h = 1 61 stride_w = 1 62 63 padding = 'SAME' # or 'VALID' 64 65 # Placeholder 66 Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') 67 Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') 68 Stride = [stride_h, stride_w] 69 Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') 70 Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') 71 # Declare 72 DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding) 73 ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) 74 Relu = topi.nn.relu(ScaleShift) 75 # Schedule 76 s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d) 77 s2 = schedule_depthwise_conv2d_nchw(ScaleShift) 78 s3 = schedule_depthwise_conv2d_nchw(Relu) 79 input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) 80 filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) 81 scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype) 82 shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype) 83 84 def check_device(device): 85 if not tvm.module.enabled(device): 86 print("Skip because %s is not enabled" % device) 87 return 88 ctx = tvm.context(device, 0) 89 # Build the kernel 90 f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) 91 f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) 92 f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) 93 # Prepare data 94 input_tvm = tvm.nd.array(input_np, ctx) 95 filter_tvm = tvm.nd.array(filter_np, ctx) 96 scale_tvm = tvm.nd.array(scale_np, ctx) 97 shift_tvm = tvm.nd.array(shift_np, ctx) 98 99 depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx) 100 scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) 101 relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) 102 # Measure time cost of kernel 1 (depthwise_conv2d) 103 timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000) 104 tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean 105 # Measure time cost of kernel 2 (depthwise_conv2d + scale_shift) 106 timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1000) 107 tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean 108 # Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu) 109 timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1000) 110 tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean 111 print("Input shape = " + str(get_const_tuple(Input.shape))) 112 print("Filter shape = " + str(get_const_tuple(Filter.shape))) 113 print("Stride = (%d, %d)" % (stride_h, stride_w)) 114 print("padding = %s\n" % padding) 115 print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape))) 116 print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6)) 117 print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6)) 118 print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6)) 119 # correctness 120 depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) 121 scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) 122 for c in range(in_channel * channel_multiplier): 123 scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] 124 relu_scipy = np.maximum(scale_shift_scipy, 0) 125 tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) 126 tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) 127 tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) 128 print("success") 129 130 for device in ['cuda', 'opencl', 'rocm']: 131 with tvm.build_config(auto_unroll_max_step=128, 132 unroll_explicit=device == 'rocm', 133 detect_global_barrier=False, 134 restricted_func=True): 135 check_device(device) 136 137def test_depthwise_conv2d_nhwc(): 138 """You may test different settings.""" 139 batch = 1 140 in_channel = 256 141 in_height = 96 142 in_width = 96 143 144 filter_channel = in_channel 145 channel_multiplier = 1 146 filter_height = 3 147 filter_width = 3 148 149 stride_h = 1 150 stride_w = 1 151 152 padding = 'SAME' # or 'VALID' 153 154 # Placeholder 155 Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input') 156 Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter') 157 Stride = [stride_h, stride_w] 158 Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') 159 Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') 160 # Declare 161 DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding) 162 ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) 163 Relu = topi.nn.relu(ScaleShift) 164 # Schedule 165 s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d) 166 s2 = schedule_depthwise_conv2d_nhwc(ScaleShift) 167 s3 = schedule_depthwise_conv2d_nhwc(Relu) 168 169 input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype) 170 filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype) 171 scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype) 172 shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype) 173 174 def check_device(device): 175 if not tvm.module.enabled(device): 176 print("Skip because %s is not enabled" % device) 177 return 178 ctx = tvm.context(device, 0) 179 # Build the kernel 180 f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device) 181 f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device) 182 f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) 183 # Prepare data 184 input_tvm = tvm.nd.array(input_np, ctx) 185 filter_tvm = tvm.nd.array(filter_np, ctx) 186 scale_tvm = tvm.nd.array(scale_np, ctx) 187 shift_tvm = tvm.nd.array(shift_np, ctx) 188 depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx) 189 scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx) 190 relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) 191 # Measure time cost of kernel 1 (depthwise_conv2d) 192 timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000) 193 tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean 194 # Measure time cost of kernel 2 (depthwise_conv2d + scale_shift) 195 timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1000) 196 tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean 197 # Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu) 198 timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1000) 199 tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean 200 print("Input shape = " + str(get_const_tuple(Input.shape))) 201 print("Filter shape = " + str(get_const_tuple(Filter.shape))) 202 print("Stride = (%d, %d)" % (stride_h, stride_w)) 203 print("padding = %s\n" % padding) 204 print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape))) 205 print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6)) 206 print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6)) 207 print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6)) 208 # correctness 209 depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding) 210 scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape)) 211 for c in range(in_channel * channel_multiplier): 212 scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c] 213 relu_scipy = np.maximum(scale_shift_scipy, 0) 214 tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5) 215 tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5) 216 tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5) 217 print("success") 218 219 for device in ['cuda', 'opencl', 'rocm']: 220 with tvm.build_config(auto_unroll_max_step=128, 221 detect_global_barrier=False, 222 restricted_func=True): 223 check_device(device) 224 225if __name__ == "__main__": 226 test_depthwise_conv2d_nchw() 227 test_depthwise_conv2d_nhwc() 228