1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Perform ResNet autoTVM tuning on VTA using Relay."""
19
20import argparse, os, time
21from mxnet.gluon.model_zoo import vision
22import numpy as np
23from PIL import Image
24
25import topi
26import tvm
27from tvm import rpc, autotvm, relay
28from tvm.autotvm.measure.measure_methods import request_remote
29from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
30from tvm.contrib import graph_runtime, util, download
31from tvm.contrib.debugger import debug_runtime
32import vta
33from vta.testing import simulator
34from vta.top import graph_pack
35from tvm.autotvm.task import extract_from_program
36
37def parse_arguments():
38
39    parser = argparse.ArgumentParser(description='Train a model for image classification.')
40    parser.add_argument('--model', type=str, default='resnet18_v1', choices=['resnet18_v1'],
41                        help='Input model name.')
42    parser.add_argument('--start-name', type=str, default='nn.max_pool2d',
43                        help='The name of the node where packing starts')
44    parser.add_argument('--stop-name', type=str, default='nn.global_avg_pool2d',
45                        help='The name of the node where packing stops')
46    parser.add_argument('--debug-profile', action='store_true',
47                        help='Show layer-wise time cost profiling results')
48    parser.add_argument('--device', default='vta',  choices=['vta', 'arm_cpu'],
49                        help='Select device target')
50    parser.add_argument('--measurements', type=int, default=1,
51                        help='Number of measurements during AutoTVM search')
52    parser.add_argument('--tuner', type=str, default="random",
53                        help='AutoTVM search strategy')
54    parser.add_argument('--log-filename', type=str, default="resnet-18.log",
55                        help='AutoTVM log file name')
56
57    return parser.parse_args()
58
59
60def register_vta_tuning_tasks():
61    from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args
62
63    @tvm.tag_scope(tag=topi.tag.ELEMWISE)
64    def my_clip(x, a_min, a_max):
65        """Unlike topi's current clip, put min and max into two stages."""
66        const_min = tvm.const(a_min, x.dtype)
67        const_max = tvm.const(a_max, x.dtype)
68        x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
69        x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
70        return x
71
72    # init autotvm env to register VTA operator
73    TaskExtractEnv()
74
75    @autotvm.task.register("topi_nn_conv2d", override=True)
76    def _topi_nn_conv2d(*args, **kwargs):
77        assert not kwargs, "Do not support kwargs in template function call"
78        args = deserialize_args(args)
79        A, W = args[:2]
80
81        with tvm.target.vta():
82            res = topi.nn.conv2d(*args, **kwargs)
83            res = topi.right_shift(res, 8)
84            res = my_clip(res, 0, 127)
85            res = topi.cast(res, "int8")
86
87        if tvm.target.current_target().device_name == 'vta':
88            s = topi.generic.schedule_conv2d_nchw([res])
89        else:
90            s = tvm.create_schedule([res.op])
91        return s, [A, W, res]
92
93    @autotvm.task.register("topi_nn_dense", override=True)
94    def _topi_nn_dense(*args, **kwargs):
95        assert not kwargs, "Do not support kwargs in template function call"
96        args = deserialize_args(args)
97        A, W = args[:2]
98
99        with tvm.target.vta():
100            res = topi.nn.dense(*args, **kwargs)
101            res = topi.right_shift(res, 8)
102            res = my_clip(res, 0, 127)
103            res = topi.cast(res, "int8")
104
105        if tvm.target.current_target().device_name == 'vta':
106            s = topi.generic.schedule_dense([res])
107        else:
108            s = tvm.create_schedule([res.op])
109
110        return s, [A, W, res]
111
112
113def compile_network(opt, env, target):
114
115    # Populate the shape and data type dictionary
116    dtype_dict = {"data": 'float32'}
117    shape_dict = {"data": (env.BATCH, 3, 224, 224)}
118
119    # Get off the shelf gluon model, and convert to relay
120    gluon_model = vision.get_model(opt.model, pretrained=True)
121    mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
122
123    # Update shape and type dictionary
124    shape_dict.update({k: v.shape for k, v in params.items()})
125    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
126
127    # Perform quantization in Relay
128    # Note: We set opt_level to 3 in order to fold batch norm
129    with relay.build_config(opt_level=3):
130        with relay.quantize.qconfig(global_scale=8.0,
131                                    skip_conv_layers=[0]):
132            relay_prog = relay.quantize.quantize(mod["main"], params=params)
133
134    # Perform graph packing and constant folding for VTA target
135    if target.device_name == "vta":
136        assert env.BLOCK_IN == env.BLOCK_OUT
137        relay_prog = graph_pack(
138            relay_prog,
139            env.BATCH,
140            env.BLOCK_OUT,
141            env.WGT_WIDTH,
142            start_name=opt.start_name,
143            stop_name=opt.stop_name)
144
145    return relay_prog, params
146
147
148def tune_tasks(tasks,
149               measure_option,
150               tuner='xgb',
151               n_trial=1000,
152               early_stopping=None,
153               log_filename='tuning.log',
154               use_transfer_learning=True,
155               try_winograd=True):
156
157    # create tmp log file
158    tmp_log_file = log_filename + ".tmp"
159    if os.path.exists(tmp_log_file):
160        os.remove(tmp_log_file)
161
162    for i, tsk in enumerate(reversed(tasks)):
163        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
164
165        # create tuner
166        if tuner == 'xgb' or tuner == 'xgb-rank':
167            tuner_obj = XGBTuner(tsk, loss_type='rank')
168        elif tuner == 'ga':
169            tuner_obj = GATuner(tsk, pop_size=50)
170        elif tuner == 'random':
171            tuner_obj = RandomTuner(tsk)
172        elif tuner == 'gridsearch':
173            tuner_obj = GridSearchTuner(tsk)
174        else:
175            raise ValueError("Invalid tuner: " + tuner)
176
177        if use_transfer_learning:
178            if os.path.isfile(tmp_log_file):
179                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
180
181        # do tuning
182        n_trial_ = min(n_trial, len(tsk.config_space))
183        tuner_obj.tune(n_trial_,
184                       early_stopping=early_stopping,
185                       measure_option=measure_option,
186                       callbacks=[
187                           autotvm.callback.progress_bar(n_trial_, prefix=prefix),
188                           autotvm.callback.log_to_file(tmp_log_file)])
189
190    # pick best records to a cache file
191    autotvm.record.pick_best(tmp_log_file, log_filename)
192    os.remove(tmp_log_file)
193
194if __name__ == '__main__':
195
196    opt = parse_arguments()
197
198    # Make sure that TVM was compiled with RPC=1
199    assert tvm.module.enabled("rpc")
200
201    # Read in VTA environment
202    env = vta.get_env()
203
204    # Get remote from fleet node
205    tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
206    tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
207    if not tracker_host or not tracker_port:
208        print("Set your AutoTVM tracker node host and port variables to run the autotuner")
209        exit()
210
211    # Get remote
212    if env.TARGET != "sim":
213
214        # Measure build start time
215        reconfig_start = time.time()
216
217        # Get remote from fleet node
218        remote = autotvm.measure.request_remote(env.TARGET, tracker_host, int(tracker_port), timeout=10000)
219
220        # Reconfigure the JIT runtime and FPGA.
221        # You can program the FPGA with your own custom bitstream
222        # by passing the path to the bitstream file instead of None.
223        vta.reconfig_runtime(remote)
224        vta.program_fpga(remote, bitstream=None)
225
226        # Report on reconfiguration time
227        reconfig_time = time.time() - reconfig_start
228        print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time))
229
230    # In simulation mode, host the RPC server locally.
231    else:
232        remote = rpc.LocalSession()
233
234    # VTA target and execution context
235    target = env.target if opt.device == "vta" else env.target_vta_cpu
236    ctx = remote.ext_dev(0) if opt.device == "vta" else remote.cpu(0)
237
238    # Compile Relay program
239    print("Initial compile...")
240    relay_prog, params = compile_network(opt, env, target)
241
242    # Register VTA tuning tasks
243    register_vta_tuning_tasks()
244
245    # Perform task extraction on Relay program
246    print("Extracting tasks...")
247    tasks = extract_from_program(func=relay_prog,
248                                 params=params,
249                                 ops=(tvm.relay.op.nn.conv2d,),
250                                 target=target,
251                                 target_host=env.target_host)
252
253    # Perform Autotuning
254    print("Tuning...")
255    tuning_opt = {
256        'log_filename': opt.log_filename,
257        'tuner': opt.tuner,
258        'n_trial': 1e9,
259        'early_stopping': None,
260        'measure_option': autotvm.measure_option(
261                builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
262                runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port,
263                    number=4, min_repeat_ms=150, repeat=opt.measurements, timeout=60,
264                    check_correctness=True))
265    }
266    tune_tasks(tasks, **tuning_opt)
267
268    # Compile kernels with history best records
269    with autotvm.tophub.context(target, extra_files=[opt.log_filename]):
270
271        # Compile network
272        print("Compiling network with best tuning parameters...")
273        with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
274            if target.device_name != "vta":
275                graph, lib, params = relay.build(
276                    relay_prog, target=target,
277                    params=params, target_host=env.target_host)
278            else:
279                with vta.build_config():
280                    graph, lib, params = relay.build(
281                        relay_prog, target=target,
282                        params=params, target_host=env.target_host)
283
284        # Export library
285        temp = util.tempdir()
286        lib.save(temp.relpath("graphlib.o"))
287        remote.upload(temp.relpath("graphlib.o"))
288        lib = remote.load_module("graphlib.o")
289
290        # If detailed runtime info is needed build with debug runtime
291        if opt.debug_profile:
292            m = debug_runtime.create(graph, lib, ctx)
293        else:
294            m = graph_runtime.create(graph, lib, ctx)
295
296        # Set the network parameters and synthetic input
297        image = tvm.nd.array(
298            (np.random.uniform(size=(1, 3, 224, 224))).astype('float32'))
299        m.set_input(**params)
300        m.set_input('data', image)
301
302        # Perform inference
303        timer = m.module.time_evaluator("run", ctx, number=4, repeat=opt.measurements)
304        tcost = timer()
305        prof_res = np.array(tcost.results) * 1000  # convert to millisecond
306        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
307              (np.mean(prof_res), np.std(prof_res)))
308
309        # Display profile information
310        if opt.debug_profile:
311            m.run()
312