1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19"""Functions for enabling AMP (automatic mixed precision).""" 20__all__ = ['init', 'init_trainer', 'scale_loss', 'unscale', 'convert_model', 21 'convert_hybrid_block', 'list_lp16_ops', 'list_fp32_ops', 22 'list_lp16_fp32_ops', 'list_conditional_fp32_ops', 23 'list_widest_type_cast', 'list_loss_output_functions', 'list_lp16_use_fp32_params', 24 'convert_symbol'] 25 26from array import array 27import ctypes 28import logging 29import contextlib 30import numpy as np 31 32from ... import symbol 33from ...context import gpu 34from ...symbol import Symbol 35from ...module import BucketingModule 36from ...symbol import contrib as symbol_contrib 37from ... import ndarray 38from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP 39from . import lists 40from ...gluon import trainer 41from ... import base 42from ...base import c_str_array, SymbolHandle, check_call, _LIB, mx_uint, c_array_buf 43from ... import optimizer as opt 44from .loss_scaler import LossScaler 45 46bfloat16 = np.dtype([('bfloat16', np.uint16)]) 47 48def _cast_symbol_NDArray(s, dtype): 49 float_types_gpu = (np.float16, np.float32) 50 float_types_cpu = (bfloat16, np.float32) 51 if isinstance(s, Symbol): 52 return symbol.amp_cast(s, dtype=dtype) 53 elif isinstance(s, NDArray): 54 if (s.dtype != dtype and s.dtype in float_types_gpu and s.context.device_type != 'cpu'): 55 return ndarray.amp_cast(s, dtype=dtype) 56 elif (s.dtype != dtype and s.dtype in float_types_cpu and s.context.device_type == 'cpu'): 57 return ndarray.amp_cast(s, dtype=dtype) 58 else: 59 return s 60 else: 61 return s 62 63def _get_fun_to_wrap(name, module, submodule_dict): 64 module_internal = getattr(module, "_internal") 65 prefix = base._get_op_name_prefix(name) 66 if len(prefix) > 0: 67 if prefix != '_random_' or name.endswith('_like'): 68 func_name = name[len(prefix):] 69 cur_module = submodule_dict[prefix] 70 else: 71 func_name = name 72 cur_module = module_internal 73 elif name.startswith('_'): 74 func_name = name 75 cur_module = module_internal 76 else: 77 func_name = name 78 cur_module = module 79 return func_name, cur_module 80 81def _wrap_symbol_functions(module, target_dtype, target_precision_ops=None, 82 conditional_fp32_ops=None, fp32_ops=None): 83 def _ndarray_wrapper(f, target_dtype, fp32_param=None, cond_arg=None): 84 def _new_fun(*args, **kwargs): 85 if cond_arg is not None: 86 if (cond_arg[0] not in kwargs or 87 kwargs[cond_arg[0]] not in cond_arg[1]): 88 return f(*args, **kwargs) 89 if fp32_param: 90 new_args = [] 91 for i, x in enumerate(args): 92 if fp32_param[i]: 93 new_args.append(x) 94 else: 95 new_args.append(_cast_symbol_NDArray(x, target_dtype)) 96 else: 97 new_args = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype), args)) 98 args = tuple(new_args) 99 if fp32_param: 100 new_kwargs = {} 101 for k, v in kwargs.items(): 102 if k in fp32_param: 103 new_kwargs[k] = v 104 else: 105 new_kwargs[k] = _cast_symbol_NDArray(v, target_dtype) 106 kwargs = new_kwargs 107 else: 108 kwargs = {k: _cast_symbol_NDArray(v, target_dtype) for k, v in kwargs.items()} 109 return f(*args, **kwargs) 110 _new_fun.__name__ = f.__name__ 111 _new_fun.__module__ = f.__module__ 112 _new_fun.__doc__ = f.__doc__ 113 return _new_fun 114 115 def _symbol_wrapper(f, target_dtype, fp32_param=None, cond_arg=None): 116 def _new_fun(*args, **kwargs): 117 if cond_arg is not None: 118 if (cond_arg[0] not in kwargs or 119 kwargs[cond_arg[0]] not in cond_arg[1]): 120 return f(*args, **kwargs) 121 sym = f(*args, **kwargs) 122 inputs = sym.get_children() 123 aux = sym.list_auxiliary_states() 124 if fp32_param: 125 new_inputs = [] 126 for i, x in enumerate(inputs): 127 if (x.name in aux) or fp32_param[i]: 128 new_inputs.append(x) 129 else: 130 new_inputs.append(_cast_symbol_NDArray(x, target_dtype)) 131 inputs = new_inputs 132 else: 133 inputs = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype) 134 if x.name not in aux else x, inputs)) 135 atomic_sym = sym._gen_atomic_symbol() 136 wrapped_sym = atomic_sym(*inputs) 137 wrapped_sym._set_attr(name=sym.name) 138 return wrapped_sym 139 _new_fun.__name__ = f.__name__ 140 _new_fun.__module__ = f.__module__ 141 _new_fun.__doc__ = f.__doc__ 142 return _new_fun 143 144 def _symbol_widest_wrapper(f): 145 def _new_fun(*args, **kwargs): 146 symbols = [] 147 is_symbol = False 148 args = list(args) 149 for i, arg in enumerate(args): 150 if isinstance(arg, (Symbol, NDArray)): 151 symbols.append((args, i, arg)) 152 is_symbol = is_symbol or isinstance(arg, Symbol) 153 for k, arg in kwargs.items(): 154 if isinstance(arg, (Symbol, NDArray)): 155 symbols.append((kwargs, k, arg)) 156 is_symbol = is_symbol or isinstance(arg, Symbol) 157 if not is_symbol: 158 # NDArray case 159 widest_type = target_dtype 160 for _, _, arg in symbols: 161 if isinstance(arg, NDArray): 162 if arg.dtype == np.float32: 163 widest_type = np.float32 164 for arr, index, arg in symbols: 165 if arg.dtype != widest_type and arg.dtype == target_dtype: 166 arr[index] = ndarray.amp_cast(arg, dtype=widest_type) 167 else: 168 # Symbol case 169 sym_to_check = list(map(lambda x: x[2], symbols)) 170 casted_syms = symbol.amp_multicast(*sym_to_check, num_outputs=len(sym_to_check)) 171 symbols = list(map(lambda x_y: (x_y[0][0], x_y[0][1], x_y[1]), 172 zip(symbols, casted_syms))) 173 for arr, index, arg in symbols: 174 arr[index] = arg 175 176 return f(*args, **kwargs) 177 _new_fun.__name__ = f.__name__ 178 _new_fun.__module__ = f.__module__ 179 _new_fun.__doc__ = f.__doc__ 180 return _new_fun 181 182 _wrapper = _symbol_wrapper if module in (symbol, Symbol, symbol_contrib) else _ndarray_wrapper 183 184 submodule_dict = {} 185 for op_name_prefix in base._OP_NAME_PREFIX_LIST: 186 submodule_dict[op_name_prefix] =\ 187 getattr(module, op_name_prefix[1:-1]) 188 fp32_param_list = list_lp16_use_fp32_params(target_dtype) 189 wrap_list = target_precision_ops if target_precision_ops is not None \ 190 else list_lp16_ops(target_dtype) 191 for fun_name in wrap_list: 192 try: 193 fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict) 194 f_to_wrap = getattr(cur_module, fun_name) 195 fp32_param = fp32_param_list[fun_name] if (fp32_param_list and fun_name in fp32_param_list) else None 196 setattr(cur_module, fun_name, _wrapper(f_to_wrap, target_dtype, fp32_param=fp32_param)) 197 if cur_module == module: 198 setattr(module.op, fun_name, _wrapper(f_to_wrap, target_dtype, fp32_param=fp32_param)) 199 except AttributeError: 200 raise 201 202 wrap_list = fp32_ops if fp32_ops is not None else list_fp32_ops(target_dtype) 203 for fun_name in wrap_list: 204 try: 205 fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict) 206 f_to_wrap = getattr(cur_module, fun_name) 207 setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32)) 208 if cur_module == module: 209 setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32)) 210 except AttributeError: 211 raise 212 213 wrap_list = conditional_fp32_ops if conditional_fp32_ops is not None \ 214 else list_conditional_fp32_ops(target_dtype) 215 for fun_name, arg, arg_values in wrap_list: 216 try: 217 fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict) 218 f_to_wrap = getattr(cur_module, fun_name) 219 setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32, cond_arg=(arg, arg_values))) 220 if cur_module == module: 221 setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32, cond_arg=(arg, arg_values))) 222 except AttributeError: 223 raise 224 225 226 for fun_name in list_widest_type_cast(target_dtype): 227 try: 228 fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict) 229 f_to_wrap = getattr(cur_module, fun_name) 230 setattr(cur_module, fun_name, _symbol_widest_wrapper(f_to_wrap)) 231 if cur_module == module: 232 setattr(module.op, fun_name, _symbol_widest_wrapper(f_to_wrap)) 233 except AttributeError: 234 raise 235 236def _wrap_loss_output_functions(module, ls, target_dtype): 237 if module == ndarray: 238 def _wrapper(f): 239 def _scaling_wrapper(*args, **kwargs): 240 if 'grad_scale' in kwargs: 241 kwargs['grad_scale'] = kwargs['grad_scale'] * ls.loss_scale 242 else: 243 kwargs['grad_scale'] = ls.loss_scale 244 return f(*args, **kwargs) 245 _scaling_wrapper.__name__ = f.__name__ 246 _scaling_wrapper.__module__ = f.__module__ 247 _scaling_wrapper.__doc__ = f.__doc__ 248 return _scaling_wrapper 249 else: 250 def _wrapper(f): 251 def _warning_wrapper(*args, **kwargs): 252 logging.warning("%s does not support dynamic loss scaling " 253 "in symbolic and hybridized execution.", f.__name__) 254 return f(*args, **kwargs) 255 _warning_wrapper.__name__ = f.__name__ 256 _warning_wrapper.__module__ = f.__module__ 257 _warning_wrapper.__doc__ = f.__doc__ 258 return _warning_wrapper 259 260 for fun_name in list_loss_output_functions(target_dtype): 261 try: 262 f_to_wrap = getattr(module, fun_name) 263 setattr(module, fun_name, _wrapper(f_to_wrap)) 264 except AttributeError: 265 pass 266 267_amp_initialized = False 268_amp_loss_scale_initialized = False 269_loss_scaler = None 270 271@contextlib.contextmanager 272def scale_loss(loss, optimizer_or_trainer): 273 assert optimizer_or_trainer._amp_loss_scaler is not None, \ 274 'Loss scaler is not initialized, did you forget to call amp.init_trainer()?' 275 optimizer_or_trainer._scale = (optimizer_or_trainer._amp_original_scale / 276 optimizer_or_trainer._amp_loss_scaler.loss_scale) 277 if isinstance(loss, (list, tuple)): 278 yield [l * optimizer_or_trainer._amp_loss_scaler.loss_scale for l in loss] 279 else: 280 yield optimizer_or_trainer._amp_loss_scaler.loss_scale * loss 281 282def init(target_dtype='float16', target_precision_ops=None, 283 conditional_fp32_ops=None, fp32_ops=None): 284 """Initialize AMP (automatic mixed precision). 285 286 This needs to be done before model creation. 287 288 Parameters 289 ---------- 290 target_dtype : {'float16', 'bfloat16'} 291 Target low precision type for AMP. Currently only float16 and bfloat16 are supported. 292 target_precision_ops : list of string 293 Override the list of functions casted to target_dtype. Entries in this list 294 are names of the functions casted to target_dtype. 295 conditional_fp32_ops : list of (string, string, list of string) 296 Override the list of functions conditionally casted to FP32. The format 297 of the list is (name of the function, name of the parameter, list of 298 values of the parameter that make the function be casted to FP32). 299 fp32_ops : list of string 300 Override the list of functions casted to FP32. Entries in this list 301 are names of the functions casted to FP32. 302 """ 303 global _amp_initialized 304 global _loss_scaler 305 if not _amp_initialized: 306 assert target_dtype in ['float16', np.float16, 'bfloat16', bfloat16], \ 307 "AMP currently supports only float16 or bfloat16 as a target_dtype" 308 _amp_initialized = True 309 logging.info("Using AMP") 310 if target_dtype == "bfloat16": 311 target_dtype = bfloat16 312 else: 313 target_dtype = np.dtype(target_dtype) 314 _wrap_symbol_functions(symbol, target_dtype, target_precision_ops, 315 conditional_fp32_ops, fp32_ops) 316 _wrap_symbol_functions(ndarray, target_dtype, target_precision_ops, 317 conditional_fp32_ops, fp32_ops) 318 _loss_scaler = LossScaler() 319 _wrap_loss_output_functions(ndarray, _loss_scaler, target_dtype) 320 _wrap_loss_output_functions(symbol, _loss_scaler, target_dtype) 321 322def init_trainer(optimizer_or_trainer): 323 """Initialize trainer or optimizer to work with AMP dynamic loss scaling. 324 325 Parameters 326 ---------- 327 optimizer_or_trainer : Optimizer or Trainer 328 MXNet Optimizer or Gluon trainer to initialize with AMP 329 """ 330 global _amp_loss_scale_initialized 331 global _amp_initialized 332 global _loss_scaler 333 assert _amp_initialized, "AMP not initialized, did you forget to call amp.init()?" 334 if not _amp_loss_scale_initialized: 335 _amp_loss_scale_initialized = True 336 loss_scaler = _loss_scaler 337 else: 338 loss_scaler = LossScaler() 339 #_wrap_output 340 if isinstance(optimizer_or_trainer, trainer.Trainer): 341 optimizer_or_trainer._amp_loss_scaler = loss_scaler 342 optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale 343 elif isinstance(optimizer_or_trainer, opt.Optimizer): 344 # TODO(ptredak): make it work with the optimizer 345 raise TypeError("AMP is currently only compatible with Gluon Trainer") 346 else: 347 raise TypeError("optimizer_or_trainer should be a Gluon Trainer or " 348 "an optimizer, instead is %s" % type(optimizer_or_trainer)) 349 350def unscale(optimizer_or_trainer): 351 """Check and unscale the gradients manually. This function should only be used 352 if accessing gradients is necessary, e.g. for gradient clipping. 353 354 Parameters 355 ---------- 356 optimizer_or_trainer : Optimizer or Trainer 357 MXNet optimizer or Gluon Trainer used when scaling the gradients 358 """ 359 if isinstance(optimizer_or_trainer, trainer.Trainer): 360 valid_grads = [p._grad for p in optimizer_or_trainer._params if p._grad is not None] 361 for grads in valid_grads: 362 # TODO(ptredak): make a bulked unscale 363 for g in grads: 364 g[:] *= optimizer_or_trainer._scale 365 optimizer_or_trainer._scale = 1. 366 elif isinstance(optimizer_or_trainer, opt.Optimizer): 367 # TODO(ptredak): make it work with the optimizer 368 raise TypeError("AMP is currently only compatible with Gluon Trainer") 369 else: 370 raise TypeError("optimizer_or_trainer should be a Gluon Trainer or " 371 "an optimizer, instead is %s" % type(optimizer_or_trainer)) 372 373def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None, 374 fp32_ops=None, conditional_fp32_ops=None, 375 excluded_sym_names=None, data_names=None, 376 cast_optional_params=False): 377 """Given a symbol object representing a neural network of data type FP32 and target_dtype, 378 add cast layers according to the op lists (target_dtype_ops, fp32_ops, 379 conditional_fp32_ops) if provided, otherwise use the default 380 lists provided by the framework. 381 382 Parameters 383 ---------- 384 sym : Symbol 385 FP32 neural network symbol 386 target_dtype : str or numpy, optional defaults to float16 387 currently only supports float16 and bfloat16. The target dtype indicates to add cast layers 388 when possible so that lower precision computation can be leveraged. 389 target_dtype_ops : list of strs, optional 390 Override the list of operator names casted to the target_dtype. 391 If None, uses the framework's default list to be casted to target_dtype. 392 fp32_ops : list of strs, optional 393 Override the list of operator names casted to FP32. 394 If None, uses the framework's default list to be casted to FP32. 395 conditional_fp32_ops : list of (string, string, list of string), optional 396 Override the list of functions to be casted to FP32. 397 The format of the list is 398 (name of the function, name of the parameter, 399 list of values of the parameter that make the operator to be casted to FP32) 400 excluded_sym_names : list of strs, optional 401 A list of strings that represent the names of symbols that users want to exclude 402 from being casted to LP16 or FP32. 403 data_names : list of strs, optional 404 A list of strings that represent input data tensor names to the model 405 cast_optional_params : bool, default False 406 Whether to cast the arg_params and aux_params that don't require to be in LP16 407 because of a cast layer following it, but will reduce the computation and memory 408 overhead of the model if casted. 409 """ 410 assert isinstance(sym, Symbol), "First argument to convert_symbol should be Symbol" 411 412 assert target_dtype in ['float16', 'bfloat16'], \ 413 "Only target_dtype float16 and bfloat16 are supported currently" 414 415 if target_dtype == 'bfloat16': 416 target_dtype = bfloat16 417 418 if target_dtype_ops is not None: 419 assert isinstance(target_dtype_ops, list), "target_dtype_ops should be a list of strs" 420 else: 421 target_dtype_ops = list_lp16_ops(target_dtype) 422 423 if fp32_ops is not None: 424 assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs" 425 else: 426 fp32_ops = list_fp32_ops(target_dtype) 427 428 if conditional_fp32_ops is not None: 429 assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops should be a list" 430 else: 431 conditional_fp32_ops = list_conditional_fp32_ops(target_dtype) 432 433 original_conditional_op_names = [] 434 conditional_op_names = [] 435 param_names = [] 436 param_vals = [] 437 indptr = [0] 438 for conditional_fp32_op in conditional_fp32_ops: 439 assert isinstance(conditional_fp32_op[0], str) and isinstance(conditional_fp32_op[1], str) \ 440 and isinstance(conditional_fp32_op[2], list), "conditional_fp32_ops should be a list of " \ 441 "(str, str, list of str)" 442 param_vals += conditional_fp32_op[2] 443 indptr.append(len(param_vals)) 444 param_names.append(conditional_fp32_op[1]) 445 conditional_op_names.append(conditional_fp32_op[0]) 446 447 if excluded_sym_names is not None: 448 assert isinstance(excluded_sym_names, list), "excluded_sym_names should be a list of strs" 449 else: 450 excluded_sym_names = [] 451 452 for original_conditional_fp32_op in list_conditional_fp32_ops(target_dtype): 453 original_conditional_op_names.append(original_conditional_fp32_op[0]) 454 455 # Op lists should not have intersection 456 common_ops = set(target_dtype_ops) & set(fp32_ops) 457 assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \ 458 "Common ops in target_dtype_ops and fp32_ops {}".format(common_ops) 459 common_ops = set(target_dtype_ops) & set(conditional_op_names) 460 assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \ 461 "Common ops in target_dtype_ops and conditional_fp32_ops {}".format(common_ops) 462 common_ops = set(conditional_op_names) & set(fp32_ops) 463 assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \ 464 "Common ops in fp32_ops and conditional_fp32_ops {}".format(common_ops) 465 466 combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names) 467 all_lp16_fp32_ops = set(list_lp16_ops(target_dtype) + list_fp32_ops(target_dtype) 468 + list_lp16_fp32_ops(target_dtype) + original_conditional_op_names) 469 470 illegal_ops = combined_ops - all_lp16_fp32_ops 471 assert not illegal_ops, '''Can only choose ops from one of the three lists 472 for lp16_ops and fp32_ops 473 1. amp.list_lp16_ops(target_dtype) 474 2. amp.list_fp32_ops(target_dtype) 475 3. amp.list_lp16_fp32_ops(target_dtype) 476 4. amp.list_conditional_fp32_ops(target_dtype) 477 Op %s not in any of them''' % (illegal_ops) 478 479 widest_dtype_ops = list_widest_type_cast(target_dtype) 480 if target_dtype == bfloat16: 481 target_dtype = _DTYPE_NP_TO_MX[bfloat16] 482 else: 483 target_dtype = _DTYPE_NP_TO_MX[np.dtype(target_dtype).type] 484 485 # Prepare a data_names list based on list_inputs if its not provided 486 # Add all names in list for the nodes in the symbol which don't have 487 # __dtype__ set 488 attr_dict = sym.attr_dict() 489 if data_names is None: 490 data_names = [] 491 for sym_name in sym.list_inputs(): 492 if not sym_name in attr_dict: 493 data_names.append(sym_name) 494 continue 495 if not "__dtype__" in attr_dict[sym_name]: 496 data_names.append(sym_name) 497 model_param_names = list(set(sym.list_inputs()) - set(data_names)) 498 499 # Since assumption is that it is a FP32 model, set dtypes for all 500 # data_names to float32 501 str_keys = [] 502 sdata = [] 503 for k in data_names: 504 str_keys.append(k) 505 sdata.append(0) 506 keys = c_str_array(str_keys) 507 out = SymbolHandle() 508 check_call(_LIB.MXReducePrecisionSymbol(sym.handle, 509 ctypes.byref(out), 510 mx_uint(len(sdata)), 511 c_array_buf(ctypes.c_int, array('i', sdata)), 512 mx_uint(len(indptr)), 513 c_array_buf(ctypes.c_int, array('i', indptr)), 514 ctypes.byref(ctypes.c_int(target_dtype)), 515 ctypes.c_int(cast_optional_params), 516 mx_uint(len(target_dtype_ops)), 517 mx_uint(len(fp32_ops)), 518 mx_uint(len(widest_dtype_ops)), 519 mx_uint(len(conditional_op_names)), 520 mx_uint(len(excluded_sym_names)), 521 mx_uint(len(model_param_names)), 522 c_str_array(target_dtype_ops), 523 c_str_array(fp32_ops), 524 c_str_array(widest_dtype_ops), 525 c_str_array(conditional_op_names), 526 c_str_array(excluded_sym_names), 527 c_str_array(param_names), 528 c_str_array(param_vals), 529 c_str_array(model_param_names), 530 keys)) 531 return Symbol(out) 532 533def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dtype_ops=None, 534 fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=None, 535 cast_optional_params=False): 536 """API for converting a model from FP32 model to a mixed precision model. 537 MXNet tries to convert the FP32 model to mixed precision model by adding 538 cast layers using amp_cast and amp_multicast operators which can be used for inference use cases. 539 The decision on which cast layer to add is based on hardcoded lists for Automatic Mixed Precision 540 in MXNet. These lists can be overridden by the user by providing their own lists 541 using : targe_precision_ops, fp32_ops, widest_precision_ops, conditional_fp32_ops 542 543 arg_params : dict 544 Dictionary of name to `NDArray`. 545 aux_params : dict 546 Dictionary of name to `NDArray`. 547 target_dtype : str 548 Currently only supports float16 and bfloat 16. The target dtype indicates to add cast layers 549 when possible so that lower precision computation can be leveraged. 550 target_dtype_ops : list of strs 551 Override the list of operator names casted to target_dtype. 552 If None, uses the framework's default list to be casted to target dtype. 553 fp32_ops : list of strs 554 Override the lists of operator names casted to FP32. 555 If None, uses the framework's default list to be casted to FP32. 556 widest_dtype_ops : list of strs 557 A list of op names provided by user which should run in widest precision among its inputs. 558 If None, uses the framework's default list of widest_precision_ops. 559 conditional_fp32_ops : list of (string, string, list of string) 560 Override the list of operators to be casted to FP32. 561 The format of the list is 562 (name of the function, name of the parameter, 563 list of values of the parameter that make the operator to be casted to 564 fp32) 565 excluded_sym_names : list of strs 566 A list of strings that represent the names of symbols that users want to exclude 567 from being executed in lower precision. 568 cast_optional_params : bool, default False 569 Whether to cast the arg_params and aux_params that don't require to be in LP16 570 because of a cast layer following it, but will reduce the computation and memory 571 overhead of the model if casted. 572 """ 573 if excluded_sym_names is None: 574 excluded_sym_names = [] 575 if not isinstance(excluded_sym_names, list): 576 raise ValueError('excluded_sym_names must be a list of strings representing' 577 ' the names of the symbols that should not be casted,' 578 ' while received type %s' % str(type(excluded_sym_names))) 579 assert target_dtype in ['float16', 'bfloat16'], \ 580 "Only target_dtype float16 and bfloat16 are supported currently" 581 582 assert isinstance(sym, Symbol), "First argument to convert_model should be Symbol" 583 assert isinstance(arg_params, dict), "Second argument to convert_model should be a dict of name to ndarray" 584 assert isinstance(aux_params, dict), "Third argument to convert_model should be a dict of name to ndarray" 585 586 param_names = list(arg_params.keys()) + list(aux_params.keys()) 587 588 # Only pass non params as data_names, param types can be inferred 589 data_names = list(set(sym.list_inputs()) - set(param_names)) 590 sym = convert_symbol(sym, target_dtype, target_dtype_ops, 591 fp32_ops, conditional_fp32_ops, 592 excluded_sym_names, data_names, 593 cast_optional_params) 594 595 # If dtype is set for params, cast the param to that dtype 596 attr_dict = sym.attr_dict() 597 for sym_name in sym.list_arguments(): 598 if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]: 599 if attr_dict[sym_name]["__dtype__"] != "-1": 600 typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])] 601 if typ == bfloat16: 602 arg_params[sym_name] = _cast_symbol_NDArray(arg_params[sym_name], bfloat16) 603 else: 604 arg_params[sym_name] = arg_params[sym_name].astype(typ) 605 606 for sym_name in sym.list_auxiliary_states(): 607 if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]: 608 if attr_dict[sym_name]["__dtype__"] != "-1": 609 typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])] 610 if typ == bfloat16: 611 aux_params[sym_name] = _cast_symbol_NDArray(aux_params[sym_name], bfloat16) 612 else: 613 aux_params[sym_name] = aux_params[sym_name].astype(typ) 614 615 # Return the converted symbol and casted params 616 return sym, arg_params, aux_params 617 618def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None, 619 fp32_ops=None, conditional_fp32_ops=None, 620 excluded_sym_names=None, ctx=gpu(0), 621 cast_optional_params=False): 622 """Given a hybrid block/symbol block representing a FP32 model and a target_dtype, 623 return a block with mixed precision support which can be used for inference use cases. 624 625 Parameters 626 ---------- 627 block : HybridBlock or SymbolBlock object 628 FP32 HybridBlock or SymbolBlock object 629 target_dtype : str or numpy 630 currently only supports float16 and bfloat16. The target dtype indicates to add cast layers 631 when possible so that lower precision computation can be leveraged. 632 target_precision_ops : list of strs 633 Override the list of operator names casted to target_dtype. 634 If None, uses the framework's default list to be casted to FP32. 635 conditional_fp32_ops : list of (str, str, list of str) 636 Override the list of functions to be casted to FP32. 637 The format of the list is 638 (name of the function, name of the parameter, 639 list of values of the parameter that make the operator to be casted to FP32 640 excluded_sym_names : list of strs 641 A list of strings that represent the names of symbols that users want to exclude 642 from being quantized 643 ctx : Context 644 Context on which model parameters should live 645 cast_optional_params : bool, default False 646 Whether to cast the arg_params and aux_params that don't require to be in LP16 647 because of a cast layer following it, but will reduce the computation and memory 648 overhead of the model if casted. 649 """ 650 from ...gluon import HybridBlock, SymbolBlock 651 assert isinstance(block, HybridBlock), "block input should be a HybridBlock" 652 if not block._cached_graph: 653 raise RuntimeError( 654 "Please first call block.hybridize() and then run forward with " 655 "this block at least once before calling export.") 656 657 # Prepare inputs to pass to the convert_symbol API 658 inputs, sym = block._cached_graph 659 input_names = [] 660 for inp in inputs: 661 input_names.append(inp.name) 662 converted_sym = convert_symbol(sym, target_dtype, target_dtype_ops, 663 fp32_ops, conditional_fp32_ops, 664 excluded_sym_names, data_names=input_names, 665 cast_optional_params=cast_optional_params) 666 667 arg_names = set(converted_sym.list_arguments()) 668 aux_names = set(converted_sym.list_auxiliary_states()) 669 arg_dict = {} 670 671 # If dtype for the param was set in the json, cast the 672 # param to this dtype 673 attr_dict = converted_sym.attr_dict() 674 for name, param in block.collect_params().items(): 675 if name in arg_names: 676 arg_dict['arg:%s'%name] = param._reduce() 677 if name in attr_dict and "__dtype__" in attr_dict[name]: 678 if attr_dict[name]["__dtype__"] != "-1": 679 typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])] 680 if typ == bfloat16: 681 arg_dict['arg:%s' % name] = _cast_symbol_NDArray(arg_dict['arg:%s' % name], bfloat16) 682 else: 683 arg_dict['arg:%s'%name] = arg_dict['arg:%s'%name].astype(typ) 684 else: 685 assert name in aux_names 686 arg_dict['aux:%s'%name] = param._reduce() 687 if name in attr_dict and "__dtype__" in attr_dict[name]: 688 if attr_dict[name]["__dtype__"] != "-1": 689 typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])] 690 if typ == bfloat16: 691 arg_dict['aux:%s' % name] = _cast_symbol_NDArray(arg_dict['aux:%s' % name], 'bfloat16') 692 else: 693 arg_dict['aux:%s'%name] = arg_dict['aux:%s'%name].astype(typ) 694 695 # Create a symbolblock and cast the params to the dtypes based 696 # on the dtype information from the converted_symbol 697 ret = SymbolBlock(converted_sym, inputs) 698 for key, param in ret.collect_params().items(): 699 arg_param_name = "arg:%s" % key 700 if arg_param_name in arg_dict and param.dtype != arg_dict[arg_param_name].dtype: 701 param.cast(arg_dict[arg_param_name].dtype) 702 703 aux_param_name = "aux:%s" % key 704 if aux_param_name in arg_dict and param.dtype != arg_dict[aux_param_name].dtype: 705 param.cast(arg_dict[aux_param_name].dtype) 706 707 ret.collect_params().load_dict(arg_dict, ctx=ctx) 708 return ret 709 710def convert_bucketing_module(bucketing_mod, target_dtype="float16", target_dtype_ops=None, 711 fp32_ops=None, conditional_fp32_ops=None, 712 excluded_sym_names=None, cast_optional_params=False): 713 """Given a bucketing module cast the symbols associated with the BucketingModule 714 and params if cast_optional_params is set. 715 bucketing_mod : BucketingModule instance 716 target_dtype : str 717 Currently only supports float16. The target dtype indicates to add cast layers 718 when possible so that lower precision computation can be leveraged. 719 target_dtype_ops : list of strs 720 Override the list of operator names casted to target_dtype. 721 If None, uses the framework's default list to be casted to target dtype. 722 fp32_ops : list of strs 723 Override the lists of operator names casted to FP32. 724 If None, uses the framework's default list to be casted to FP32. 725 widest_dtype_ops : list of strs 726 A list of op names provided by user which should run in widest precision among its inputs. 727 If None, uses the framework's default list of widest_precision_ops. 728 conditional_fp32_ops : list of (string, string, list of string) 729 Override the list of operators to be casted to FP32. 730 The format of the list is 731 (name of the function, name of the parameter, 732 list of values of the parameter that make the operator to be casted to 733 fp32) 734 excluded_sym_names : list of strs 735 A list of strings that represent the names of symbols that users want to exclude 736 from being executed in lower precision. 737 cast_optional_params : bool, default False 738 Whether to cast the arg_params and aux_params that don't require to be in LP16 739 because of a cast layer following it, but will reduce the computation and memory 740 overhead of the model if casted. 741 """ 742 assert isinstance(bucketing_mod, BucketingModule), "module should be instance of bucketing module" 743 assert len(bucketing_mod._buckets) > 0, "Bucketing Module should not be empty" 744 745 sym_dict = {} 746 assert bucketing_mod.params_initialized, \ 747 "bucketing_mod params should be initialized for mixed precision conversion" 748 arg_params, aux_params = bucketing_mod._curr_module._arg_params, bucketing_mod._curr_module._aux_params 749 for key, val in bucketing_mod._buckets.items(): 750 sym_dict[key], result_arg_params, result_aux_params = convert_model(val._symbol, 751 arg_params, 752 aux_params, 753 target_dtype=target_dtype, 754 target_dtype_ops=target_dtype_ops, 755 fp32_ops=fp32_ops, 756 conditional_fp32_ops=conditional_fp32_ops, 757 excluded_sym_names=excluded_sym_names, 758 cast_optional_params=cast_optional_params) 759 result_mod = BucketingModule.load_dict(sym_dict, 760 sym_gen=bucketing_mod._sym_gen, 761 arg_params=result_arg_params, 762 aux_params=result_aux_params, 763 default_bucket_key=bucketing_mod._default_bucket_key, 764 logger=bucketing_mod.logger, 765 context=bucketing_mod._context, 766 work_load_list=bucketing_mod._work_load_list, 767 fixed_param_names=bucketing_mod._fixed_param_names, 768 state_names=bucketing_mod._state_names, 769 group2ctxs=bucketing_mod._group2ctxs, 770 compression_params=bucketing_mod._compression_params) 771 return result_mod 772 773def list_lp16_ops(target_dtype): 774 """Get the default list of LP16 ops for AMP 775 """ 776 if target_dtype in ['float16', np.float16]: 777 return lists.symbol_fp16.FP16_FUNCS 778 else: 779 assert (target_dtype == bfloat16), "not supported type" 780 return lists.symbol_bf16.BF16_FUNCS 781 782def list_fp32_ops(target_dtype): 783 """Get the default list of FP32 ops for AMP 784 """ 785 if target_dtype in ['float16', np.float16]: 786 return lists.symbol_fp16.FP32_FUNCS 787 else: 788 assert (target_dtype == bfloat16), "not supported type" 789 return lists.symbol_bf16.FP32_FUNCS 790 791def list_lp16_fp32_ops(target_dtype): 792 """Get the default list of ops which run in both LP16 and FP32 793 """ 794 if target_dtype in ['float16', np.float16]: 795 return lists.symbol_fp16.FP16_FP32_FUNCS 796 else: 797 assert (target_dtype == bfloat16), "not supported type" 798 return lists.symbol_bf16.BF16_FP32_FUNCS 799 800def list_conditional_fp32_ops(target_dtype): 801 """Get the conditional fp32 ops list 802 """ 803 if target_dtype in ['float16', np.float16]: 804 return lists.symbol_fp16.CONDITIONAL_FP32_FUNCS 805 else: 806 assert (target_dtype == bfloat16), "not supported type" 807 return lists.symbol_bf16.CONDITIONAL_FP32_FUNCS 808 809def list_widest_type_cast(target_dtype): 810 """Get the widest type cast ops list 811 """ 812 if target_dtype in ['float16', np.float16]: 813 return lists.symbol_fp16.WIDEST_TYPE_CASTS 814 else: 815 assert (target_dtype == bfloat16), "not supported type" 816 return lists.symbol_bf16.WIDEST_TYPE_CASTS 817 818def list_loss_output_functions(target_dtype): 819 """Get loss function list 820 """ 821 if target_dtype in ['float16', np.float16]: 822 return lists.symbol_fp16.LOSS_OUTPUT_FUNCTIONS 823 else: 824 assert (target_dtype == bfloat16), "not supported type" 825 return lists.symbol_bf16.LOSS_OUTPUT_FUNCTIONS 826 827def list_lp16_use_fp32_params(target_dtype): 828 """ Get the params restrict for LP16 829 830 """ 831 if target_dtype in ['float16', np.float16]: 832 return None 833 else: 834 assert (target_dtype == bfloat16), "not supported type" 835 return lists.symbol_bf16.BF16_USE_FP32_PARAMS 836