1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Perform ResNet autoTVM tuning on VTA using Relay.""" 19 20import argparse, os, time 21from mxnet.gluon.model_zoo import vision 22import numpy as np 23from PIL import Image 24 25import topi 26import tvm 27from tvm import rpc, autotvm, relay 28from tvm.autotvm.measure.measure_methods import request_remote 29from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner 30from tvm.contrib import graph_runtime, util, download 31from tvm.contrib.debugger import debug_runtime 32import vta 33from vta.testing import simulator 34from vta.top import graph_pack 35from tvm.autotvm.task import extract_from_program 36 37def parse_arguments(): 38 39 parser = argparse.ArgumentParser(description='Train a model for image classification.') 40 parser.add_argument('--model', type=str, default='resnet18_v1', choices=['resnet18_v1'], 41 help='Input model name.') 42 parser.add_argument('--start-name', type=str, default='nn.max_pool2d', 43 help='The name of the node where packing starts') 44 parser.add_argument('--stop-name', type=str, default='nn.global_avg_pool2d', 45 help='The name of the node where packing stops') 46 parser.add_argument('--debug-profile', action='store_true', 47 help='Show layer-wise time cost profiling results') 48 parser.add_argument('--device', default='vta', choices=['vta', 'arm_cpu'], 49 help='Select device target') 50 parser.add_argument('--measurements', type=int, default=1, 51 help='Number of measurements during AutoTVM search') 52 parser.add_argument('--tuner', type=str, default="random", 53 help='AutoTVM search strategy') 54 parser.add_argument('--log-filename', type=str, default="resnet-18.log", 55 help='AutoTVM log file name') 56 57 return parser.parse_args() 58 59 60def register_vta_tuning_tasks(): 61 from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args 62 63 @tvm.tag_scope(tag=topi.tag.ELEMWISE) 64 def my_clip(x, a_min, a_max): 65 """Unlike topi's current clip, put min and max into two stages.""" 66 const_min = tvm.const(a_min, x.dtype) 67 const_max = tvm.const(a_max, x.dtype) 68 x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") 69 x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") 70 return x 71 72 # init autotvm env to register VTA operator 73 TaskExtractEnv() 74 75 @autotvm.task.register("topi_nn_conv2d", override=True) 76 def _topi_nn_conv2d(*args, **kwargs): 77 assert not kwargs, "Do not support kwargs in template function call" 78 args = deserialize_args(args) 79 A, W = args[:2] 80 81 with tvm.target.vta(): 82 res = topi.nn.conv2d(*args, **kwargs) 83 res = topi.right_shift(res, 8) 84 res = my_clip(res, 0, 127) 85 res = topi.cast(res, "int8") 86 87 if tvm.target.current_target().device_name == 'vta': 88 s = topi.generic.schedule_conv2d_nchw([res]) 89 else: 90 s = tvm.create_schedule([res.op]) 91 return s, [A, W, res] 92 93 @autotvm.task.register("topi_nn_dense", override=True) 94 def _topi_nn_dense(*args, **kwargs): 95 assert not kwargs, "Do not support kwargs in template function call" 96 args = deserialize_args(args) 97 A, W = args[:2] 98 99 with tvm.target.vta(): 100 res = topi.nn.dense(*args, **kwargs) 101 res = topi.right_shift(res, 8) 102 res = my_clip(res, 0, 127) 103 res = topi.cast(res, "int8") 104 105 if tvm.target.current_target().device_name == 'vta': 106 s = topi.generic.schedule_dense([res]) 107 else: 108 s = tvm.create_schedule([res.op]) 109 110 return s, [A, W, res] 111 112 113def compile_network(opt, env, target): 114 115 # Populate the shape and data type dictionary 116 dtype_dict = {"data": 'float32'} 117 shape_dict = {"data": (env.BATCH, 3, 224, 224)} 118 119 # Get off the shelf gluon model, and convert to relay 120 gluon_model = vision.get_model(opt.model, pretrained=True) 121 mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) 122 123 # Update shape and type dictionary 124 shape_dict.update({k: v.shape for k, v in params.items()}) 125 dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) 126 127 # Perform quantization in Relay 128 # Note: We set opt_level to 3 in order to fold batch norm 129 with relay.build_config(opt_level=3): 130 with relay.quantize.qconfig(global_scale=8.0, 131 skip_conv_layers=[0]): 132 relay_prog = relay.quantize.quantize(mod["main"], params=params) 133 134 # Perform graph packing and constant folding for VTA target 135 if target.device_name == "vta": 136 assert env.BLOCK_IN == env.BLOCK_OUT 137 relay_prog = graph_pack( 138 relay_prog, 139 env.BATCH, 140 env.BLOCK_OUT, 141 env.WGT_WIDTH, 142 start_name=opt.start_name, 143 stop_name=opt.stop_name) 144 145 return relay_prog, params 146 147 148def tune_tasks(tasks, 149 measure_option, 150 tuner='xgb', 151 n_trial=1000, 152 early_stopping=None, 153 log_filename='tuning.log', 154 use_transfer_learning=True, 155 try_winograd=True): 156 157 # create tmp log file 158 tmp_log_file = log_filename + ".tmp" 159 if os.path.exists(tmp_log_file): 160 os.remove(tmp_log_file) 161 162 for i, tsk in enumerate(reversed(tasks)): 163 prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) 164 165 # create tuner 166 if tuner == 'xgb' or tuner == 'xgb-rank': 167 tuner_obj = XGBTuner(tsk, loss_type='rank') 168 elif tuner == 'ga': 169 tuner_obj = GATuner(tsk, pop_size=50) 170 elif tuner == 'random': 171 tuner_obj = RandomTuner(tsk) 172 elif tuner == 'gridsearch': 173 tuner_obj = GridSearchTuner(tsk) 174 else: 175 raise ValueError("Invalid tuner: " + tuner) 176 177 if use_transfer_learning: 178 if os.path.isfile(tmp_log_file): 179 tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) 180 181 # do tuning 182 n_trial_ = min(n_trial, len(tsk.config_space)) 183 tuner_obj.tune(n_trial_, 184 early_stopping=early_stopping, 185 measure_option=measure_option, 186 callbacks=[ 187 autotvm.callback.progress_bar(n_trial_, prefix=prefix), 188 autotvm.callback.log_to_file(tmp_log_file)]) 189 190 # pick best records to a cache file 191 autotvm.record.pick_best(tmp_log_file, log_filename) 192 os.remove(tmp_log_file) 193 194if __name__ == '__main__': 195 196 opt = parse_arguments() 197 198 # Make sure that TVM was compiled with RPC=1 199 assert tvm.module.enabled("rpc") 200 201 # Read in VTA environment 202 env = vta.get_env() 203 204 # Get remote from fleet node 205 tracker_host = os.environ.get("TVM_TRACKER_HOST", None) 206 tracker_port = os.environ.get("TVM_TRACKER_PORT", None) 207 if not tracker_host or not tracker_port: 208 print("Set your AutoTVM tracker node host and port variables to run the autotuner") 209 exit() 210 211 # Get remote 212 if env.TARGET != "sim": 213 214 # Measure build start time 215 reconfig_start = time.time() 216 217 # Get remote from fleet node 218 remote = autotvm.measure.request_remote(env.TARGET, tracker_host, int(tracker_port), timeout=10000) 219 220 # Reconfigure the JIT runtime and FPGA. 221 # You can program the FPGA with your own custom bitstream 222 # by passing the path to the bitstream file instead of None. 223 vta.reconfig_runtime(remote) 224 vta.program_fpga(remote, bitstream=None) 225 226 # Report on reconfiguration time 227 reconfig_time = time.time() - reconfig_start 228 print("Reconfigured FPGA and RPC runtime in {0:.2f}s!".format(reconfig_time)) 229 230 # In simulation mode, host the RPC server locally. 231 else: 232 remote = rpc.LocalSession() 233 234 # VTA target and execution context 235 target = env.target if opt.device == "vta" else env.target_vta_cpu 236 ctx = remote.ext_dev(0) if opt.device == "vta" else remote.cpu(0) 237 238 # Compile Relay program 239 print("Initial compile...") 240 relay_prog, params = compile_network(opt, env, target) 241 242 # Register VTA tuning tasks 243 register_vta_tuning_tasks() 244 245 # Perform task extraction on Relay program 246 print("Extracting tasks...") 247 tasks = extract_from_program(func=relay_prog, 248 params=params, 249 ops=(tvm.relay.op.nn.conv2d,), 250 target=target, 251 target_host=env.target_host) 252 253 # Perform Autotuning 254 print("Tuning...") 255 tuning_opt = { 256 'log_filename': opt.log_filename, 257 'tuner': opt.tuner, 258 'n_trial': 1e9, 259 'early_stopping': None, 260 'measure_option': autotvm.measure_option( 261 builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), 262 runner=autotvm.RPCRunner(env.TARGET, tracker_host, tracker_port, 263 number=4, min_repeat_ms=150, repeat=opt.measurements, timeout=60, 264 check_correctness=True)) 265 } 266 tune_tasks(tasks, **tuning_opt) 267 268 # Compile kernels with history best records 269 with autotvm.tophub.context(target, extra_files=[opt.log_filename]): 270 271 # Compile network 272 print("Compiling network with best tuning parameters...") 273 with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): 274 if target.device_name != "vta": 275 graph, lib, params = relay.build( 276 relay_prog, target=target, 277 params=params, target_host=env.target_host) 278 else: 279 with vta.build_config(): 280 graph, lib, params = relay.build( 281 relay_prog, target=target, 282 params=params, target_host=env.target_host) 283 284 # Export library 285 temp = util.tempdir() 286 lib.save(temp.relpath("graphlib.o")) 287 remote.upload(temp.relpath("graphlib.o")) 288 lib = remote.load_module("graphlib.o") 289 290 # If detailed runtime info is needed build with debug runtime 291 if opt.debug_profile: 292 m = debug_runtime.create(graph, lib, ctx) 293 else: 294 m = graph_runtime.create(graph, lib, ctx) 295 296 # Set the network parameters and synthetic input 297 image = tvm.nd.array( 298 (np.random.uniform(size=(1, 3, 224, 224))).astype('float32')) 299 m.set_input(**params) 300 m.set_input('data', image) 301 302 # Perform inference 303 timer = m.module.time_evaluator("run", ctx, number=4, repeat=opt.measurements) 304 tcost = timer() 305 prof_res = np.array(tcost.results) * 1000 # convert to millisecond 306 print("Mean inference time (std dev): %.2f ms (%.2f ms)" % 307 (np.mean(prof_res), np.std(prof_res))) 308 309 # Display profile information 310 if opt.debug_profile: 311 m.run() 312