1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import os
18import tvm
19import numpy as np
20from scipy import signal
21from tvm.contrib import nvcc
22
23import topi
24from topi.util import get_const_tuple
25from topi.cuda.depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
26
27TASK = "depthwise_conv2d"
28USE_MANUAL_CODE = False
29
30@tvm.register_func
31def tvm_callback_cuda_compile(code):
32    ptx = nvcc.compile_cuda(code, target="ptx")
33    return ptx
34
35def write_code(code, fname):
36    with open(fname, "w") as f:
37        f.write(code)
38
39@tvm.register_func
40def tvm_callback_cuda_postproc(code):
41    if not os.path.exists("perf"):
42        os.mkdir("perf")
43    write_code(code, "perf/%s_generated.cu" % TASK)
44    if USE_MANUAL_CODE:
45        code = open("perf/%s_manual.cu" % TASK).read()
46    return code
47
48def test_depthwise_conv2d_nchw():
49    """You may test different settings."""
50    batch = 1
51    in_channel = 256
52    in_height = 96
53    in_width = 96
54
55    filter_channel = in_channel
56    channel_multiplier = 1
57    filter_height = 3
58    filter_width = 3
59
60    stride_h = 1
61    stride_w = 1
62
63    padding = 'SAME' # or 'VALID'
64
65    # Placeholder
66    Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
67    Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
68    Stride = [stride_h, stride_w]
69    Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
70    Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
71    # Declare
72    DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, Filter, Stride, padding)
73    ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
74    Relu = topi.nn.relu(ScaleShift)
75    # Schedule
76    s1 = schedule_depthwise_conv2d_nchw(DepthwiseConv2d)
77    s2 = schedule_depthwise_conv2d_nchw(ScaleShift)
78    s3 = schedule_depthwise_conv2d_nchw(Relu)
79    input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
80    filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
81    scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype)
82    shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype)
83
84    def check_device(device):
85        if not tvm.module.enabled(device):
86            print("Skip because %s is not enabled" % device)
87            return
88        ctx = tvm.context(device, 0)
89        # Build the kernel
90        f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
91        f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
92        f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
93        # Prepare data
94        input_tvm = tvm.nd.array(input_np, ctx)
95        filter_tvm = tvm.nd.array(filter_np, ctx)
96        scale_tvm = tvm.nd.array(scale_np, ctx)
97        shift_tvm = tvm.nd.array(shift_np, ctx)
98
99        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx)
100        scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
101        relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
102        # Measure time cost of kernel 1 (depthwise_conv2d)
103        timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000)
104        tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
105        # Measure time cost of kernel 2 (depthwise_conv2d + scale_shift)
106        timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1000)
107        tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
108        # Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu)
109        timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1000)
110        tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
111        print("Input shape = " + str(get_const_tuple(Input.shape)))
112        print("Filter shape = " + str(get_const_tuple(Filter.shape)))
113        print("Stride = (%d, %d)" % (stride_h, stride_w))
114        print("padding = %s\n" % padding)
115        print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
116        print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
117        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
118        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
119        # correctness
120        depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
121        scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
122        for c in range(in_channel * channel_multiplier):
123            scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
124        relu_scipy = np.maximum(scale_shift_scipy, 0)
125        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
126        tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
127        tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
128        print("success")
129
130    for device in ['cuda', 'opencl', 'rocm']:
131        with tvm.build_config(auto_unroll_max_step=128,
132                              unroll_explicit=device == 'rocm',
133                              detect_global_barrier=False,
134                              restricted_func=True):
135            check_device(device)
136
137def test_depthwise_conv2d_nhwc():
138    """You may test different settings."""
139    batch = 1
140    in_channel = 256
141    in_height = 96
142    in_width = 96
143
144    filter_channel = in_channel
145    channel_multiplier = 1
146    filter_height = 3
147    filter_width = 3
148
149    stride_h = 1
150    stride_w = 1
151
152    padding = 'SAME' # or 'VALID'
153
154    # Placeholder
155    Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input')
156    Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
157    Stride = [stride_h, stride_w]
158    Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
159    Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
160    # Declare
161    DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, Filter, Stride, padding)
162    ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
163    Relu = topi.nn.relu(ScaleShift)
164    # Schedule
165    s1 = schedule_depthwise_conv2d_nhwc(DepthwiseConv2d)
166    s2 = schedule_depthwise_conv2d_nhwc(ScaleShift)
167    s3 = schedule_depthwise_conv2d_nhwc(Relu)
168
169    input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
170    filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
171    scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype)
172    shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype)
173
174    def check_device(device):
175        if not tvm.module.enabled(device):
176            print("Skip because %s is not enabled" % device)
177            return
178        ctx = tvm.context(device, 0)
179        # Build the kernel
180        f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
181        f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
182        f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
183        # Prepare data
184        input_tvm = tvm.nd.array(input_np, ctx)
185        filter_tvm = tvm.nd.array(filter_np, ctx)
186        scale_tvm = tvm.nd.array(scale_np, ctx)
187        shift_tvm = tvm.nd.array(shift_np, ctx)
188        depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape),dtype=DepthwiseConv2d.dtype), ctx)
189        scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
190        relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
191        # Measure time cost of kernel 1 (depthwise_conv2d)
192        timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1000)
193        tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm).mean
194        # Measure time cost of kernel 2 (depthwise_conv2d + scale_shift)
195        timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=1000)
196        tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm).mean
197        # Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu)
198        timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=1000)
199        tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm).mean
200        print("Input shape = " + str(get_const_tuple(Input.shape)))
201        print("Filter shape = " + str(get_const_tuple(Filter.shape)))
202        print("Stride = (%d, %d)" % (stride_h, stride_w))
203        print("padding = %s\n" % padding)
204        print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
205        print("average time cost of 1000 runs (depthwise_conv2d) = %g us" % (tcost_1*1e6))
206        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift) = %g us" % (tcost_2*1e6))
207        print("average time cost of 1000 runs (depthwise_conv2d + scale_shift + relu) = %g us" % (tcost_3*1e6))
208        # correctness
209        depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nhwc(input_np, filter_np, stride=[stride_h, stride_w], padding=padding)
210        scale_shift_scipy = np.zeros(shape=get_const_tuple(ScaleShift.shape))
211        for c in range(in_channel * channel_multiplier):
212            scale_shift_scipy[:,:,:,c] = depthwise_conv2d_scipy[:,:,:,c] * scale_np[c] + shift_np[c]
213        relu_scipy = np.maximum(scale_shift_scipy, 0)
214        tvm.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
215        tvm.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
216        tvm.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
217        print("success")
218
219    for device in ['cuda', 'opencl', 'rocm']:
220        with tvm.build_config(auto_unroll_max_step=128,
221                              detect_global_barrier=False,
222                              restricted_func=True):
223            check_device(device)
224
225if __name__ == "__main__":
226    test_depthwise_conv2d_nchw()
227    test_depthwise_conv2d_nhwc()
228