1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Testing topi conv2d_transpose operator for VTA"""
19
20import json
21import os
22
23import pytest
24import numpy as np
25from collections import namedtuple
26
27import tvm
28from tvm import te
29from tvm import relay
30from tvm import autotvm
31from tvm.contrib import util
32from tvm.contrib.pickle_memoize import memoize
33from tvm import topi
34import tvm.topi.testing
35import vta
36from vta import program_fpga, reconfig_runtime
37import vta.testing
38from vta.testing import simulator
39
40
41Workload = namedtuple(
42    "Conv2DTransposeWorkload",
43    [
44        "batch",
45        "height",
46        "width",
47        "in_filter",
48        "out_filter",
49        "hkernel",
50        "wkernel",
51        "hpad",
52        "wpad",
53        "hstride",
54        "wstride",
55        "o_hpad",
56        "o_wpad",
57    ],
58)
59
60# Get batch info from env
61env = vta.get_env()
62
63# DCGAN workloads
64dcgan_wklds = [
65    # dcgan
66    ("DCGAN.CT1", Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)),
67    ("DCGAN.CT2", Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2, 0, 0)),
68    ("DCGAN.CT3", Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2, 0, 0)),
69]
70
71# FIXME: we need a custom clip operator to circumvent a pattern detection limitation
72@tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
73def my_clip(x, a_min, a_max):
74    """Unlike topi's current clip, put min and max into two stages."""
75    const_min = tvm.tir.const(a_min, x.dtype)
76    const_max = tvm.tir.const(a_max, x.dtype)
77    x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA")
78    x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
79    return x
80
81
82# Helper function to get factors
83def _find_factors(n):
84    factors = []
85    for f in range(1, n + 1):
86        if n % f == 0:
87            factors.append(f)
88    return factors
89
90
91def run_conv2d_transpose(
92    env, remote, wl, target, check_correctness=True, print_ir=False, samples=4
93):
94
95    # Workload assertions
96    assert wl.hpad == wl.wpad
97
98    # Perform packing only if we are targeting the accelerator
99    if "arm_cpu" in target.keys:
100        data_pack = False
101        layout = "NCHW"
102        fcompute = topi.arm_cpu.conv2d_transpose_nchw
103        fschedule = topi.arm_cpu.schedule_conv2d_transpose_nchw
104    elif "vta" in target.keys:
105        data_pack = True
106        layout = "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN)
107        fcompute = vta.top.conv2d_transpose_packed
108        fschedule = vta.top.schedule_conv2d_transpose_packed
109
110    # Derive shapes depending upon packing
111
112    a_shape = (wl.batch, wl.in_filter, wl.height, wl.width)
113    w_shape = (wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel)
114    if data_pack:
115        data_shape = (
116            wl.batch // env.BATCH,
117            wl.in_filter // env.BLOCK_IN,
118            wl.height,
119            wl.width,
120            env.BATCH,
121            env.BLOCK_IN,
122        )
123        kernel_shape = (
124            wl.out_filter // env.BLOCK_OUT,
125            wl.in_filter // env.BLOCK_IN,
126            wl.hkernel,
127            wl.wkernel,
128            env.BLOCK_OUT,
129            env.BLOCK_IN,
130        )
131    else:
132        data_shape = a_shape
133        kernel_shape = w_shape
134    data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
135    kernel = te.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
136    padding = relay.nn.get_pad_tuple2d((wl.hpad, wl.wpad))
137
138    # Define base computation schedule
139    with target:
140
141        res = fcompute(
142            data, kernel, (wl.hstride, wl.wstride), padding, env.acc_dtype, (wl.o_hpad, wl.o_wpad)
143        )
144        res = topi.right_shift(res, env.WGT_WIDTH)
145        res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
146        res = topi.cast(res, env.out_dtype)
147        # Derive base schedule
148        s = fschedule([res])
149        if print_ir:
150            print(vta.lower(s, [data, kernel, res], simple_mode=True))
151
152    # Derive number of ops
153    fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel + wl.o_hpad
154    fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel + wl.o_wpad
155    num_ops = (
156        2
157        * wl.batch
158        * fout_height
159        * fout_width
160        * wl.hkernel
161        * wl.wkernel
162        * wl.out_filter
163        * wl.in_filter
164    )
165
166    # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")
167    def get_ref_data():
168        # derive min max for act and wgt types (max non inclusive)
169        a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1))
170        w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1))
171        a_np = np.random.randint(a_min, a_max, size=a_shape).astype(data.dtype)
172        w_np = np.random.randint(
173            w_min, w_max, size=(wl.in_filter, wl.out_filter, wl.hkernel, wl.wkernel)
174        ).astype(kernel.dtype)
175        r_np = tvm.topi.testing.conv2d_transpose_nchw_python(
176            a_np.astype(env.acc_dtype),
177            w_np.astype(env.acc_dtype),
178            (wl.hstride, wl.wstride),
179            wl.hpad,
180            (wl.o_hpad, wl.o_wpad),
181        ).astype(env.acc_dtype)
182        return a_np, w_np, r_np
183
184    # Data in original format
185    data_np, kernel_np, res_ref = get_ref_data()
186    if data_pack:
187        data_np = data_np.reshape(
188            wl.batch // env.BATCH,
189            env.BATCH,
190            wl.in_filter // env.BLOCK_IN,
191            env.BLOCK_IN,
192            wl.height,
193            wl.width,
194        ).transpose((0, 2, 4, 5, 1, 3))
195        kernel_np = kernel_np.reshape(
196            wl.in_filter // env.BLOCK_IN,
197            env.BLOCK_IN,
198            wl.out_filter // env.BLOCK_OUT,
199            env.BLOCK_OUT,
200            wl.hkernel,
201            wl.wkernel,
202        ).transpose((2, 0, 4, 5, 3, 1))
203        kernel_np = np.flip(kernel_np, 2)
204        kernel_np = np.flip(kernel_np, 3)
205
206    # Build
207    if "vta" in target.keys:
208        mod = vta.build(
209            s,
210            [data, kernel, res],
211            target=target,
212            target_host=env.target_host,
213            name="conv2d_transpose",
214        )
215    else:
216        mod = tvm.build(
217            s,
218            [data, kernel, res],
219            target=target,
220            target_host=env.target_host,
221            name="conv2d_transpose",
222        )
223    temp = util.tempdir()
224    mod.save(temp.relpath("conv2d_transpose.o"))
225    remote.upload(temp.relpath("conv2d_transpose.o"))
226    f = remote.load_module("conv2d_transpose.o")
227    ctx = remote.context(str(target))
228
229    res_np = np.zeros(topi.util.get_const_tuple(res.shape)).astype(res.dtype)
230    data_arr = tvm.nd.array(data_np, ctx)
231    kernel_arr = tvm.nd.array(kernel_np, ctx)
232    res_arr = tvm.nd.array(res_np, ctx)
233    time_f = f.time_evaluator("conv2d_transpose", ctx, number=samples)
234
235    # In vta sim mode, collect simulator runtime statistics
236    stats = {}
237    cost = None
238    if env.TARGET in ["sim", "tsim"]:
239        # Check if we're in local RPC mode (allows us to rebuild the
240        # runtime on the fly when varying the VTA designs)
241        local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
242        if local_rpc:
243            if env.TARGET == "sim":
244                remote.get_function("vta.simulator.profiler_clear")()
245            else:
246                remote.get_function("vta.tsim.profiler_clear")()
247            cost = time_f(data_arr, kernel_arr, res_arr)
248            if env.TARGET == "sim":
249                stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
250            else:
251                stats = json.loads(remote.get_function("vta.tsim.profiler_status")())
252        else:
253            simulator.clear_stats()
254            cost = time_f(data_arr, kernel_arr, res_arr)
255            stats = simulator.stats()
256    else:
257        cost = time_f(data_arr, kernel_arr, res_arr)
258
259    # Check correctness
260    correct = False
261    if check_correctness:
262        res_orig = res_arr.asnumpy()
263        if data_pack:
264            res_orig = res_orig.transpose((0, 4, 1, 5, 2, 3)).reshape(
265                wl.batch, wl.out_filter, fout_height, fout_width
266            )
267        res_ref = res_ref >> env.WGT_WIDTH
268        res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH - 1) - 1)
269        res_ref = res_ref.astype(env.out_dtype)
270        correct = np.allclose(res_orig, res_ref)
271
272    gops = (num_ops / cost.mean) / float(10 ** 9)
273    status = "PASSED" if correct else "FAILED"
274    if "arm_cpu" in target.keys:
275        device = "CPU"
276    elif "vta" in target.keys:
277        device = "VTA"
278    print("%s CONV2D TEST %s: Time cost = %g sec/op, %g GOPS" % (device, status, cost.mean, gops))
279
280    return correct, cost, stats
281
282
283@pytest.mark.parametrize("device", ["vta", "arm_cpu"])
284def test_conv2d_transpose(device):
285    def _run(env, remote):
286        if device == "vta":
287            target = env.target
288            if env.TARGET not in ["sim", "tsim"]:
289                assert tvm.runtime.enabled("rpc")
290                program_fpga(remote, bitstream=None)
291                reconfig_runtime(remote)
292        elif device == "arm_cpu":
293            target = env.target_vta_cpu
294        with autotvm.tophub.context(target):  # load pre-tuned schedule parameters
295            for _, wl in dcgan_wklds:
296                print(wl)
297                run_conv2d_transpose(env, remote, wl, target)
298
299    vta.testing.run(_run)
300
301
302if __name__ == "__main__":
303    test_conv2d_transpose(device="arm_cpu")
304    test_conv2d_transpose(device="vta")
305