1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=unused-variable 18"""Definition of task function. 19 20Task can be constructed from tuple of func, args, and kwargs. 21func is a state-less function, or a string that 22registers the standard task. 23""" 24 25import numpy as np 26 27from ... import tensor, expr, container, target as _target 28 29from ..util import get_const_int, get_const_tuple, get_func_name 30from .dispatcher import DispatchContext, ApplyConfig, dispatcher 31from .space import ConfigSpace 32 33def _raise_error(*args, **kwargs): # pylint: disable=unused-argument 34 raise RuntimeError("The function of this task is not found. Possibly the function " 35 "of this task is registered in another python file " 36 "which is not imported in this run") 37 38class Task(object): 39 """A Tunable Task 40 41 Parameters 42 ---------- 43 name: str 44 The name of the task. 45 args: Tuple 46 Positional argument of func 47 """ 48 def __init__(self, name, args): 49 self.name = name 50 self.args = args 51 self.kwargs = {} # currently unused 52 53 # init null config space 54 self.config_space = None 55 self.func = TASK_TABLE.get(name, _raise_error) 56 57 # auxiliary info, available after `init_space` is called 58 self.workload = None 59 self.flop = None 60 self.target = None 61 self.target_host = None 62 63 def instantiate(self, config): 64 """Instantiate this task function (template) with a config. 65 Returns corresponding schedule. 66 67 Parameters 68 ---------- 69 config: template.ConfigEntity 70 parameter config for this template 71 72 Returns 73 ------- 74 sch: tvm.schedule.Schedule 75 The tvm schedule 76 arg_bufs: Array of tvm.tensor.Tensor 77 The input/output buffers 78 """ 79 config.flop = 0 80 with ApplyConfig(config): 81 sch, arg_bufs = self.func(*self.args, **self.kwargs) 82 if not self.flop: 83 config.flop = config.flop or compute_flop(sch) 84 self.flop = config.flop 85 return sch, arg_bufs 86 87 def __getstate__(self): 88 # custom pickle implementation is required for 89 # some unpickable local task functions. 90 # So we only pickle the name of the function 91 # and restore the function by name when unpickling it. 92 return { 93 "name": self.name, 94 "args": self.args, 95 "kwargs": self.kwargs, 96 "config_space": self.config_space, 97 "workload": self.workload, 98 "flop": self.flop, 99 "target": self.target, 100 "target_host": self.target_host 101 } 102 103 def __setstate__(self, state): 104 self.name = state["name"] 105 self.args = state["args"] 106 self.kwargs = state["kwargs"] 107 self.config_space = state["config_space"] 108 self.func = TASK_TABLE.get(state["name"], _raise_error) 109 self.workload = state["workload"] 110 self.flop = state["flop"] 111 self.target = state["target"] 112 self.target_host = state["target_host"] 113 114 def __repr__(self): 115 return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % ( 116 self.name, self.args, self.kwargs, self.workload 117 ) 118 119TASK_TABLE = { 120} 121 122def register(name, func=None, override=False): 123 """Register a task function. 124 125 Parameters 126 ---------- 127 name : str 128 The name to identify the task. 129 func : callable 130 The function to be registered. 131 override : bool 132 Whether override existing registration. 133 134 Returns 135 ------- 136 func: callable 137 The registered function 138 """ 139 def _do_reg(myf): 140 if name in TASK_TABLE and not override: 141 raise ValueError( 142 "Key %s is already registered" % name) 143 TASK_TABLE[name] = myf 144 return myf 145 if func: 146 return _do_reg(func) 147 return _do_reg 148 149def create(func_name, args, target, target_host=None, template_key=None): 150 """Create a tuning task and initialize its search space 151 152 Parameters 153 ---------- 154 func_name : str or callable 155 The task function 156 args : List 157 Positional arguments 158 target : Target 159 The compilation target 160 target_host: Target, optional 161 The compilation target for host side 162 163 Returns 164 ------- 165 tsk: Task 166 a task object 167 """ 168 if callable(func_name): 169 # register this function if it is not registered before 170 func = func_name 171 func_name = func.func_name if hasattr(func, 'func_name') else func.__name__ 172 if func_name in TASK_TABLE: 173 assert func == TASK_TABLE[func_name], "Find name conflict in task registration. " \ 174 "Consider to choose another name for this task" 175 else: 176 register(func_name, func=func) 177 178 func = TASK_TABLE[func_name] 179 ret = Task(func_name, args) 180 181 if isinstance(target, str): 182 target = _target.create(target) 183 184 # init config space 185 ret.config_space = ConfigSpace() 186 ret.config_space.template_key = template_key or "" 187 188 ctx = ApplyConfig(ret.config_space) 189 with ctx: 190 with target: 191 sch, _ = func(*args) 192 ret.config_space.code_hash = getattr(sch, 'code_hash', None) 193 194 ret.workload = ctx.workload 195 ret.flop = ret.config_space.flop or compute_flop(sch) 196 ret.target = target 197 ret.target_host = target_host 198 199 return ret 200 201def args_to_workload(x, topi_compute_func=None): 202 """Convert argument list to hashable workload tuple. 203 This function will convert list to tuple, tvm node to python value and 204 flatten tvm.tensor.Tensor to a tuple 205 206 Parameters 207 ---------- 208 x: primitive hashable types or tensor.Tensor 209 The original value 210 topi_compute_func: topi compute function 211 The function name will be added as first element of the workload tuple 212 213 Returns 214 ------- 215 ret: hashable 216 The hashable value 217 """ 218 if isinstance(x, tensor.Tensor): 219 workload = get_const_tuple(x.shape) + (x.dtype, ) 220 elif isinstance(x, (tuple, list, container.Array)): 221 workload = tuple([args_to_workload(a) for a in x]) 222 elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)): 223 workload = x 224 elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)): 225 workload = x.value 226 elif x is None: 227 workload = 0 228 else: 229 raise RuntimeError('Do not support type "%s" in argument. Consider to use' 230 'primitive types or tvm.expr.Var only' % type(x)) 231 return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload 232 233def template(func): 234 """ 235 Decorate a function as a tunable schedule template 236 237 Parameters 238 ---------- 239 func: callable 240 A callable template function. 241 Its argument should be hashable values. 242 Its return value should be a Tuple(Schedule, Array of Tensor) 243 244 Returns 245 ------- 246 func: callable 247 The decorated function 248 249 Examples 250 -------- 251 The following code is a tunable template for a blocked matrix multiplication 252 253 .. code-block:: python 254 255 @autotvm.template 256 def matmul(N, L, M, dtype): 257 A = tvm.placeholder((N, L), name='A', dtype=dtype) 258 B = tvm.placeholder((L, M), name='B', dtype=dtype) 259 260 k = tvm.reduce_axis((0, L), name='k') 261 C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C') 262 s = tvm.create_schedule(C.op) 263 264 # schedule 265 y, x = s[C].op.axis 266 k = s[C].op.reduce_axis[0] 267 268 ##### define space begin ##### 269 cfg = autotvm.get_config() 270 cfg.define_split("tile_y", y, num_outputs=2) 271 cfg.define_split("tile_x", x, num_outputs=2) 272 ##### define space end ##### 273 274 # schedule according to config 275 yo, yi = cfg["tile_y"].apply(s, C, y) 276 xo, xi = cfg["tile_x"].apply(s, C, x) 277 278 s[C].reorder(yo, xo, k, yi, xi) 279 280 return s, [A, B, C] 281 """ 282 # pylint: disable=unused-variable 283 284 fname = get_func_name(func) 285 286 @register(fname) 287 @dispatcher 288 def config_dispatcher(*args, **kwargs): 289 assert not kwargs, "Do not support kwargs in template function call" 290 return (fname, ) + args_to_workload(args) 291 292 @config_dispatcher.register("") 293 def template_call(cfg, *args, **kwargs): 294 assert not kwargs, "Do not support kwargs in template function call" 295 with ApplyConfig(cfg): 296 return func(*args, **kwargs) 297 298 config_dispatcher.func_name = fname 299 return config_dispatcher 300 301def get_config(): 302 """Get current config object 303 304 Returns 305 ------- 306 cfg: ConfigSpace or ConfigEntity 307 The current config 308 """ 309 return DispatchContext.current.query(None, None) 310 311class FlopCalculationError(RuntimeError): 312 """Error happens when estimating FLOP for a compute op""" 313 314 315def compute_flop(sch): 316 """Calculate number of FLOP (floating number operations) of the compute ops in a schedule 317 318 Parameters 319 ---------- 320 sch: tvm.schedule.Schedule 321 schedule 322 323 Returns 324 ------- 325 flop: int 326 number of FLOP in this schedule 327 """ 328 def _prod_length(axes): 329 """compute product of the lengths of a list of axes""" 330 try: 331 num_iter = int(np.prod([get_const_int(axis.dom.extent) for axis in axes])) 332 except ValueError: 333 raise FlopCalculationError("The length of axis is not constant. ") 334 return num_iter 335 336 def _count_flop(exp): 337 """compute flop for a single expression""" 338 if isinstance(exp, expr.Reduce): 339 num_iter = _prod_length(exp.axis) 340 combiner = exp.combiner.result 341 source = exp.source 342 if len(combiner) != 1: 343 raise FlopCalculationError("Found multiple output in the combiner of reduce op") 344 if len(source) != 1: 345 raise FlopCalculationError("Found multiple output in the source of reduce op") 346 return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0])) 347 if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)): 348 return 0 349 if isinstance(exp, expr.Cast): 350 return _count_flop(exp.value) 351 if isinstance(exp, expr.Var): 352 return 0 353 if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, 354 expr.Div, expr.Mod, 355 expr.FloorDiv, expr.FloorMod, 356 expr.Max, expr.Min, 357 expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE, 358 expr.And, expr.Or, expr.Not)): 359 base = 1 360 361 if isinstance(exp, expr.Not): # unary 362 return base + _count_flop(exp.a) 363 364 return base + _count_flop(exp.a) + _count_flop(exp.b) 365 if isinstance(exp, expr.Select): 366 return _count_flop(exp.condition) + max(_count_flop(exp.true_value), 367 _count_flop(exp.false_value)) 368 if isinstance(exp, expr.Call): 369 if exp.call_type == expr.Call.Halide: 370 # Ignore flops from indexing expressions. 371 return 0 372 373 return sum([_count_flop(x) for x in exp.args]) 374 375 raise FlopCalculationError("Found unsupported operator in the compute expr") 376 377 def traverse(ops): 378 """accumulate flops""" 379 ret = 0 380 for op in ops: 381 if isinstance(op, tensor.ComputeOp): 382 num_element = _prod_length(op.axis) 383 384 body = op.body 385 if len(body) != 1: 386 raise FlopCalculationError("Found multiple output in the compute") 387 exp = body[0] 388 389 ret += num_element * _count_flop(exp) 390 ret += traverse([t.op for t in op.input_tensors]) 391 392 elif isinstance(op, tensor.PlaceholderOp): 393 pass 394 else: 395 raise FlopCalculationError("Only support tvm.compute currently. " 396 "Other ops like tvm.scan/tvm.extern is not supported") 397 return ret 398 399 try: 400 ret = traverse(sch.outputs) 401 except FlopCalculationError as exc: 402 raise RuntimeError("FLOP estimator fails for this operator. Error msg: " 403 + str(exc) + ". Please use `cfg.add_flop` to manually set " 404 "FLOP for this operator") 405 406 if ret == 0: 407 raise RuntimeError("Cannot find float number operation in this operator. " 408 "Please use `cfg.add_flop` to manually set " 409 "FLOP for this operator") 410 return ret 411