1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable= arguments-differ, too-many-lines, reimported
20"""Base container class for all neural network models."""
21__all__ = ['Block', 'HybridBlock', 'SymbolBlock']
22
23import threading
24import copy
25import warnings
26import re
27import json
28from collections import OrderedDict, defaultdict
29import numpy as np
30
31from ..base import mx_real_t, MXNetError
32from .. import symbol, ndarray, initializer, np_symbol
33from ..symbol import Symbol, load_json
34from ..ndarray import NDArray
35from .. import name as _name
36from .parameter import Parameter, ParameterDict, DeferredInitializationError
37from .utils import _indent, _brief_print_list, HookHandle
38from .utils import _check_same_symbol_type, _check_all_np_ndarrays
39from .. import numpy_extension as _mx_npx
40from .. import numpy as _mx_np
41from .. util import is_np_array, np_shape, np_array
42
43
44
45class _BlockScope(object):
46    """Scope for collecting child `Block` s."""
47    _current = threading.local()
48
49    def __init__(self, block):
50        self._block = block
51        self._counter = {}
52        self._old_scope = None
53        self._name_scope = None
54
55    @staticmethod
56    def create(prefix, params, hint):
57        """Creates prefix and params for new `Block`."""
58        current = getattr(_BlockScope._current, "value", None)
59        if current is None:
60            if prefix is None:
61                if not hasattr(_name.NameManager._current, "value"):
62                    _name.NameManager._current.value = _name.NameManager()
63                prefix = _name.NameManager._current.value.get(None, hint) + '_'
64            if params is None:
65                params = ParameterDict(prefix)
66            else:
67                params = ParameterDict(params.prefix, params)
68            return prefix, params
69
70        if prefix is None:
71            count = current._counter.get(hint, 0)
72            prefix = '%s%d_'%(hint, count)
73            current._counter[hint] = count + 1
74        if params is None:
75            parent = current._block.params
76            params = ParameterDict(parent.prefix+prefix, parent._shared)
77        else:
78            params = ParameterDict(params.prefix, params)
79        return current._block.prefix+prefix, params
80
81    def __enter__(self):
82        if self._block._empty_prefix:
83            return self
84        self._old_scope = getattr(_BlockScope._current, "value", None)
85        _BlockScope._current.value = self
86        self._name_scope = _name.Prefix(self._block.prefix)
87        self._name_scope.__enter__()
88        return self
89
90    def __exit__(self, ptype, value, trace):
91        if self._block._empty_prefix:
92            return
93        self._name_scope.__exit__(ptype, value, trace)
94        self._name_scope = None
95        _BlockScope._current.value = self._old_scope
96
97
98def _gather_type_ctx_info(args):
99    """Analyze the elements inside the nested args object and find:
100        - If there exists ndarray
101        - If there exists symbol
102        - All contexts appearing in args
103
104    Parameters
105    ----------
106    args : list or NDArray or Symbol
107        Could be a nested architecture.
108
109    Returns
110    -------
111    has_symbol : bool
112        Whether the elements in args contains symbols
113    has_ndarray : bool
114        Whether the elements in args contains ndarrays
115    ctx_set : set of mxnet.context.Context
116        Contains all possible contexts of the inner ndarrays in args. Can be empty if there is no
117        ndarray inside args.
118    first_ctx : mxnet.context.Context or None
119        Context of the first appeared NDArray (for backward-compatibility)
120    """
121    if isinstance(args, NDArray):
122        return False, True, {args.ctx}, args.ctx
123    elif isinstance(args, Symbol):
124        return True, False, set(), None
125    elif isinstance(args, (list, tuple)):
126        has_symbol = False
127        has_ndarray = False
128        ctx_set = set()
129        first_ctx = None
130        for ele in args:
131            ele_has_sym, ele_has_nd, ele_ctx_set, ele_first_ctx =\
132                _gather_type_ctx_info(ele)
133            has_symbol = has_symbol or ele_has_sym
134            has_ndarray = has_ndarray or ele_has_nd
135            if first_ctx is None and ele_first_ctx is not None:
136                first_ctx = ele_first_ctx
137            ctx_set = ctx_set | ele_ctx_set
138            if has_symbol and has_ndarray:
139                break
140        return has_symbol, has_ndarray, ctx_set, first_ctx
141    else:
142        return False, False, set(), None
143
144
145def _flatten(args, inout_str):
146    """Parse the arguments into a flattened list + an additional format array.
147    The format array stores the structure of the original arguments to help reconstruct the inputs.
148
149    Parameters
150    ----------
151    args : NDArray, Symbol, or (nested) list of Symbol or NDArray
152        We allow None inside the args.
153    inout_str : str
154        The name of the HybridBlock
155
156    Returns
157    -------
158    flat : list of Symbol or NDArray
159        The flatten version of the input args.
160    fmts : (nested) list of ints
161        Stores the format information of the original structured args.
162    """
163    if isinstance(args, NDArray):
164        return [args], int(0)
165    if isinstance(args, Symbol):
166        length = len(args.list_outputs())
167        length = length if length > 1 else 0
168        return [args], int(length)
169    if args is None:
170        return [None], int(-1)
171
172    if not isinstance(args, (list, tuple)):
173        raise ValueError("When hybridized, the input of HybridBlock {}"
174                         " must be (nested) list of Symbol"
175                         " or NDArray, "
176                         "but got {} of type {}".format(inout_str, str(args), str(type(args))))
177    flat = []
178    fmts = []
179    for i in args:
180        arg, fmt = _flatten(i, inout_str)
181        flat.extend(arg)
182        fmts.append(fmt)
183    return flat, fmts
184
185
186def _regroup(args, fmt):
187    """Reconstruct the structured arguments based on the flattened version.
188
189    Parameters
190    ----------
191    args : NDArray, Symbol, or (nested) list of Symbol or NDArray
192        We allow None inside the args.
193    fmt : (nested) list of ints
194        Stores the format information of the original structured args.
195
196    Returns
197    -------
198    ret : NDArray, Symbol, or (nested) list of Symbol or NDArray
199
200    """
201    def _merger(args, fmt):
202        """Recursive call to merge the arguments"""
203        if isinstance(fmt, int):
204            if fmt < -1:
205                raise ValueError("Unsupported encoded format {}.".format(fmt))
206            if fmt == 0:
207                return args[0], args[1:]
208            if fmt == -1:
209                if args[0] is not None:
210                    raise ValueError('We do not support passing types that are not None'
211                                     ' when the initial HybridBlock has received NoneType and'
212                                     ' has been hybridized.'
213                                     ' Received arg = {}, fmt = {}.'.format(args[0], fmt))
214                return None, args[1:]
215            else:
216                return args[:fmt], args[fmt:]
217
218        if not isinstance(args, (list, tuple)):
219            raise ValueError("When hybridized, the output of HybridBlock must be (nested)"
220                             " list of Symbol or NDArray, "
221                             "but got {} of type {}".format(args, type(args)))
222        ret = []
223        for i in fmt:
224            res, args = _merger(args, i)
225            ret.append(res)
226        return ret, args
227    return _merger(args, fmt)[0]
228
229
230class Block(object):
231    """Base class for all neural network layers and models. Your models should
232    subclass this class.
233
234    :py:class:`Block` can be nested recursively in a tree structure. You can create and
235    assign child :py:class:`Block` as regular attributes::
236
237        from mxnet.gluon import Block, nn
238        from mxnet import ndarray as F
239
240        class Model(Block):
241            def __init__(self, **kwargs):
242                super(Model, self).__init__(**kwargs)
243                # use name_scope to give child Blocks appropriate names.
244                with self.name_scope():
245                    self.dense0 = nn.Dense(20)
246                    self.dense1 = nn.Dense(20)
247
248            def forward(self, x):
249                x = F.relu(self.dense0(x))
250                return F.relu(self.dense1(x))
251
252        model = Model()
253        model.initialize(ctx=mx.cpu(0))
254        model(F.zeros((10, 10), ctx=mx.cpu(0)))
255
256
257    Child :py:class:`Block` assigned this way will be registered and :py:meth:`collect_params`
258    will collect their Parameters recursively. You can also manually register
259    child blocks with :py:meth:`register_child`.
260
261    Parameters
262    ----------
263    prefix : str
264        Prefix acts like a name space. All children blocks created in parent block's
265        :py:meth:`name_scope` will have parent block's prefix in their name.
266        Please refer to
267        `naming tutorial </api/python/docs/tutorials/packages/gluon/blocks/naming.html>`_
268        for more info on prefix and naming.
269    params : ParameterDict or None
270        :py:class:`ParameterDict` for sharing weights with the new :py:class:`Block`. For example,
271        if you want ``dense1`` to share ``dense0``'s weights, you can do::
272
273            dense0 = nn.Dense(20)
274            dense1 = nn.Dense(20, params=dense0.collect_params())
275    """
276    def __init__(self, prefix=None, params=None):
277        self._empty_prefix = prefix == ''
278        self._prefix, self._params = _BlockScope.create(prefix, params, self._alias())
279        self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix
280        self._scope = _BlockScope(self)
281        self._children = OrderedDict()
282        self._reg_params = {}
283        self._forward_hooks = OrderedDict()
284        self._forward_pre_hooks = OrderedDict()
285
286    def __repr__(self):
287        s = '{name}(\n{modstr}\n)'
288        modstr = '\n'.join(['  ({key}): {block}'.format(key=key,
289                                                        block=_indent(block.__repr__(), 2))
290                            for key, block in self.__dict__.items() if isinstance(block, Block)])
291        return s.format(name=self.__class__.__name__, modstr=modstr)
292
293    def __setattr__(self, name, value):
294        """Registers parameters."""
295
296        if hasattr(self, name):
297            existing = getattr(self, name)
298            if isinstance(existing, (Parameter, Block)) and not isinstance(value, type(existing)):
299                raise TypeError('Changing attribute type for {name} from {type1} to {type2}' \
300                                'is not allowed.'.format(
301                                    name=name, type1=type(existing), type2=type(value)))
302
303        if isinstance(value, Block):
304            self.register_child(value, name)
305        elif isinstance(value, Parameter):
306            assert name not in self._reg_params, \
307                "Overriding Parameter attribute %s is not allowed. " \
308                "If you want to share parameters between blocks, please set " \
309                "'params' at Block construction instead."
310            self._reg_params[name] = value
311
312        super(Block, self).__setattr__(name, value)
313
314    def _check_container_with_block(self):
315        children = set(self._children.values())
316        def _find_unregistered_block_in_container(data):
317            # Find whether a nested container structure contains Blocks
318            if isinstance(data, (list, tuple)):
319                for ele in data:
320                    if _find_unregistered_block_in_container(ele):
321                        return True
322                return False
323            elif isinstance(data, dict):
324                for _, v in data.items():
325                    if _find_unregistered_block_in_container(v):
326                        return True
327                return False
328            elif isinstance(data, Block):
329                return not data in children
330            else:
331                return False
332        for k, v in self.__dict__.items():
333            if isinstance(v, (list, tuple, dict)) and not (k.startswith('__') or k == '_children'):
334                if _find_unregistered_block_in_container(v):
335                    warnings.warn('"{name}" is an unregistered container with Blocks. '
336                                  'Note that Blocks inside the list, tuple or dict will not be '
337                                  'registered automatically. Make sure to register them using '
338                                  'register_child() or switching to '
339                                  'nn.Sequential/nn.HybridSequential instead. '
340                                  .format(name=self.__class__.__name__ + "." + k), stacklevel=3)
341
342    def _alias(self):
343        return self.__class__.__name__.lower()
344
345    @property
346    def prefix(self):
347        """Prefix of this :py:class:`Block`."""
348        return self._prefix
349
350    @property
351    def name(self):
352        """Name of this :py:class:`Block`, without '_' in the end."""
353        return self._name
354
355    def name_scope(self):
356        """Returns a name space object managing a child :py:class:`Block` and parameter
357        names. Should be used within a ``with`` statement::
358
359            with self.name_scope():
360                self.dense = nn.Dense(20)
361
362        Please refer to
363        `the naming tutorial </api/python/docs/tutorials/packages/gluon/blocks/naming.html>`_
364        for more info on prefix and naming.
365        """
366        return self._scope
367
368    @property
369    def params(self):
370        """Returns this :py:class:`Block`'s parameter dictionary (does not include its
371        children's parameters)."""
372        return self._params
373
374    def collect_params(self, select=None):
375        """Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its
376        children's Parameters(default), also can returns the select :py:class:`ParameterDict`
377        which match some given regular expressions.
378
379        For example, collect the specified parameters in ['conv1_weight', 'conv1_bias', 'fc_weight',
380        'fc_bias']::
381
382            model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')
383
384        or collect all parameters whose names end with 'weight' or 'bias', this can be done
385        using regular expressions::
386
387            model.collect_params('.*weight|.*bias')
388
389        Parameters
390        ----------
391        select : str
392            regular expressions
393
394        Returns
395        -------
396        The selected :py:class:`ParameterDict`
397        """
398        # We need to check here because blocks inside containers are not supported.
399        self._check_container_with_block()
400        ret = ParameterDict(self._params.prefix)
401        if not select:
402            ret.update(self.params)
403        else:
404            pattern = re.compile(select)
405            ret.update({name:value for name, value in self.params.items() if pattern.match(name)})
406        for cld in self._children.values():
407            ret.update(cld.collect_params(select=select))
408        return ret
409
410    def _collect_params_with_prefix(self, prefix=''):
411        if prefix:
412            prefix += '.'
413        ret = {prefix + key : val for key, val in self._reg_params.items()}
414        for name, child in self._children.items():
415            ret.update(child._collect_params_with_prefix(prefix + name))
416        return ret
417
418    def save_parameters(self, filename, deduplicate=False):
419        """Save parameters to file.
420
421        Saved parameters can only be loaded with `load_parameters`. Note that this
422        method only saves parameters, not model structure. If you want to save
423        model structures, please use :py:meth:`HybridBlock.export`.
424
425        Parameters
426        ----------
427        filename : str
428            Path to file.
429        deduplicate : bool, default False
430            If True, save shared parameters only once. Otherwise, if a Block
431            contains multiple sub-blocks that share parameters, each of the
432            shared parameters will be separately saved for every sub-block.
433
434        References
435        ----------
436        `Saving and Loading Gluon Models \
437        <https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
438        """
439        params = self._collect_params_with_prefix()
440
441        if deduplicate:
442            # Shared parameters are stored only a single time as of MXNet 1.6.
443            # Shared parameters are registered under multiple prefixes returned by
444            # _collect_params_with_prefix. We select a single one and only store
445            # it. In load_parameters it is sufficient for a shared parameter to
446            # only set it for a single prefix.
447            reverse_params = {v: k for k, v in params.items()}
448            params = {v: k for k, v in reverse_params.items()}
449
450        arg_dict = {key: val._reduce() for key, val in params.items()}
451        save_fn = _mx_npx.save if is_np_array() else ndarray.save
452        save_fn(filename, arg_dict)
453
454    def save_params(self, filename):
455        """[Deprecated] Please use save_parameters. Note that if you want load
456        from SymbolBlock later, please use export instead.
457
458        Save parameters to file.
459
460        filename : str
461            Path to file.
462        """
463        warnings.warn("save_params is deprecated. Please use save_parameters. "
464                      "Note that if you want load from SymbolBlock later, please "
465                      "use export instead. For details, see "
466                      "https://mxnet.apache.org/tutorials/gluon/save_lo"
467                      "ad_params.html")
468        try:
469            self.collect_params().save(filename, strip_prefix=self.prefix)
470        except ValueError as e:
471            raise ValueError('%s\nsave_params is deprecated. Using ' \
472                              'save_parameters may resolve this error.'%e.message)
473
474    def load_parameters(self, filename, ctx=None, allow_missing=False,
475                        ignore_extra=False, cast_dtype=False, dtype_source='current'):
476        """Load parameters from file previously saved by `save_parameters`.
477
478        Parameters
479        ----------
480        filename : str
481            Path to parameter file.
482        ctx : Context or list of Context, default cpu()
483            Context(s) to initialize loaded parameters on.
484        allow_missing : bool, default False
485            Whether to silently skip loading parameters not represents in the file.
486        ignore_extra : bool, default False
487            Whether to silently ignore parameters from the file that are not
488            present in this Block.
489        cast_dtype : bool, default False
490            Cast the data type of the NDArray loaded from the checkpoint to the dtype
491            provided by the Parameter if any.
492        dtype_source : str, default 'current'
493            must be in {'current', 'saved'}
494            Only valid if cast_dtype=True, specify the source of the dtype for casting
495            the parameters
496        References
497        ----------
498        `Saving and Loading Gluon Models \
499        <https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
500        """
501        if is_np_array():
502            # failure may happen when loading parameters saved as NDArrays within
503            # NumPy semantics. Check the failure type and recover from it if it happens.
504            try:
505                loaded = _mx_npx.load(filename)
506            except MXNetError as e:
507                err_msg = str(e)
508                if 'is_np_shape' in err_msg:
509                    # Loading failure due to parameters saved without numpy semantics.
510                    # Temporarily disable numpy semantics and load parameters. After it's
511                    # done, resume the numpy semantics. This is fine because the cases
512                    # numpy ndarray covers is a superset of the legacy ndarray's.
513                    with np_array(False):
514                        with np_shape(False):
515                            loaded_nds = ndarray.load(filename)
516                    assert isinstance(loaded_nds, dict),\
517                        'expecting a dict type, got {}'.format(str(type(loaded_nds)))
518                    loaded = {k: loaded_nds[k].as_np_ndarray() for k in loaded_nds}
519                else:
520                    raise ValueError(err_msg)
521        else:
522            loaded = ndarray.load(filename)
523        params = self._collect_params_with_prefix()
524        if not loaded and not params:
525            return
526
527        if not any('.' in i for i in loaded.keys()):
528            # legacy loading
529            loaded = None  # This should be changed to `del loaded` when dropping Python 2
530            self.collect_params().load(
531                filename, ctx, allow_missing, ignore_extra, self.prefix,
532                cast_dtype=cast_dtype, dtype_source=dtype_source)
533            return
534
535        if not allow_missing:
536            # Shared parameters are stored only a single time as of MXNet 1.6.
537            # We thus retrieve all prefixes (through _collect_params_with_prefix)
538            # that a shared parameter is used with. Check that there are no
539            # missing parameters that were not yet already loaded from the
540            # shared version.
541            params_inv = defaultdict(list)
542            for k, v in params.items():
543                params_inv[v].append(k)
544
545            for name, param in params.items():
546                assert any(p in loaded for p in params_inv[param]), \
547                    "Parameter '%s' is missing in file '%s', which contains parameters: %s. " \
548                    "Set allow_missing=True to ignore missing parameters."%(
549                        name, filename, _brief_print_list(loaded.keys()))
550        for name in loaded:
551            if not ignore_extra and name not in params:
552                raise ValueError(
553                    "Parameter '%s' loaded from file '%s' is not present in ParameterDict, " \
554                    "which contains parameters %s. Set ignore_extra=True to ignore. "%(
555                        name, filename, _brief_print_list(self._params.keys())))
556            if name in params:
557                params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)
558
559    def load_params(self, filename, ctx=None, allow_missing=False,
560                    ignore_extra=False):
561        """[Deprecated] Please use load_parameters.
562
563        Load parameters from file.
564
565        filename : str
566            Path to parameter file.
567        ctx : Context or list of Context, default cpu()
568            Context(s) to initialize loaded parameters on.
569        allow_missing : bool, default False
570            Whether to silently skip loading parameters not represents in the file.
571        ignore_extra : bool, default False
572            Whether to silently ignore parameters from the file that are not
573            present in this Block.
574        """
575        warnings.warn("load_params is deprecated. Please use load_parameters.")
576        self.load_parameters(filename, ctx, allow_missing, ignore_extra)
577
578    def register_child(self, block, name=None):
579        """Registers block as a child of self. :py:class:`Block` s assigned to self as
580        attributes will be registered automatically."""
581        if name is None:
582            name = str(len(self._children))
583        self._children[name] = block
584
585    def register_forward_pre_hook(self, hook):
586        r"""Registers a forward pre-hook on the block.
587
588        The hook function is called immediately before :func:`forward`.
589        It should not modify the input or output.
590
591        Parameters
592        ----------
593        hook : callable
594            The forward hook function of form `hook(block, input) -> None`.
595
596        Returns
597        -------
598        :class:`mxnet.gluon.utils.HookHandle`
599        """
600        handle = HookHandle()
601        handle.attach(self._forward_pre_hooks, hook)
602        return handle
603
604    def register_forward_hook(self, hook):
605        r"""Registers a forward hook on the block.
606
607        The hook function is called immediately after :func:`forward`.
608        It should not modify the input or output.
609
610        Parameters
611        ----------
612        hook : callable
613            The forward hook function of form `hook(block, input, output) -> None`.
614
615        Returns
616        -------
617        :class:`mxnet.gluon.utils.HookHandle`
618        """
619        handle = HookHandle()
620        handle.attach(self._forward_hooks, hook)
621        return handle
622
623    def apply(self, fn):
624        r"""Applies ``fn`` recursively to every child block as well as self.
625
626        Parameters
627        ----------
628        fn : callable
629            Function to be applied to each submodule, of form `fn(block)`.
630
631        Returns
632        -------
633        this block
634        """
635        for cld in self._children.values():
636            cld.apply(fn)
637        fn(self)
638        return self
639
640    def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
641                   force_reinit=False):
642        """Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children.
643        Equivalent to ``block.collect_params().initialize(...)``
644
645        Parameters
646        ----------
647        init : Initializer
648            Global default Initializer to be used when :py:meth:`Parameter.init` is ``None``.
649            Otherwise, :py:meth:`Parameter.init` takes precedence.
650        ctx : Context or list of Context
651            Keeps a copy of Parameters on one or many context(s).
652        verbose : bool, default False
653            Whether to verbosely print out details on initialization.
654        force_reinit : bool, default False
655            Whether to force re-initialization if parameter is already initialized.
656        """
657        self.collect_params().initialize(init, ctx, verbose, force_reinit)
658
659    def hybridize(self, active=True, **kwargs):
660        """ Please refer description of HybridBlock hybridize().
661        """
662        for cld in self._children.values():
663            cld.hybridize(active, **kwargs)
664
665    def save(self, prefix):
666        """Save the model architecture and parameters to load again later
667
668        Saves the model architecture as a nested dictionary where each Block
669        in the model is a dictionary and its children are sub-dictionaries.
670
671        Each Block is uniquely identified by Block class name and a unique ID.
672        We save the child's name that that parent uses for it to restore later
673        in order to match the saved parameters.
674
675        Recursively traverses a Block's children in order (since its an
676        OrderedDict) and uses the unique ID to denote that specific Block.
677        Assumes that the model is created in an identical order every time.
678        If the model is not able to be recreated deterministically do not
679        use this set of APIs to save/load your model.
680
681        For HybridBlocks, the cached_graph (Symbol & inputs) is saved if
682        it has already been hybridized.
683
684        Parameters
685        ----------
686        prefix : str
687            The prefix to use in filenames for saving this model:
688            <prefix>-model.json and <prefix>-model.params
689        """
690        # create empty model structure
691        model = {}
692        def _save_cached_graphs(blk, index, structure):
693            # create new entry for this block
694            mdl = {'orig_name': blk.name}
695            # encode unique name based on block type and ID
696            name = type(blk).__name__.lower()
697            structure[name+str(index[0])] = mdl
698            if isinstance(blk, HybridBlock):
699                if blk._cached_graph:
700                    # save in/out formats
701                    mdl['in_format'] = blk._in_format
702                    mdl['out_format'] = blk._out_format
703                    # save cached graph & input symbols
704                    syms, out = blk._cached_graph
705                    mdl_syms = []
706                    for sym in syms:
707                        mdl_syms.append(sym.tojson())
708                    mdl['inputs'] = mdl_syms
709                    mdl['symbol'] = out.tojson()
710                    mdl['hybridized'] = True
711                else:
712                    mdl['hybridized'] = False
713            children = dict()
714            mdl['children'] = children
715            # recursively save children
716            for ch_name, child in blk._children.items():
717                index[0] += 1
718                # save child's original name in this block's map
719                children[child.name] = ch_name
720                _save_cached_graphs(child, index, mdl)
721        # save top-level block
722        index = [0]
723        _save_cached_graphs(self, index, model)
724        # save model
725        with open(prefix+'-model.json', 'w') as fp:
726            json.dump(model, fp)
727        # save params
728        self.save_parameters(prefix+'-model.params')
729
730    def load(self, prefix):
731        """Load a model saved using the `save` API
732
733        Reconfigures a model using the saved configuration. This function
734        does not regenerate the model architecture. It resets the children's
735        names as they were when saved in order to match the names of the
736        saved parameters.
737
738        This function assumes the Blocks in the model were created in the same
739        order they were when the model was saved. This is because each Block is
740        uniquely identified by Block class name and a unique ID in order (since
741        its an OrderedDict) and uses the unique ID to denote that specific Block.
742        Assumes that the model is created in an identical order every time.
743        If the model is not able to be recreated deterministically do not
744        use this set of APIs to save/load your model.
745
746        For HybridBlocks, the cached_graph (Symbol & inputs) and settings are
747        restored if it had been hybridized before saving.
748
749        Parameters
750        ----------
751        prefix : str
752            The prefix to use in filenames for loading this model:
753            <prefix>-model.json and <prefix>-model.params
754        """
755        # load model json from file
756        with open(prefix+'-model.json') as fp:
757            model = json.load(fp)
758
759        def _load_cached_graphs(blk, index, structure):
760            # get block name
761            name = type(blk).__name__.lower()
762            # lookup previous encoded name based on block type and ID
763            mdl = structure[name+str(index[0])]
764            # rename block to what it was when saved
765            blk._name = mdl['orig_name']
766            if isinstance(blk, HybridBlock):
767                if mdl['hybridized']:
768                    # restore in/out formats
769                    blk._in_format = mdl['in_format']
770                    blk._out_format = mdl['out_format']
771                    # get saved symbol
772                    out = load_json(mdl['symbol'])
773                    syms = []
774                    # recreate inputs for this symbol
775                    for inp in mdl['inputs']:
776                        syms.append(load_json(inp))
777                    # reset cached_graph and active status
778                    blk._cached_graph = (syms, out)
779                    blk._active = True
780            # rename params with updated block name
781            pnames = list(blk.params.keys())
782            for p in pnames:
783                param = blk.params._params[p]
784                new_name = blk.name +'_'+ p[len(blk.params._prefix):]
785                blk.params._params.pop(p)
786                blk.params._params[new_name] = param
787            # recursively reload children
788            for ch_name, child in blk._children.items():
789                index[0] += 1
790                _load_cached_graphs(child, index, mdl)
791            # current set of child names
792            ch_names = list(blk._children.keys())
793            # original child names
794            children = mdl['children']
795            # loop and remap children with original names
796            for ch_name in ch_names:
797                child = blk._children[ch_name]
798                blk._children.pop(ch_name)
799                orig_name = children[child.name]
800                blk._children[orig_name] = child
801        # load top-level block
802        index = [0]
803        _load_cached_graphs(self, index, model)
804        # load params
805        self.load_parameters(prefix+'-model.params')
806
807    def cast(self, dtype):
808        """Cast this Block to use another data type.
809
810        Parameters
811        ----------
812        dtype : str or numpy.dtype
813            The new data type.
814        """
815        for child in self._children.values():
816            child.cast(dtype)
817        for _, param in self.params.items():
818            param.cast(dtype)
819
820    def __call__(self, *args):
821        """Calls forward. Only accepts positional arguments."""
822        for hook in self._forward_pre_hooks.values():
823            hook(self, args)
824
825        out = self.forward(*args)
826
827        for hook in self._forward_hooks.values():
828            hook(self, args, out)
829        if _mx_npx.is_np_array():
830            _check_all_np_ndarrays(out)
831        return out
832
833    def forward(self, *args):
834        """Overrides to implement forward computation using :py:class:`NDArray`. Only
835        accepts positional arguments.
836
837        Parameters
838        ----------
839        *args : list of NDArray
840            Input tensors.
841        """
842        # pylint: disable= invalid-name
843        raise NotImplementedError
844
845    def register_op_hook(self, callback, monitor_all=False):
846        """Install callback monitor.
847
848        Parameters
849        ----------
850        callback : function
851            Takes a string and a NDArrayHandle.
852        monitor_all : bool, default False
853            If true, monitor both input and output, otherwise monitor output only.
854        """
855        for cld in self._children.values():
856            cld.register_op_hook(callback, monitor_all)
857
858    def summary(self, *inputs):
859        """Print the summary of the model's output and parameters.
860
861        The network must have been initialized, and must not have been hybridized.
862
863        Parameters
864        ----------
865        inputs : object
866            Any input that the model supports. For any tensor in the input, only
867            :class:`mxnet.ndarray.NDArray` is supported.
868        """
869        summary = OrderedDict()
870        seen = set()
871        hooks = []
872
873        def _get_shape_str(args):
874            def flatten(args):
875                if not isinstance(args, (list, tuple)):
876                    return [args], int(0)
877                flat = []
878                fmts = []
879                for i in args:
880                    arg, fmt = flatten(i)
881                    flat.extend(arg)
882                    fmts.append(fmt)
883                return flat, fmts
884
885            def regroup(args, fmt):
886                if isinstance(fmt, int):
887                    if fmt == 0:
888                        return args[0], args[1:]
889                    return args[:fmt], args[fmt:]
890                ret = []
891                for i in fmt:
892                    res, args = regroup(args, i)
893                    ret.append(res)
894                return ret, args
895
896            flat_args, fmts = flatten(args)
897            flat_arg_shapes = [x.shape if isinstance(x, ndarray.NDArray) else x
898                               for x in flat_args]
899            shapes = regroup(flat_arg_shapes, fmts)[0]
900            if isinstance(shapes, list):
901                shape_str = str(shapes)[1:-1]
902            else:
903                shape_str = str(shapes)
904            return shape_str.replace('L', '')
905
906        def _register_summary_hook(block):
907            assert not isinstance(block, HybridBlock) or not block._active, \
908                    '"{}" must not be hybridized to print summary.'.format(block.name)
909            def _summary_hook(block, _, outputs):
910                class_name = block.__class__.__name__
911                block_idx = len(summary) - 1
912
913                m_key = '%s-%i' % (class_name, block_idx+1)
914                summary[m_key] = OrderedDict()
915                summary[m_key]['output_shape'] = _get_shape_str(outputs)
916
917                params = 0
918                summary[m_key]['trainable'] = 0
919                summary[m_key]['shared'] = 0
920                for p in block.params.values():
921                    params += p.data().size
922                    summary[m_key]['trainable'] += 0 if p.grad_req == 'null' else p.data().size
923                    if p in seen:
924                        summary[m_key]['shared'] += p.data().size
925                    else:
926                        seen.add(p)
927                summary[m_key]['n_params'] = params
928
929            from .nn.basic_layers import Sequential, HybridSequential
930            if not isinstance(block, (Sequential, HybridSequential)):
931                hooks.append(block.register_forward_hook(_summary_hook))
932
933        summary['Input'] = OrderedDict()
934        summary['Input']['output_shape'] = _get_shape_str(inputs)
935        summary['Input']['n_params'] = 0
936        summary['Input']['trainable'] = 0
937        summary['Input']['shared'] = 0
938
939        try:
940            self.apply(_register_summary_hook)
941            self(*inputs)
942
943            line_format = '{:>20}  {:>42} {:>15}'
944            print('-'*80)
945            print(line_format.format('Layer (type)', 'Output Shape', 'Param #'))
946            print('='*80)
947            total_params = 0
948            trainable_params = 0
949            shared_params = 0
950            for layer in summary:
951                print(line_format.format(layer,
952                                         str(summary[layer]['output_shape']),
953                                         summary[layer]['n_params']))
954                total_params += summary[layer]['n_params']
955                trainable_params += summary[layer]['trainable']
956                shared_params += summary[layer]['shared']
957            print('='*80)
958            print('Parameters in forward computation graph, duplicate included')
959            print('   Total params: ' + str(total_params))
960            print('   Trainable params: ' + str(trainable_params))
961            print('   Non-trainable params: ' + str(total_params - trainable_params))
962            print('Shared params in forward computation graph: ' + str(shared_params))
963            print('Unique parameters in model: ' + str(total_params - shared_params))
964            print('-'*80)
965        finally:
966            for h in hooks:
967                h.detach()
968
969
970class HybridBlock(Block):
971    """`HybridBlock` supports forwarding with both Symbol and NDArray.
972
973    `HybridBlock` is similar to `Block`, with a few differences::
974
975        import mxnet as mx
976        from mxnet.gluon import HybridBlock, nn
977
978        class Model(HybridBlock):
979            def __init__(self, **kwargs):
980                super(Model, self).__init__(**kwargs)
981                # use name_scope to give child Blocks appropriate names.
982                with self.name_scope():
983                    self.dense0 = nn.Dense(20)
984                    self.dense1 = nn.Dense(20)
985
986            def hybrid_forward(self, F, x):
987                x = F.relu(self.dense0(x))
988                return F.relu(self.dense1(x))
989
990        model = Model()
991        model.initialize(ctx=mx.cpu(0))
992        model.hybridize()
993        model(mx.nd.zeros((10, 10), ctx=mx.cpu(0)))
994
995    Forward computation in :py:class:`HybridBlock` must be static to work with :py:class:`Symbol` s,
996    i.e. you cannot call :py:meth:`NDArray.asnumpy`, :py:attr:`NDArray.shape`,
997    :py:attr:`NDArray.dtype`, `NDArray` indexing (`x[i]`) etc on tensors.
998    Also, you cannot use branching or loop logic that bases on non-constant
999    expressions like random numbers or intermediate results, since they change
1000    the graph structure for each iteration.
1001
1002    Before activating with :py:meth:`hybridize()`, :py:class:`HybridBlock` works just like normal
1003    :py:class:`Block`. After activation, :py:class:`HybridBlock` will create a symbolic graph
1004    representing the forward computation and cache it. On subsequent forwards,
1005    the cached graph will be used instead of :py:meth:`hybrid_forward`.
1006
1007    Please see references for detailed tutorial.
1008
1009    References
1010    ----------
1011        `Hybrid - Faster training and easy deployment
1012        <https://mxnet.io/tutorials/gluon/hybrid.html>`_
1013    """
1014    def __init__(self, prefix=None, params=None):
1015        super(HybridBlock, self).__init__(prefix=prefix, params=params)
1016        self._cached_graph = ()
1017        self._cached_op = None
1018        self._cached_op_args = []
1019        self._out_format = None
1020        self._in_format = None
1021        self._active = False
1022        self._flags = []
1023        self._callback = None
1024        self._monitor_all = False
1025        self._backend = None
1026        self._backend_opts = {}
1027
1028    def __setattr__(self, name, value):
1029        """Registers parameters."""
1030        super(HybridBlock, self).__setattr__(name, value)
1031        if isinstance(value, HybridBlock):
1032            self._clear_cached_op()
1033
1034    def _get_graph(self, *args):
1035        if not self._cached_graph:
1036            flatten_args, self._in_format = _flatten(args, "input")
1037            flatten_inputs = []
1038            symbol_inputs = []
1039            cnt = 0
1040            real_arg_num = sum([ele is not None for ele in flatten_args])
1041            if real_arg_num == 0:
1042                raise ValueError('All args are None and we do not support such a case.'
1043                                 ' Received args={}'.format(args))
1044            for arg in flatten_args:
1045                if arg is not None:
1046                    if real_arg_num > 1:
1047                        arg_sym = symbol.var('data{}'.format(cnt))
1048                    else:
1049                        arg_sym = symbol.var('data')
1050                    if isinstance(arg, _mx_np.ndarray):
1051                        arg_sym = arg_sym.as_np_ndarray()
1052                    cnt += 1
1053                    flatten_inputs.append(arg_sym)
1054                    symbol_inputs.append(arg_sym)
1055                else:
1056                    flatten_inputs.append(None)
1057            grouped_inputs = _regroup(flatten_inputs, self._in_format)
1058            params = {i: j.var() for i, j in self._reg_params.items()}
1059            with self.name_scope():
1060                out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
1061            out, self._out_format = _flatten(out, "output")
1062
1063            self._cached_graph = symbol_inputs, symbol.Group(out, _check_same_symbol_type(out))
1064
1065        return self._cached_graph
1066
1067    def _build_cache(self, *args):
1068        data, out = self._get_graph(*args)
1069        data_names = {data.name: i for i, data in enumerate(data)}
1070        input_names = out.list_inputs()
1071        expected_names = set(input_names)
1072
1073        # try to reuse cached_op_args for params
1074        if len(self._cached_op_args) > 0:
1075            params = {param_tuple[1].name:param_tuple[1]
1076                      for param_tuple in self._cached_op_args
1077                      if isinstance(param_tuple[1], Parameter)}
1078        else:
1079            params = self.collect_params()
1080        param_names = set(params.keys())
1081        for name in expected_names:
1082            assert name in param_names or name in data_names, \
1083                "Unknown input to HybridBlock: %s" %name
1084
1085        used_data_names = [i for i in data_names if i in expected_names]
1086        if len(used_data_names) != len(data_names):
1087            unused = ', '.join(['%d-th'%i for name, i in data_names.items()
1088                                if name not in expected_names])
1089            warnings.warn("The %s input to HybridBlock is not used by any "
1090                          "computation. Is this intended?"%unused, stacklevel=4)
1091
1092        used_param_names = [i for i in param_names if i in expected_names]
1093        if len(used_param_names) != len(param_names):
1094            unused = ', '.join(list(param_names - set(used_param_names)))
1095            warnings.warn("Parameter %s is not used by any computation. "
1096                          "Is this intended?"%unused, stacklevel=4)
1097
1098        args, _ = _flatten(args, "input")
1099        try:
1100            for name in input_names:
1101                if name in params:
1102                    params[name].data()
1103        except DeferredInitializationError:
1104            self._deferred_infer_shape(*args)
1105            for name in input_names:
1106                if name in params:
1107                    params[name]._finish_deferred_init()
1108
1109        arg_dict, aux_dict = dict(), dict()
1110        if self._backend:
1111            # set context for inputs
1112            _, _, ctx_set, _ = _gather_type_ctx_info(list(args))
1113            ctx = ctx_set.pop() if len(ctx_set) > 0 else None
1114            # get list of params in the order of out.list_arguments
1115            input_shapes = dict()
1116            for name in out.list_arguments():
1117                if name in data_names.keys() and data_names[name] < len(args):
1118                    if isinstance(args[data_names[name]], NDArray):
1119                        arg_dict[name] = args[data_names[name]]
1120                    elif (isinstance(args[data_names[name]], symbol.Symbol) and
1121                          '__shape__' in args[data_names[name]].list_attr()):
1122                        shape_str = args[data_names[name]].list_attr()['__shape__']
1123                        input_shapes[name] = tuple(map(int, shape_str.strip('()').split(',')))
1124                elif name in params:
1125                    arg_dict[name] = params[name].data()
1126
1127            for name in out.list_auxiliary_states():
1128                if name in data_names.keys() and data_names[name] < len(args):
1129                    if isinstance(args[data_names[name]], NDArray):
1130                        aux_dict[name] = args[data_names[name]]
1131                    elif (isinstance(args[data_names[name]], symbol.Symbol) and
1132                          '__shape__' in args[data_names[name]].list_attr()):
1133                        shape_str = args[data_names[name]].list_attr()['__shape__']
1134                        input_shapes[name] = tuple(map(int, shape_str.strip('()').split(',')))
1135                elif name in params:
1136                    aux_dict[name] = params[name].data()
1137
1138            # Partition the graph
1139            out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, input_shapes, **self._backend_opts)
1140
1141            # convert to numpy symbol if needed
1142            if _mx_npx.is_np_array():
1143                out = out.as_np_ndarray()
1144
1145            #update cached graph with partitioned graph
1146            self._cached_graph = data, out
1147
1148        input_names = out.list_inputs()
1149        data_indices = []
1150        param_indices = []
1151
1152        # In the default case, _cached_ops_args contains all the parameters from params (the sets are identical)
1153        # In the case of Partition API optimized graph _cached_ops_args might contain some parameters from params,
1154        # might contain some new parameters created during optimization and added to `arg_dict/aux_dict`,
1155        # and might not contain some parameters that were deleted during optimization.
1156        self._cached_op_args = []
1157        for i, name in enumerate(input_names):
1158            pair = None
1159            if name in data_names:
1160                data_indices.append(i)
1161                pair = (True, data_names[name])
1162            else:
1163                param_indices.append(i)
1164                if name in params:
1165                    param = params[name]
1166                else:
1167                    # The param is missing from the original params dictionary, which means the param must have
1168                    # been added by the Partition API backend
1169                    if name in arg_dict or name:
1170                        param_data = arg_dict[name]
1171                    elif name in aux_dict:
1172                        param_data = aux_dict[name]
1173                    else:
1174                        raise RuntimeError('A parameter was added to the graph during optimization but it was not '
1175                                           'added to the parameter dicts.\n'
1176                                           'Please check the backend.')
1177
1178                    param = Parameter(name, dtype=param_data.dtype)
1179                    param._load_init(param_data, param_data.context)
1180                pair = (False, param)
1181
1182            self._cached_op_args.append(pair)
1183
1184        flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
1185                self._flags
1186
1187        self._cached_op = ndarray.CachedOp(out, flags)
1188
1189
1190    def _deferred_infer_shape(self, *args):
1191        try:
1192            self.infer_shape(*args)
1193        except Exception as e:
1194            error_msg = "Deferred initialization failed because shape"\
1195                        " cannot be inferred. {}".format(e)
1196            raise ValueError(error_msg)
1197
1198    def _call_cached_op(self, *args):
1199        if self._cached_op is None:
1200            self._build_cache(*args)
1201        assert self._cached_op, "Gluon failed to build the cache. " \
1202                                "This should never happen. " \
1203                                "Please submit an issue on Github" \
1204                                " https://github.com/apache/incubator-mxnet."
1205        if self._callback:
1206            self._cached_op._register_op_hook(self._callback, self._monitor_all)
1207            if len(self._flags) >= 2 and (self._flags[1] or self._flags[0]):
1208                warnings.warn("register_op_hook is experimental when static_alloc=True / static_shape=True "
1209                              " and may not work correctly")
1210
1211        args, fmt = _flatten(args, "input")
1212        if fmt != self._in_format:
1213            # Do not raise in the case that the fmt or stored_fmt ends with None and
1214            # We are relying on the default values.
1215            if len(self._in_format) > len(fmt):
1216                valid = all([self._in_format[i] == -1
1217                             for i in range(len(fmt), len(self._in_format))])
1218                valid = valid and (fmt == self._in_format[:len(fmt)])
1219            elif len(self._in_format) < len(fmt):
1220                valid = all([fmt[i] == -1
1221                             for i in range(len(self._in_format), len(fmt))])
1222                valid = valid and (fmt[:len(self._in_format)] == self._in_format)
1223            else:
1224                valid = False
1225            if not valid:
1226                raise ValueError("The argument structure of HybridBlock does not match"
1227                                 " the cached version. Stored format = {}, input format = {}"
1228                                 .format(fmt, self._in_format))
1229
1230        args_without_none = [ele for ele in args if ele is not None]
1231        cargs = [args_without_none[i] if is_arg else i.data()
1232                 for is_arg, i in self._cached_op_args]
1233        out = self._cached_op(*cargs)
1234        if isinstance(out, NDArray):
1235            out = [out]
1236        return _regroup(out, self._out_format)
1237
1238    def optimize_for(self, x, *args, backend=None, clear=False,
1239                     static_alloc=False,
1240                     static_shape=False,
1241                     inline_limit=2,
1242                     forward_bulk_size=None,
1243                     backward_bulk_size=None,
1244                     **kwargs):
1245        """Partitions the current HybridBlock and optimizes it for a given backend
1246        without executing a forward pass. Modifies the HybridBlock in-place.
1247
1248        Immediately partitions a HybridBlock using the specified backend. Combines
1249        the work done in the hybridize API with part of the work done in the forward
1250        pass without calling the CachedOp. Can be used in place of hybridize,
1251        afterwards `export` can be called or inference can be run. See README.md in
1252        example/extensions/lib_subgraph/README.md for more details.
1253
1254        Examples
1255        --------
1256        # partition and then export to file
1257        block.optimize_for(x, backend='myPart')
1258        block.export('partitioned')
1259
1260        # partition and then run inference
1261        block.optimize_for(x, backend='myPart')
1262        block(x)
1263
1264        Parameters
1265        ----------
1266        x : NDArray
1267            first input to model
1268        *args : NDArray
1269            other inputs to model
1270        backend : str
1271            The name of backend, as registered in `SubgraphBackendRegistry`, default None
1272        clear : bool, default False
1273            Clears any previous optimizations
1274        static_alloc : bool, default False
1275            Statically allocate memory to improve speed. Memory usage may increase.
1276        static_shape : bool, default False
1277            Optimize for invariant input shapes between iterations. Must also
1278            set static_alloc to True. Change of input shapes is still allowed
1279            but slower.
1280        inline_limit : optional int, default 2
1281            Maximum number of operators that can be inlined.
1282        forward_bulk_size : optional int, default None
1283            Segment size of bulk execution during forward pass.
1284        backward_bulk_size : optional int, default None
1285            Segment size of bulk execution during forward pass.
1286        **kwargs: The backend options, optional
1287            Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
1288        """
1289        if len(kwargs) > 0:
1290            self._backend_opts = kwargs
1291        if not backend:
1292            raise ValueError('Must specify "backend" to optimize_for')
1293
1294        self.hybridize(True, backend, clear, static_alloc, static_shape,
1295                       inline_limit, forward_bulk_size, backward_bulk_size)
1296
1297        # do part of forward API call
1298        has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
1299        if not has_symbol and not has_ndarray:
1300            raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.'
1301                             ' Please check the type of the args.\n')
1302        if len(ctx_set) > 1:
1303            raise ValueError('Found multiple contexts in the input, '
1304                             'After hybridized, the HybridBlock only supports one input '
1305                             'context. You can print the ele.ctx in the '
1306                             'input arguments to inspect their contexts. '
1307                             'Find all contexts = {}'.format(ctx_set))
1308
1309        self._build_cache(x, *args)
1310        assert self._cached_op, "Gluon failed to build the cache. " \
1311                                "This should never happen. " \
1312                                "Please submit an issue on Github" \
1313                                " https://github.com/apache/incubator-mxnet."
1314        # do not actually call the cached_op
1315
1316    def _clear_cached_op(self):
1317        self._cached_graph = ()
1318        self._cached_op = None
1319        self._cached_op_args = []
1320
1321    def register_child(self, block, name=None):
1322        if not isinstance(block, HybridBlock):
1323            raise ValueError(
1324                "Children of HybridBlock must also be HybridBlock, " \
1325                "but %s has type %s. If you are using Sequential, " \
1326                "please try HybridSequential instead."%(
1327                    str(block), str(type(block))))
1328        super(HybridBlock, self).register_child(block, name)
1329        self._clear_cached_op()
1330
1331    def hybridize(self, active=True, backend=None, clear=True,
1332                  static_alloc=False, static_shape=False,
1333                  inline_limit=2,
1334                  forward_bulk_size=None,
1335                  backward_bulk_size=None,
1336                  **kwargs):
1337        """Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
1338        non-hybrid children.
1339
1340        Parameters
1341        ----------
1342        active : bool, default True
1343            Whether to turn hybrid on or off.
1344        backend : str
1345            The name of backend, as registered in `SubgraphBackendRegistry`, default None
1346        clear : bool, default True
1347            Clears any previous optimizations
1348        static_alloc : optional bool, default False
1349            Statically allocate memory to improve speed. Memory usage may increase.
1350        static_shape : optional bool, default False
1351            Optimize for invariant input shapes between iterations. Must also
1352            set static_alloc to True. Change of input shapes is still allowed
1353            but slower.
1354        inline_limit : optional int, default 2
1355            Maximum number of operators that can be inlined.
1356        forward_bulk_size : optional int, default None
1357            Segment size of bulk execution during forward pass.
1358        backward_bulk_size : optional int, default None
1359            Segment size of bulk execution during forward pass.
1360        **kwargs:  optional
1361            Backend options.
1362        """
1363        if len(kwargs) > 0:
1364            self._backend_opts = kwargs
1365
1366        self._backend = backend
1367
1368        self._active = active
1369        self._flags = [("static_alloc", static_alloc), ("static_shape", static_shape),
1370                       ("inline_limit", inline_limit)]
1371        if forward_bulk_size is not None:
1372            self._flags.append(("forward_bulk_size", forward_bulk_size))
1373        if backward_bulk_size is not None:
1374            self._flags.append(("backward_bulk_size", backward_bulk_size))
1375        if clear:
1376            self._clear_cached_op()
1377        if active and self._forward_hooks or self._forward_pre_hooks:
1378            warnings.warn('"{block}" is being hybridized while still having forward hook/pre-hook. '
1379                          'If "{block}" is a child of HybridBlock, the hooks will not take effect.'
1380                          .format(block=self))
1381        super(HybridBlock, self).hybridize(active,
1382                                           static_alloc=static_alloc,
1383                                           static_shape=static_shape,
1384                                           inline_limit=inline_limit,
1385                                           forward_bulk_size=forward_bulk_size,
1386                                           backward_bulk_size=backward_bulk_size)
1387
1388    def cast(self, dtype):
1389        self._clear_cached_op()
1390        super(HybridBlock, self).cast(dtype)
1391
1392    def _infer_attrs(self, infer_fn, attr, *args):
1393        """Generic infer attributes."""
1394        inputs, out = self._get_graph(*args)
1395        args, _ = _flatten(args, "input")
1396        args_without_none = [ele for ele in args if ele is not None]
1397        with warnings.catch_warnings(record=True) as w:
1398            arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
1399                **{i.name: getattr(j, attr) for i, j in zip(inputs, args_without_none)})
1400            if arg_attrs is None:
1401                raise ValueError(w[0].message)
1402        sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
1403        sdict.update({name : attr for name, attr in \
1404             zip(out.list_auxiliary_states(), aux_attrs)})
1405        for i in self.collect_params().values():
1406            setattr(i, attr, sdict[i.name])
1407
1408    def infer_shape(self, *args):
1409        """Infers shape of Parameters from inputs."""
1410        self._infer_attrs('infer_shape', 'shape', *args)
1411
1412    def infer_type(self, *args):
1413        """Infers data type of Parameters from inputs."""
1414        self._infer_attrs('infer_type', 'dtype', *args)
1415
1416    def export(self, path, epoch=0, remove_amp_cast=True):
1417        """Export HybridBlock to json format that can be loaded by
1418        `gluon.SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface.
1419
1420        .. note:: When there are only one input, it will have name `data`. When there
1421                  Are more than one inputs, they will be named as `data0`, `data1`, etc.
1422
1423        Parameters
1424        ----------
1425        path : str
1426            Path to save model. Two files `path-symbol.json` and `path-xxxx.params`
1427            will be created, where xxxx is the 4 digits epoch number.
1428        epoch : int
1429            Epoch number of saved model.
1430        """
1431        if not self._cached_graph:
1432            raise RuntimeError(
1433                "Please first call block.hybridize() and then run forward with "
1434                "this block at least once before calling export.")
1435        sym = self._cached_graph[1]
1436        sym.save('%s-symbol.json'%path, remove_amp_cast=remove_amp_cast)
1437
1438        arg_names = set(sym.list_arguments())
1439        aux_names = set(sym.list_auxiliary_states())
1440        arg_dict = {}
1441        for is_arg, param in self._cached_op_args:
1442            if not is_arg:
1443                name = param.name
1444                if name in arg_names:
1445                    arg_dict['arg:{}'.format(name)] = param._reduce()
1446                else:
1447                    if name not in aux_names:
1448                        warnings.warn('Parameter "{name}" is not found in the graph. '
1449                                      .format(name=name), stacklevel=3)
1450                    else:
1451                        arg_dict['aux:%s'%name] = param._reduce()
1452        save_fn = _mx_npx.save if is_np_array() else ndarray.save
1453        save_fn('%s-%04d.params'%(path, epoch), arg_dict)
1454
1455    def register_op_hook(self, callback, monitor_all=False):
1456        """Install op hook for block recursively.
1457
1458        Parameters
1459        ----------
1460        callback : function
1461            Takes a string and a NDArrayHandle.
1462        monitor_all : bool, default False
1463            If true, monitor both input and output, otherwise monitor output only.
1464        """
1465        self._callback = callback
1466        self._monitor_all = monitor_all
1467        for cld in self._children.values():
1468            cld._callback = callback
1469            cld._monitor_all = monitor_all
1470
1471    def forward(self, x, *args):
1472        """Defines the forward computation. Arguments can be either
1473        :py:class:`NDArray` or :py:class:`Symbol`."""
1474
1475        has_symbol, has_ndarray, ctx_set, first_ctx = _gather_type_ctx_info([x] + list(args))
1476        if has_symbol and has_ndarray:
1477            raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols'
1478                             ' types for the input. Please check the type of the args.\n')
1479        if not has_symbol and not has_ndarray:
1480            raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.'
1481                             ' Please check the type of the args.\n')
1482        if has_ndarray:
1483            ctx = first_ctx
1484            if self._active:
1485                if len(ctx_set) > 1:
1486                    raise ValueError('Find multiple contexts in the input, '
1487                                     'After hybridized, the HybridBlock only supports one input '
1488                                     'context. You can print the ele.ctx in the '
1489                                     'input arguments to inspect their contexts. '
1490                                     'Find all contexts = {}'.format(ctx_set))
1491                with ctx:
1492                    return self._call_cached_op(x, *args)
1493            with ctx:
1494                try:
1495                    params = {k: v.data(ctx) for k, v in self._reg_params.items()}
1496                except DeferredInitializationError:
1497                    self._deferred_infer_shape(x, *args)
1498                    for _, v in self.params.items():
1499                        v._finish_deferred_init()
1500                    params = {k: v.data(ctx) for k, v in self._reg_params.items()}
1501
1502                return self.hybrid_forward(ndarray, x, *args, **params)
1503        params = {i: j.var() for i, j in self._reg_params.items()}
1504        with self.name_scope():
1505            return self.hybrid_forward(symbol, x, *args, **params)
1506
1507    def hybrid_forward(self, F, x, *args, **kwargs):
1508        """Overrides to construct symbolic graph for this `Block`.
1509
1510        Parameters
1511        ----------
1512        x : Symbol or NDArray
1513            The first input tensor.
1514        *args : list of Symbol or list of NDArray
1515            Additional input tensors.
1516        """
1517        # pylint: disable= invalid-name
1518        raise NotImplementedError
1519
1520def _common_prefix(names):
1521    """Get the common prefix for all names"""
1522    if not names:
1523        return ''
1524    prefix = names[0]
1525    for name in names:
1526        i = 0
1527        while i < len(prefix) and i < len(name) and prefix[i] == name[i]:
1528            i += 1
1529        prefix = prefix[:i]
1530    return prefix
1531
1532
1533class SymbolBlock(HybridBlock):
1534    """Construct block from symbol. This is useful for using pre-trained models
1535    as feature extractors. For example, you may want to extract the output
1536    from fc2 layer in AlexNet.
1537
1538    Parameters
1539    ----------
1540    outputs : Symbol or list of Symbol
1541        The desired output for SymbolBlock.
1542    inputs : Symbol or list of Symbol
1543        The Variables in output's argument that should be used as inputs.
1544    params : ParameterDict
1545        Parameter dictionary for arguments and auxililary states of outputs
1546        that are not inputs.
1547
1548    Examples
1549    --------
1550    >>> # To extract the feature from fc1 and fc2 layers of AlexNet:
1551    >>> alexnet = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=mx.cpu(),
1552                                                 prefix='model_')
1553    >>> inputs = mx.sym.var('data')
1554    >>> out = alexnet(inputs)
1555    >>> internals = out.get_internals()
1556    >>> print(internals.list_outputs())
1557    ['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...]
1558    >>> outputs = [internals['model_dense0_relu_fwd_output'],
1559                   internals['model_dense1_relu_fwd_output']]
1560    >>> # Create SymbolBlock that shares parameters with alexnet
1561    >>> feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params())
1562    >>> x = mx.nd.random.normal(shape=(16, 3, 224, 224))
1563    >>> print(feat_model(x))
1564    """
1565    @staticmethod
1566    def imports(symbol_file, input_names, param_file=None, ctx=None, allow_missing=False,
1567                ignore_extra=False):
1568        """Import model previously saved by `gluon.HybridBlock.export` or
1569        `Module.save_checkpoint` as a `gluon.SymbolBlock` for use in Gluon.
1570
1571        Parameters
1572        ----------
1573        symbol_file : str
1574            Path to symbol file.
1575        input_names : list of str
1576            List of input variable names
1577        param_file : str, optional
1578            Path to parameter file.
1579        ctx : Context, default None
1580            The context to initialize `gluon.SymbolBlock` on.
1581        allow_missing : bool, default False
1582            Whether to silently skip loading parameters not represents in the file.
1583        ignore_extra : bool, default False
1584            Whether to silently ignore parameters from the file that are not
1585            present in this Block.
1586
1587        Returns
1588        -------
1589        gluon.SymbolBlock
1590            `gluon.SymbolBlock` loaded from symbol and parameter files.
1591
1592        Examples
1593        --------
1594        >>> net1 = gluon.model_zoo.vision.resnet18_v1(
1595        ...     prefix='resnet', pretrained=True)
1596        >>> net1.hybridize()
1597        >>> x = mx.nd.random.normal(shape=(1, 3, 32, 32))
1598        >>> out1 = net1(x)
1599        >>> net1.export('net1', epoch=1)
1600        >>>
1601        >>> net2 = gluon.SymbolBlock.imports(
1602        ...     'net1-symbol.json', ['data'], 'net1-0001.params')
1603        >>> out2 = net2(x)
1604        """
1605        if is_np_array():
1606            sym = np_symbol.load(symbol_file)
1607        else:
1608            sym = symbol.load(symbol_file)
1609        if isinstance(input_names, str):
1610            input_names = [input_names]
1611        if param_file is None:
1612            # Get a valid type inference by using fp32
1613            inputs = [symbol.var(i, dtype=mx_real_t) for i in input_names]
1614        else:
1615            # Do not specify type, rely on saved params type instead
1616            inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names]
1617        ret = SymbolBlock(sym, inputs)
1618        if param_file is not None:
1619            ret.collect_params().load(param_file, ctx, allow_missing, ignore_extra, cast_dtype=True,
1620                                      dtype_source='saved')
1621        return ret
1622
1623    def __repr__(self):
1624        s = '{name}(\n{modstr}\n)'
1625        modstr = '\n'.join(['{block} : {numinputs} -> {numoutputs}'.format(block=self._cached_graph[1],
1626                                                                           numinputs=len(self._cached_graph[0]),
1627                                                                           numoutputs=len(self._cached_graph[1].
1628                                                                                          list_outputs()))])
1629        return s.format(name=self.__class__.__name__,
1630                        modstr=modstr)
1631
1632    def __init__(self, outputs, inputs, params=None):
1633        super(SymbolBlock, self).__init__(prefix=None, params=None)
1634        self._prefix = ''
1635        self._params = ParameterDict('', params)
1636        if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) == 1:
1637            inputs = [inputs]
1638        if isinstance(outputs, (list, tuple)) and len(outputs) == 1:
1639            outputs = outputs[0]
1640
1641        syms, self._in_format = _flatten(inputs, "input")
1642        out, self._out_format = _flatten(outputs, "output")
1643        input_names = set()
1644        for i in syms:
1645            assert len(i.get_internals().list_outputs()) == 1, \
1646                "Input symbols must be variable, but %s is an output of operators"%str(i)
1647            input_names.add(i.name)
1648
1649        # check if any symbol is row_sparse
1650        row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse']
1651
1652        for i in out:
1653            for j in i.get_internals():
1654                assert(j.attr("__storage_type__") != str(row_sparse_storage)), \
1655                    "SymbolBlock doesn't support Parameter '%s' because its storage " \
1656                    "type is 'row_sparse'." % j.name
1657        if len(out) > 1:
1658            out = symbol.Group(out, _check_same_symbol_type(out))
1659        else:
1660            out = out[0]
1661
1662        # Infer type of parameters. Without this, every parameter will be created with
1663        # default type i.e., fp32
1664        arg_params = out.list_arguments()
1665        aux_params = out.list_auxiliary_states()
1666
1667        arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params)
1668
1669        for i, arg in enumerate(arg_params):
1670            if arg not in input_names:
1671                self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i])
1672
1673        for i, aux in enumerate(aux_params):
1674            if aux not in input_names:
1675                self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i])
1676
1677        self._cached_graph = syms, out
1678        len_prefix = len(_common_prefix(list(self._params.keys())))
1679        self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()}
1680
1681    def forward(self, x, *args):
1682        if isinstance(x, NDArray):
1683            with x.ctx:
1684                return self._call_cached_op(x, *args)
1685
1686        assert isinstance(x, Symbol), \
1687            "HybridBlock requires the first argument to forward be either " \
1688            "Symbol or NDArray, but got %s"%type(x)
1689        args, in_fmt = _flatten([x] + list(args), "input")
1690        assert in_fmt == self._in_format, "Invalid input format"
1691        ret = copy.copy(self._cached_graph[1])
1692        ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)})
1693        return _regroup(list(ret), self._out_format)
1694
1695    def _clear_cached_op(self):
1696        tmp = self._cached_graph
1697        super(SymbolBlock, self)._clear_cached_op()
1698        self._cached_graph = tmp
1699
1700    def cast(self, dtype):
1701        self._clear_cached_op()
1702        super(SymbolBlock, self).cast(dtype)
1703        if np.dtype(dtype).name == 'float16':
1704            # correct BatchNorm types back to float32 due to its special requirement
1705            out = self._cached_graph[1]
1706            params_list = out.get_internals().list_inputs()
1707            for node in params_list:
1708                if node.endswith('running_var'):
1709                    prefix = node[:-11]
1710                    sibs = [prefix + t for t in ('running_mean', 'gamma', 'beta')]
1711                    is_bn = all(p in params_list for p in sibs)
1712                    if is_bn:
1713                        self.params.get(node).cast('float32')
1714                        for sib in sibs:
1715                            self.params.get(sib).cast('float32')
1716                if node.endswith('moving_var'):
1717                    # another convention used
1718                    prefix = node[:-10]
1719                    sibs = [prefix + t for t in ('moving_mean', 'gamma', 'beta')]
1720                    is_bn = all(p in params_list for p in sibs)
1721                    if is_bn:
1722                        self.params.get(node).cast('float32')
1723                        for sib in sibs:
1724                            self.params.get(sib).cast('float32')
1725
1726    def hybrid_forward(self, F, x, *args, **kwargs):
1727        raise NotImplementedError
1728
1729    def reset_ctx(self, ctx):
1730        """Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args.
1731        Parameters
1732        ----------
1733        ctx : Context or list of Context, default :py:meth:`context.current_context()`.
1734            Assign Parameter to given context. If ctx is a list of Context, a
1735            copy will be made for each context.
1736        """
1737        params = self.collect_params()
1738        if self._cached_op:
1739            for p in self._cached_op_args:
1740                # resetting parameters creating by the partitioning backend
1741                if p.name not in params:
1742                    p.reset_ctx(ctx)
1743        for p in params.values():
1744            p.reset_ctx(ctx)
1745
1746def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t):
1747    """Utility function that helps in inferring DType of args and auxs params
1748    from given input param.
1749
1750    Parameters
1751    ----------
1752    in_params: List of Symbol
1753        List of input symbol variables.
1754    out_params: Symbol
1755        Output symbol variable.
1756    arg_params: List of Str
1757        List of names of argument parametrs.
1758    aux_params: List of Str
1759        List of names of auxiliary parameters.
1760    default_dtype: numpy.dtype or str, default 'float32'
1761        Default data type for arg_params and aux_params, if unable to infer the type.
1762
1763    Returns
1764    -------
1765    arg_types: List of numpy.dtype
1766        List of arg_params type. Order is same as arg_params.
1767        Defaults to 'float32', if unable to infer type.
1768    aux_types: List of numpy.dtype
1769        List of aux_params type. Order is same as aux_params.
1770        Defaults to 'float32', if unable to infer type.
1771    """
1772    arg_types = None
1773    aux_types = None
1774
1775    # Get Input symbol details. This will be used to infer types of
1776    # other parameters.
1777    input_sym_names = [in_param.name for in_param in in_params]
1778
1779    # Try to infer input types. If not successful, we will set default dtype.
1780    # If successful, we will try to infer other params in the graph.
1781    input_sym_arg_types = []
1782    can_infer_input_type = True
1783    for in_param in in_params:
1784        input_sym_arg_type = in_param.infer_type()[0]
1785        if not input_sym_arg_type or len(input_sym_arg_type) < 1:
1786            can_infer_input_type = False
1787            break
1788        else:
1789            input_sym_arg_types.append(in_param.infer_type()[0][0])
1790
1791    # Try to infer types of other parameters.
1792    if can_infer_input_type:
1793        params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)}
1794        try:
1795            arg_types, _, aux_types = out_params.infer_type(**params)
1796        except MXNetError:
1797            # Cannot infer type with current input
1798            arg_types, aux_types = None, None
1799
1800    if arg_types is None or len(arg_types) != len(arg_params):
1801        arg_types = []
1802        for _ in arg_params:
1803            arg_types.append(default_dtype)
1804
1805    if aux_types is None or len(aux_types) != len(aux_params):
1806        aux_types = []
1807        for _ in aux_params:
1808            aux_types.append(default_dtype)
1809
1810    return (arg_types, aux_types)
1811