1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19"""Functions for enabling AMP (automatic mixed precision)."""
20__all__ = ['init', 'init_trainer', 'scale_loss', 'unscale', 'convert_model',
21           'convert_hybrid_block', 'list_lp16_ops', 'list_fp32_ops',
22           'list_lp16_fp32_ops', 'list_conditional_fp32_ops',
23           'list_widest_type_cast', 'list_loss_output_functions', 'list_lp16_use_fp32_params',
24           'convert_symbol']
25
26from array import array
27import ctypes
28import logging
29import contextlib
30import numpy as np
31
32from ... import symbol
33from ...context import gpu
34from ...symbol import Symbol
35from ...module import BucketingModule
36from ...symbol import contrib as symbol_contrib
37from ... import ndarray
38from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
39from . import lists
40from ...gluon import trainer
41from ... import base
42from ...base import c_str_array, SymbolHandle, check_call, _LIB, mx_uint, c_array_buf
43from ... import optimizer as opt
44from .loss_scaler import LossScaler
45
46bfloat16 = np.dtype([('bfloat16', np.uint16)])
47
48def _cast_symbol_NDArray(s, dtype):
49    float_types_gpu = (np.float16, np.float32)
50    float_types_cpu = (bfloat16, np.float32)
51    if isinstance(s, Symbol):
52        return symbol.amp_cast(s, dtype=dtype)
53    elif isinstance(s, NDArray):
54        if (s.dtype != dtype and s.dtype in float_types_gpu and s.context.device_type != 'cpu'):
55            return ndarray.amp_cast(s, dtype=dtype)
56        elif (s.dtype != dtype and s.dtype in float_types_cpu and s.context.device_type == 'cpu'):
57            return ndarray.amp_cast(s, dtype=dtype)
58        else:
59            return s
60    else:
61        return s
62
63def _get_fun_to_wrap(name, module, submodule_dict):
64    module_internal = getattr(module, "_internal")
65    prefix = base._get_op_name_prefix(name)
66    if len(prefix) > 0:
67        if prefix != '_random_' or name.endswith('_like'):
68            func_name = name[len(prefix):]
69            cur_module = submodule_dict[prefix]
70        else:
71            func_name = name
72            cur_module = module_internal
73    elif name.startswith('_'):
74        func_name = name
75        cur_module = module_internal
76    else:
77        func_name = name
78        cur_module = module
79    return func_name, cur_module
80
81def _wrap_symbol_functions(module, target_dtype, target_precision_ops=None,
82                           conditional_fp32_ops=None, fp32_ops=None):
83    def _ndarray_wrapper(f, target_dtype, fp32_param=None, cond_arg=None):
84        def _new_fun(*args, **kwargs):
85            if cond_arg is not None:
86                if (cond_arg[0] not in kwargs or
87                        kwargs[cond_arg[0]] not in cond_arg[1]):
88                    return f(*args, **kwargs)
89            if fp32_param:
90                new_args = []
91                for i, x in enumerate(args):
92                    if fp32_param[i]:
93                        new_args.append(x)
94                    else:
95                        new_args.append(_cast_symbol_NDArray(x, target_dtype))
96            else:
97                new_args = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype), args))
98            args = tuple(new_args)
99            if fp32_param:
100                new_kwargs = {}
101                for k, v in kwargs.items():
102                    if k in fp32_param:
103                        new_kwargs[k] = v
104                    else:
105                        new_kwargs[k] = _cast_symbol_NDArray(v, target_dtype)
106                    kwargs = new_kwargs
107            else:
108                kwargs = {k: _cast_symbol_NDArray(v, target_dtype) for k, v in kwargs.items()}
109            return f(*args, **kwargs)
110        _new_fun.__name__ = f.__name__
111        _new_fun.__module__ = f.__module__
112        _new_fun.__doc__ = f.__doc__
113        return _new_fun
114
115    def _symbol_wrapper(f, target_dtype, fp32_param=None, cond_arg=None):
116        def _new_fun(*args, **kwargs):
117            if cond_arg is not None:
118                if (cond_arg[0] not in kwargs or
119                        kwargs[cond_arg[0]] not in cond_arg[1]):
120                    return f(*args, **kwargs)
121            sym = f(*args, **kwargs)
122            inputs = sym.get_children()
123            aux = sym.list_auxiliary_states()
124            if fp32_param:
125                new_inputs = []
126                for i, x in enumerate(inputs):
127                    if (x.name in aux) or fp32_param[i]:
128                        new_inputs.append(x)
129                    else:
130                        new_inputs.append(_cast_symbol_NDArray(x, target_dtype))
131                inputs = new_inputs
132            else:
133                inputs = list(map(lambda x: _cast_symbol_NDArray(x, target_dtype)
134                                  if x.name not in aux else x, inputs))
135            atomic_sym = sym._gen_atomic_symbol()
136            wrapped_sym = atomic_sym(*inputs)
137            wrapped_sym._set_attr(name=sym.name)
138            return wrapped_sym
139        _new_fun.__name__ = f.__name__
140        _new_fun.__module__ = f.__module__
141        _new_fun.__doc__ = f.__doc__
142        return _new_fun
143
144    def _symbol_widest_wrapper(f):
145        def _new_fun(*args, **kwargs):
146            symbols = []
147            is_symbol = False
148            args = list(args)
149            for i, arg in enumerate(args):
150                if isinstance(arg, (Symbol, NDArray)):
151                    symbols.append((args, i, arg))
152                    is_symbol = is_symbol or isinstance(arg, Symbol)
153            for k, arg in kwargs.items():
154                if isinstance(arg, (Symbol, NDArray)):
155                    symbols.append((kwargs, k, arg))
156                    is_symbol = is_symbol or isinstance(arg, Symbol)
157            if not is_symbol:
158                # NDArray case
159                widest_type = target_dtype
160                for _, _, arg in symbols:
161                    if isinstance(arg, NDArray):
162                        if arg.dtype == np.float32:
163                            widest_type = np.float32
164                for arr, index, arg in symbols:
165                    if arg.dtype != widest_type and arg.dtype == target_dtype:
166                        arr[index] = ndarray.amp_cast(arg, dtype=widest_type)
167            else:
168                # Symbol case
169                sym_to_check = list(map(lambda x: x[2], symbols))
170                casted_syms = symbol.amp_multicast(*sym_to_check, num_outputs=len(sym_to_check))
171                symbols = list(map(lambda x_y: (x_y[0][0], x_y[0][1], x_y[1]),
172                                   zip(symbols, casted_syms)))
173                for arr, index, arg in symbols:
174                    arr[index] = arg
175
176            return f(*args, **kwargs)
177        _new_fun.__name__ = f.__name__
178        _new_fun.__module__ = f.__module__
179        _new_fun.__doc__ = f.__doc__
180        return _new_fun
181
182    _wrapper = _symbol_wrapper if module in (symbol, Symbol, symbol_contrib) else _ndarray_wrapper
183
184    submodule_dict = {}
185    for op_name_prefix in base._OP_NAME_PREFIX_LIST:
186        submodule_dict[op_name_prefix] =\
187                getattr(module, op_name_prefix[1:-1])
188    fp32_param_list = list_lp16_use_fp32_params(target_dtype)
189    wrap_list = target_precision_ops if target_precision_ops is not None \
190                    else list_lp16_ops(target_dtype)
191    for fun_name in wrap_list:
192        try:
193            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
194            f_to_wrap = getattr(cur_module, fun_name)
195            fp32_param = fp32_param_list[fun_name] if (fp32_param_list and fun_name in fp32_param_list) else None
196            setattr(cur_module, fun_name, _wrapper(f_to_wrap, target_dtype, fp32_param=fp32_param))
197            if cur_module == module:
198                setattr(module.op, fun_name, _wrapper(f_to_wrap, target_dtype, fp32_param=fp32_param))
199        except AttributeError:
200            raise
201
202    wrap_list = fp32_ops if fp32_ops is not None else list_fp32_ops(target_dtype)
203    for fun_name in wrap_list:
204        try:
205            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
206            f_to_wrap = getattr(cur_module, fun_name)
207            setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32))
208            if cur_module == module:
209                setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32))
210        except AttributeError:
211            raise
212
213    wrap_list = conditional_fp32_ops if conditional_fp32_ops is not None \
214                    else list_conditional_fp32_ops(target_dtype)
215    for fun_name, arg, arg_values in wrap_list:
216        try:
217            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
218            f_to_wrap = getattr(cur_module, fun_name)
219            setattr(cur_module, fun_name, _wrapper(f_to_wrap, np.float32, cond_arg=(arg, arg_values)))
220            if cur_module == module:
221                setattr(module.op, fun_name, _wrapper(f_to_wrap, np.float32, cond_arg=(arg, arg_values)))
222        except AttributeError:
223            raise
224
225
226    for fun_name in list_widest_type_cast(target_dtype):
227        try:
228            fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
229            f_to_wrap = getattr(cur_module, fun_name)
230            setattr(cur_module, fun_name, _symbol_widest_wrapper(f_to_wrap))
231            if cur_module == module:
232                setattr(module.op, fun_name, _symbol_widest_wrapper(f_to_wrap))
233        except AttributeError:
234            raise
235
236def _wrap_loss_output_functions(module, ls, target_dtype):
237    if module == ndarray:
238        def _wrapper(f):
239            def _scaling_wrapper(*args, **kwargs):
240                if 'grad_scale' in kwargs:
241                    kwargs['grad_scale'] = kwargs['grad_scale'] * ls.loss_scale
242                else:
243                    kwargs['grad_scale'] = ls.loss_scale
244                return f(*args, **kwargs)
245            _scaling_wrapper.__name__ = f.__name__
246            _scaling_wrapper.__module__ = f.__module__
247            _scaling_wrapper.__doc__ = f.__doc__
248            return _scaling_wrapper
249    else:
250        def _wrapper(f):
251            def _warning_wrapper(*args, **kwargs):
252                logging.warning("%s does not support dynamic loss scaling "
253                                "in symbolic and hybridized execution.", f.__name__)
254                return f(*args, **kwargs)
255            _warning_wrapper.__name__ = f.__name__
256            _warning_wrapper.__module__ = f.__module__
257            _warning_wrapper.__doc__ = f.__doc__
258            return _warning_wrapper
259
260    for fun_name in list_loss_output_functions(target_dtype):
261        try:
262            f_to_wrap = getattr(module, fun_name)
263            setattr(module, fun_name, _wrapper(f_to_wrap))
264        except AttributeError:
265            pass
266
267_amp_initialized = False
268_amp_loss_scale_initialized = False
269_loss_scaler = None
270
271@contextlib.contextmanager
272def scale_loss(loss, optimizer_or_trainer):
273    assert optimizer_or_trainer._amp_loss_scaler is not None, \
274        'Loss scaler is not initialized, did you forget to call amp.init_trainer()?'
275    optimizer_or_trainer._scale = (optimizer_or_trainer._amp_original_scale /
276                                   optimizer_or_trainer._amp_loss_scaler.loss_scale)
277    if isinstance(loss, (list, tuple)):
278        yield [l * optimizer_or_trainer._amp_loss_scaler.loss_scale for l in loss]
279    else:
280        yield optimizer_or_trainer._amp_loss_scaler.loss_scale * loss
281
282def init(target_dtype='float16', target_precision_ops=None,
283         conditional_fp32_ops=None, fp32_ops=None):
284    """Initialize AMP (automatic mixed precision).
285
286    This needs to be done before model creation.
287
288    Parameters
289    ----------
290    target_dtype : {'float16', 'bfloat16'}
291        Target low precision type for AMP. Currently only float16 and bfloat16 are supported.
292    target_precision_ops : list of string
293        Override the list of functions casted to target_dtype. Entries in this list
294        are names of the functions casted to target_dtype.
295    conditional_fp32_ops : list of (string, string, list of string)
296        Override the list of functions conditionally casted to FP32. The format
297        of the list is (name of the function, name of the parameter, list of
298        values of the parameter that make the function be casted to FP32).
299    fp32_ops : list of string
300        Override the list of functions casted to FP32. Entries in this list
301        are names of the functions casted to FP32.
302    """
303    global _amp_initialized
304    global _loss_scaler
305    if not _amp_initialized:
306        assert target_dtype in ['float16', np.float16, 'bfloat16', bfloat16], \
307               "AMP currently supports only float16 or bfloat16 as a target_dtype"
308        _amp_initialized = True
309        logging.info("Using AMP")
310        if target_dtype == "bfloat16":
311            target_dtype = bfloat16
312        else:
313            target_dtype = np.dtype(target_dtype)
314        _wrap_symbol_functions(symbol, target_dtype, target_precision_ops,
315                               conditional_fp32_ops, fp32_ops)
316        _wrap_symbol_functions(ndarray, target_dtype, target_precision_ops,
317                               conditional_fp32_ops, fp32_ops)
318        _loss_scaler = LossScaler()
319        _wrap_loss_output_functions(ndarray, _loss_scaler, target_dtype)
320        _wrap_loss_output_functions(symbol, _loss_scaler, target_dtype)
321
322def init_trainer(optimizer_or_trainer):
323    """Initialize trainer or optimizer to work with AMP dynamic loss scaling.
324
325    Parameters
326    ----------
327    optimizer_or_trainer : Optimizer or Trainer
328        MXNet Optimizer or Gluon trainer to initialize with AMP
329    """
330    global _amp_loss_scale_initialized
331    global _amp_initialized
332    global _loss_scaler
333    assert _amp_initialized, "AMP not initialized, did you forget to call amp.init()?"
334    if not _amp_loss_scale_initialized:
335        _amp_loss_scale_initialized = True
336        loss_scaler = _loss_scaler
337    else:
338        loss_scaler = LossScaler()
339    #_wrap_output
340    if isinstance(optimizer_or_trainer, trainer.Trainer):
341        optimizer_or_trainer._amp_loss_scaler = loss_scaler
342        optimizer_or_trainer._amp_original_scale = optimizer_or_trainer._scale
343    elif isinstance(optimizer_or_trainer, opt.Optimizer):
344        # TODO(ptredak): make it work with the optimizer
345        raise TypeError("AMP is currently only compatible with Gluon Trainer")
346    else:
347        raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
348                        "an optimizer, instead is %s" % type(optimizer_or_trainer))
349
350def unscale(optimizer_or_trainer):
351    """Check and unscale the gradients manually. This function should only be used
352    if accessing gradients is necessary, e.g. for gradient clipping.
353
354    Parameters
355    ----------
356    optimizer_or_trainer : Optimizer or Trainer
357        MXNet optimizer or Gluon Trainer used when scaling the gradients
358    """
359    if isinstance(optimizer_or_trainer, trainer.Trainer):
360        valid_grads = [p._grad for p in optimizer_or_trainer._params if p._grad is not None]
361        for grads in valid_grads:
362            # TODO(ptredak): make a bulked unscale
363            for g in grads:
364                g[:] *= optimizer_or_trainer._scale
365        optimizer_or_trainer._scale = 1.
366    elif isinstance(optimizer_or_trainer, opt.Optimizer):
367        # TODO(ptredak): make it work with the optimizer
368        raise TypeError("AMP is currently only compatible with Gluon Trainer")
369    else:
370        raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
371                        "an optimizer, instead is %s" % type(optimizer_or_trainer))
372
373def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
374                   fp32_ops=None, conditional_fp32_ops=None,
375                   excluded_sym_names=None, data_names=None,
376                   cast_optional_params=False):
377    """Given a symbol object representing a neural network of data type FP32 and target_dtype,
378    add cast layers according to the op lists (target_dtype_ops, fp32_ops,
379    conditional_fp32_ops) if provided, otherwise use the default
380    lists provided by the framework.
381
382    Parameters
383    ----------
384    sym : Symbol
385        FP32 neural network symbol
386    target_dtype : str or numpy, optional defaults to float16
387        currently only supports float16 and bfloat16. The target dtype indicates to add cast layers
388        when possible so that lower precision computation can be leveraged.
389    target_dtype_ops : list of strs, optional
390        Override the list of operator names casted to the target_dtype.
391        If None, uses the framework's default list to be casted to target_dtype.
392    fp32_ops : list of strs, optional
393        Override the list of operator names casted to FP32.
394        If None, uses the framework's default list to be casted to FP32.
395    conditional_fp32_ops : list of (string, string, list of string), optional
396        Override the list of functions to be casted to FP32.
397        The format of the list is
398        (name of the function, name of the parameter,
399         list of values of the parameter that make the operator to be casted to FP32)
400    excluded_sym_names : list of strs, optional
401        A list of strings that represent the names of symbols that users want to exclude
402        from being casted to LP16 or FP32.
403    data_names : list of strs, optional
404        A list of strings that represent input data tensor names to the model
405    cast_optional_params : bool, default False
406        Whether to cast the arg_params and aux_params that don't require to be in LP16
407        because of a cast layer following it, but will reduce the computation and memory
408        overhead of the model if casted.
409    """
410    assert isinstance(sym, Symbol), "First argument to convert_symbol should be Symbol"
411
412    assert target_dtype in ['float16', 'bfloat16'], \
413               "Only target_dtype float16 and bfloat16 are supported currently"
414
415    if target_dtype == 'bfloat16':
416        target_dtype = bfloat16
417
418    if target_dtype_ops is not None:
419        assert isinstance(target_dtype_ops, list), "target_dtype_ops should be a list of strs"
420    else:
421        target_dtype_ops = list_lp16_ops(target_dtype)
422
423    if fp32_ops is not None:
424        assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
425    else:
426        fp32_ops = list_fp32_ops(target_dtype)
427
428    if conditional_fp32_ops is not None:
429        assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops should be a list"
430    else:
431        conditional_fp32_ops = list_conditional_fp32_ops(target_dtype)
432
433    original_conditional_op_names = []
434    conditional_op_names = []
435    param_names = []
436    param_vals = []
437    indptr = [0]
438    for conditional_fp32_op in conditional_fp32_ops:
439        assert isinstance(conditional_fp32_op[0], str) and isinstance(conditional_fp32_op[1], str) \
440            and isinstance(conditional_fp32_op[2], list), "conditional_fp32_ops should be a list of " \
441                                                          "(str, str, list of str)"
442        param_vals += conditional_fp32_op[2]
443        indptr.append(len(param_vals))
444        param_names.append(conditional_fp32_op[1])
445        conditional_op_names.append(conditional_fp32_op[0])
446
447    if excluded_sym_names is not None:
448        assert isinstance(excluded_sym_names, list), "excluded_sym_names should be a list of strs"
449    else:
450        excluded_sym_names = []
451
452    for original_conditional_fp32_op in list_conditional_fp32_ops(target_dtype):
453        original_conditional_op_names.append(original_conditional_fp32_op[0])
454
455    # Op lists should not have intersection
456    common_ops = set(target_dtype_ops) & set(fp32_ops)
457    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
458                                 "Common ops in target_dtype_ops and fp32_ops {}".format(common_ops)
459    common_ops = set(target_dtype_ops) & set(conditional_op_names)
460    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
461                                 "Common ops in target_dtype_ops and conditional_fp32_ops {}".format(common_ops)
462    common_ops = set(conditional_op_names) & set(fp32_ops)
463    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
464                                 "Common ops in fp32_ops and conditional_fp32_ops {}".format(common_ops)
465
466    combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names)
467    all_lp16_fp32_ops = set(list_lp16_ops(target_dtype) + list_fp32_ops(target_dtype)
468                            + list_lp16_fp32_ops(target_dtype) + original_conditional_op_names)
469
470    illegal_ops = combined_ops - all_lp16_fp32_ops
471    assert not illegal_ops, '''Can only choose ops from one of the three lists
472                            for lp16_ops and fp32_ops
473                            1. amp.list_lp16_ops(target_dtype)
474                            2. amp.list_fp32_ops(target_dtype)
475                            3. amp.list_lp16_fp32_ops(target_dtype)
476                            4. amp.list_conditional_fp32_ops(target_dtype)
477                            Op %s not in any of them''' % (illegal_ops)
478
479    widest_dtype_ops = list_widest_type_cast(target_dtype)
480    if target_dtype == bfloat16:
481        target_dtype = _DTYPE_NP_TO_MX[bfloat16]
482    else:
483        target_dtype = _DTYPE_NP_TO_MX[np.dtype(target_dtype).type]
484
485    # Prepare a data_names list based on list_inputs if its not provided
486    # Add all names in list for the nodes in the symbol which don't have
487    # __dtype__ set
488    attr_dict = sym.attr_dict()
489    if data_names is None:
490        data_names = []
491        for sym_name in sym.list_inputs():
492            if not sym_name in attr_dict:
493                data_names.append(sym_name)
494                continue
495            if not "__dtype__" in attr_dict[sym_name]:
496                data_names.append(sym_name)
497    model_param_names = list(set(sym.list_inputs()) - set(data_names))
498
499    # Since assumption is that it is a FP32 model, set dtypes for all
500    # data_names to float32
501    str_keys = []
502    sdata = []
503    for k in data_names:
504        str_keys.append(k)
505        sdata.append(0)
506    keys = c_str_array(str_keys)
507    out = SymbolHandle()
508    check_call(_LIB.MXReducePrecisionSymbol(sym.handle,
509                                            ctypes.byref(out),
510                                            mx_uint(len(sdata)),
511                                            c_array_buf(ctypes.c_int, array('i', sdata)),
512                                            mx_uint(len(indptr)),
513                                            c_array_buf(ctypes.c_int, array('i', indptr)),
514                                            ctypes.byref(ctypes.c_int(target_dtype)),
515                                            ctypes.c_int(cast_optional_params),
516                                            mx_uint(len(target_dtype_ops)),
517                                            mx_uint(len(fp32_ops)),
518                                            mx_uint(len(widest_dtype_ops)),
519                                            mx_uint(len(conditional_op_names)),
520                                            mx_uint(len(excluded_sym_names)),
521                                            mx_uint(len(model_param_names)),
522                                            c_str_array(target_dtype_ops),
523                                            c_str_array(fp32_ops),
524                                            c_str_array(widest_dtype_ops),
525                                            c_str_array(conditional_op_names),
526                                            c_str_array(excluded_sym_names),
527                                            c_str_array(param_names),
528                                            c_str_array(param_vals),
529                                            c_str_array(model_param_names),
530                                            keys))
531    return Symbol(out)
532
533def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dtype_ops=None,
534                  fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=None,
535                  cast_optional_params=False):
536    """API for converting a model from FP32 model to a mixed precision model.
537    MXNet tries to convert the FP32 model to mixed precision model by adding
538    cast layers using amp_cast and amp_multicast operators which can be used for inference use cases.
539    The decision on which cast layer to add is based on hardcoded lists for Automatic Mixed Precision
540    in MXNet. These lists can be overridden by the user by providing their own lists
541    using : targe_precision_ops, fp32_ops, widest_precision_ops, conditional_fp32_ops
542
543    arg_params : dict
544        Dictionary of name to `NDArray`.
545    aux_params : dict
546        Dictionary of name to `NDArray`.
547    target_dtype : str
548        Currently only supports float16 and bfloat 16. The target dtype indicates to add cast layers
549        when possible so that lower precision computation can be leveraged.
550    target_dtype_ops : list of strs
551        Override the list of operator names casted to target_dtype.
552        If None, uses the framework's default list to be casted to target dtype.
553    fp32_ops : list of strs
554        Override the lists of operator names casted to FP32.
555        If None, uses the framework's default list to be casted to FP32.
556    widest_dtype_ops : list of strs
557        A list of op names provided by user which should run in widest precision among its inputs.
558        If None, uses the framework's default list of widest_precision_ops.
559    conditional_fp32_ops : list of (string, string, list of string)
560        Override the list of operators to be casted to FP32.
561        The format of the list is
562        (name of the function, name of the parameter,
563         list of values of the parameter that make the operator to be casted to
564        fp32)
565    excluded_sym_names : list of strs
566        A list of strings that represent the names of symbols that users want to exclude
567        from being executed in lower precision.
568    cast_optional_params : bool, default False
569        Whether to cast the arg_params and aux_params that don't require to be in LP16
570        because of a cast layer following it, but will reduce the computation and memory
571        overhead of the model if casted.
572    """
573    if excluded_sym_names is None:
574        excluded_sym_names = []
575        if not isinstance(excluded_sym_names, list):
576            raise ValueError('excluded_sym_names must be a list of strings representing'
577                             ' the names of the symbols that should not be casted,'
578                             ' while received type %s' % str(type(excluded_sym_names)))
579    assert target_dtype in ['float16', 'bfloat16'], \
580               "Only target_dtype float16 and bfloat16 are supported currently"
581
582    assert isinstance(sym, Symbol), "First argument to convert_model should be Symbol"
583    assert isinstance(arg_params, dict), "Second argument to convert_model should be a dict of name to ndarray"
584    assert isinstance(aux_params, dict), "Third argument to convert_model should be a dict of name to ndarray"
585
586    param_names = list(arg_params.keys()) + list(aux_params.keys())
587
588    # Only pass non params as data_names, param types can be inferred
589    data_names = list(set(sym.list_inputs()) - set(param_names))
590    sym = convert_symbol(sym, target_dtype, target_dtype_ops,
591                         fp32_ops, conditional_fp32_ops,
592                         excluded_sym_names, data_names,
593                         cast_optional_params)
594
595    # If dtype is set for params, cast the param to that dtype
596    attr_dict = sym.attr_dict()
597    for sym_name in sym.list_arguments():
598        if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]:
599            if attr_dict[sym_name]["__dtype__"] != "-1":
600                typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])]
601                if typ == bfloat16:
602                    arg_params[sym_name] = _cast_symbol_NDArray(arg_params[sym_name], bfloat16)
603                else:
604                    arg_params[sym_name] = arg_params[sym_name].astype(typ)
605
606    for sym_name in sym.list_auxiliary_states():
607        if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]:
608            if attr_dict[sym_name]["__dtype__"] != "-1":
609                typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])]
610                if typ == bfloat16:
611                    aux_params[sym_name] = _cast_symbol_NDArray(aux_params[sym_name], bfloat16)
612                else:
613                    aux_params[sym_name] = aux_params[sym_name].astype(typ)
614
615    # Return the converted symbol and casted params
616    return sym, arg_params, aux_params
617
618def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
619                         fp32_ops=None, conditional_fp32_ops=None,
620                         excluded_sym_names=None, ctx=gpu(0),
621                         cast_optional_params=False):
622    """Given a hybrid block/symbol block representing a FP32 model and a target_dtype,
623    return a block with mixed precision support which can be used for inference use cases.
624
625    Parameters
626    ----------
627    block : HybridBlock or SymbolBlock object
628        FP32 HybridBlock or SymbolBlock object
629    target_dtype : str or numpy
630        currently only supports float16 and bfloat16. The target dtype indicates to add cast layers
631        when possible so that lower precision computation can be leveraged.
632    target_precision_ops : list of strs
633        Override the list of operator names casted to target_dtype.
634        If None, uses the framework's default list to be casted to FP32.
635    conditional_fp32_ops : list of (str, str, list of str)
636        Override the list of functions to be casted to FP32.
637        The format of the list is
638        (name of the function, name of the parameter,
639         list of values of the parameter that make the operator to be casted to FP32
640    excluded_sym_names : list of strs
641        A list of strings that represent the names of symbols that users want to exclude
642        from being quantized
643    ctx : Context
644        Context on which model parameters should live
645    cast_optional_params : bool, default False
646        Whether to cast the arg_params and aux_params that don't require to be in LP16
647        because of a cast layer following it, but will reduce the computation and memory
648        overhead of the model if casted.
649    """
650    from ...gluon import HybridBlock, SymbolBlock
651    assert isinstance(block, HybridBlock), "block input should be a HybridBlock"
652    if not block._cached_graph:
653        raise RuntimeError(
654            "Please first call block.hybridize() and then run forward with "
655            "this block at least once before calling export.")
656
657    # Prepare inputs to pass to the convert_symbol API
658    inputs, sym = block._cached_graph
659    input_names = []
660    for inp in inputs:
661        input_names.append(inp.name)
662    converted_sym = convert_symbol(sym, target_dtype, target_dtype_ops,
663                                   fp32_ops, conditional_fp32_ops,
664                                   excluded_sym_names, data_names=input_names,
665                                   cast_optional_params=cast_optional_params)
666
667    arg_names = set(converted_sym.list_arguments())
668    aux_names = set(converted_sym.list_auxiliary_states())
669    arg_dict = {}
670
671    # If dtype for the param was set in the json, cast the
672    # param to this dtype
673    attr_dict = converted_sym.attr_dict()
674    for name, param in block.collect_params().items():
675        if name in arg_names:
676            arg_dict['arg:%s'%name] = param._reduce()
677            if name in attr_dict and "__dtype__" in attr_dict[name]:
678                if attr_dict[name]["__dtype__"] != "-1":
679                    typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
680                    if typ == bfloat16:
681                        arg_dict['arg:%s' % name] = _cast_symbol_NDArray(arg_dict['arg:%s' % name], bfloat16)
682                    else:
683                        arg_dict['arg:%s'%name] = arg_dict['arg:%s'%name].astype(typ)
684        else:
685            assert name in aux_names
686            arg_dict['aux:%s'%name] = param._reduce()
687            if name in attr_dict and "__dtype__" in attr_dict[name]:
688                if attr_dict[name]["__dtype__"] != "-1":
689                    typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
690                    if typ == bfloat16:
691                        arg_dict['aux:%s' % name] = _cast_symbol_NDArray(arg_dict['aux:%s' % name], 'bfloat16')
692                    else:
693                        arg_dict['aux:%s'%name] = arg_dict['aux:%s'%name].astype(typ)
694
695    # Create a symbolblock and cast the params to the dtypes based
696    # on the dtype information from the converted_symbol
697    ret = SymbolBlock(converted_sym, inputs)
698    for key, param in ret.collect_params().items():
699        arg_param_name = "arg:%s" % key
700        if arg_param_name in arg_dict and param.dtype != arg_dict[arg_param_name].dtype:
701            param.cast(arg_dict[arg_param_name].dtype)
702
703        aux_param_name = "aux:%s" % key
704        if aux_param_name in arg_dict and param.dtype != arg_dict[aux_param_name].dtype:
705            param.cast(arg_dict[aux_param_name].dtype)
706
707    ret.collect_params().load_dict(arg_dict, ctx=ctx)
708    return ret
709
710def convert_bucketing_module(bucketing_mod, target_dtype="float16", target_dtype_ops=None,
711                             fp32_ops=None, conditional_fp32_ops=None,
712                             excluded_sym_names=None, cast_optional_params=False):
713    """Given a bucketing module cast the symbols associated with the BucketingModule
714    and params if cast_optional_params is set.
715    bucketing_mod : BucketingModule instance
716    target_dtype : str
717        Currently only supports float16. The target dtype indicates to add cast layers
718        when possible so that lower precision computation can be leveraged.
719    target_dtype_ops : list of strs
720        Override the list of operator names casted to target_dtype.
721        If None, uses the framework's default list to be casted to target dtype.
722    fp32_ops : list of strs
723        Override the lists of operator names casted to FP32.
724        If None, uses the framework's default list to be casted to FP32.
725    widest_dtype_ops : list of strs
726        A list of op names provided by user which should run in widest precision among its inputs.
727        If None, uses the framework's default list of widest_precision_ops.
728    conditional_fp32_ops : list of (string, string, list of string)
729        Override the list of operators to be casted to FP32.
730        The format of the list is
731        (name of the function, name of the parameter,
732         list of values of the parameter that make the operator to be casted to
733        fp32)
734    excluded_sym_names : list of strs
735        A list of strings that represent the names of symbols that users want to exclude
736        from being executed in lower precision.
737    cast_optional_params : bool, default False
738        Whether to cast the arg_params and aux_params that don't require to be in LP16
739        because of a cast layer following it, but will reduce the computation and memory
740        overhead of the model if casted.
741    """
742    assert isinstance(bucketing_mod, BucketingModule), "module should be instance of bucketing module"
743    assert len(bucketing_mod._buckets) > 0, "Bucketing Module should not be empty"
744
745    sym_dict = {}
746    assert bucketing_mod.params_initialized, \
747        "bucketing_mod params should be initialized for mixed precision conversion"
748    arg_params, aux_params = bucketing_mod._curr_module._arg_params, bucketing_mod._curr_module._aux_params
749    for key, val in bucketing_mod._buckets.items():
750        sym_dict[key], result_arg_params, result_aux_params = convert_model(val._symbol,
751                                                                            arg_params,
752                                                                            aux_params,
753                                                                            target_dtype=target_dtype,
754                                                                            target_dtype_ops=target_dtype_ops,
755                                                                            fp32_ops=fp32_ops,
756                                                                            conditional_fp32_ops=conditional_fp32_ops,
757                                                                            excluded_sym_names=excluded_sym_names,
758                                                                            cast_optional_params=cast_optional_params)
759    result_mod = BucketingModule.load_dict(sym_dict,
760                                           sym_gen=bucketing_mod._sym_gen,
761                                           arg_params=result_arg_params,
762                                           aux_params=result_aux_params,
763                                           default_bucket_key=bucketing_mod._default_bucket_key,
764                                           logger=bucketing_mod.logger,
765                                           context=bucketing_mod._context,
766                                           work_load_list=bucketing_mod._work_load_list,
767                                           fixed_param_names=bucketing_mod._fixed_param_names,
768                                           state_names=bucketing_mod._state_names,
769                                           group2ctxs=bucketing_mod._group2ctxs,
770                                           compression_params=bucketing_mod._compression_params)
771    return result_mod
772
773def list_lp16_ops(target_dtype):
774    """Get the default list of LP16 ops for AMP
775    """
776    if target_dtype in ['float16', np.float16]:
777        return lists.symbol_fp16.FP16_FUNCS
778    else:
779        assert (target_dtype == bfloat16), "not supported type"
780        return lists.symbol_bf16.BF16_FUNCS
781
782def list_fp32_ops(target_dtype):
783    """Get the default list of FP32 ops for AMP
784    """
785    if target_dtype in ['float16', np.float16]:
786        return lists.symbol_fp16.FP32_FUNCS
787    else:
788        assert (target_dtype == bfloat16), "not supported type"
789        return lists.symbol_bf16.FP32_FUNCS
790
791def list_lp16_fp32_ops(target_dtype):
792    """Get the default list of ops which run in both LP16 and FP32
793    """
794    if target_dtype in ['float16', np.float16]:
795        return lists.symbol_fp16.FP16_FP32_FUNCS
796    else:
797        assert (target_dtype == bfloat16), "not supported type"
798        return lists.symbol_bf16.BF16_FP32_FUNCS
799
800def list_conditional_fp32_ops(target_dtype):
801    """Get the conditional fp32 ops list
802    """
803    if target_dtype in ['float16', np.float16]:
804        return lists.symbol_fp16.CONDITIONAL_FP32_FUNCS
805    else:
806        assert (target_dtype == bfloat16), "not supported type"
807        return lists.symbol_bf16.CONDITIONAL_FP32_FUNCS
808
809def list_widest_type_cast(target_dtype):
810    """Get the widest type cast ops list
811    """
812    if target_dtype in ['float16', np.float16]:
813        return lists.symbol_fp16.WIDEST_TYPE_CASTS
814    else:
815        assert (target_dtype == bfloat16), "not supported type"
816        return lists.symbol_bf16.WIDEST_TYPE_CASTS
817
818def list_loss_output_functions(target_dtype):
819    """Get loss function list
820    """
821    if target_dtype in ['float16', np.float16]:
822        return lists.symbol_fp16.LOSS_OUTPUT_FUNCTIONS
823    else:
824        assert (target_dtype == bfloat16), "not supported type"
825        return lists.symbol_bf16.LOSS_OUTPUT_FUNCTIONS
826
827def list_lp16_use_fp32_params(target_dtype):
828    """ Get the params restrict for LP16
829
830    """
831    if target_dtype in ['float16', np.float16]:
832        return None
833    else:
834        assert (target_dtype == bfloat16), "not supported type"
835        return lists.symbol_bf16.BF16_USE_FP32_PARAMS
836