1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable= arguments-differ, too-many-lines, reimported 20"""Base container class for all neural network models.""" 21__all__ = ['Block', 'HybridBlock', 'SymbolBlock'] 22 23import threading 24import copy 25import warnings 26import re 27import json 28from collections import OrderedDict, defaultdict 29import numpy as np 30 31from ..base import mx_real_t, MXNetError 32from .. import symbol, ndarray, initializer, np_symbol 33from ..symbol import Symbol, load_json 34from ..ndarray import NDArray 35from .. import name as _name 36from .parameter import Parameter, ParameterDict, DeferredInitializationError 37from .utils import _indent, _brief_print_list, HookHandle 38from .utils import _check_same_symbol_type, _check_all_np_ndarrays 39from .. import numpy_extension as _mx_npx 40from .. import numpy as _mx_np 41from .. util import is_np_array, np_shape, np_array 42 43 44 45class _BlockScope(object): 46 """Scope for collecting child `Block` s.""" 47 _current = threading.local() 48 49 def __init__(self, block): 50 self._block = block 51 self._counter = {} 52 self._old_scope = None 53 self._name_scope = None 54 55 @staticmethod 56 def create(prefix, params, hint): 57 """Creates prefix and params for new `Block`.""" 58 current = getattr(_BlockScope._current, "value", None) 59 if current is None: 60 if prefix is None: 61 if not hasattr(_name.NameManager._current, "value"): 62 _name.NameManager._current.value = _name.NameManager() 63 prefix = _name.NameManager._current.value.get(None, hint) + '_' 64 if params is None: 65 params = ParameterDict(prefix) 66 else: 67 params = ParameterDict(params.prefix, params) 68 return prefix, params 69 70 if prefix is None: 71 count = current._counter.get(hint, 0) 72 prefix = '%s%d_'%(hint, count) 73 current._counter[hint] = count + 1 74 if params is None: 75 parent = current._block.params 76 params = ParameterDict(parent.prefix+prefix, parent._shared) 77 else: 78 params = ParameterDict(params.prefix, params) 79 return current._block.prefix+prefix, params 80 81 def __enter__(self): 82 if self._block._empty_prefix: 83 return self 84 self._old_scope = getattr(_BlockScope._current, "value", None) 85 _BlockScope._current.value = self 86 self._name_scope = _name.Prefix(self._block.prefix) 87 self._name_scope.__enter__() 88 return self 89 90 def __exit__(self, ptype, value, trace): 91 if self._block._empty_prefix: 92 return 93 self._name_scope.__exit__(ptype, value, trace) 94 self._name_scope = None 95 _BlockScope._current.value = self._old_scope 96 97 98def _gather_type_ctx_info(args): 99 """Analyze the elements inside the nested args object and find: 100 - If there exists ndarray 101 - If there exists symbol 102 - All contexts appearing in args 103 104 Parameters 105 ---------- 106 args : list or NDArray or Symbol 107 Could be a nested architecture. 108 109 Returns 110 ------- 111 has_symbol : bool 112 Whether the elements in args contains symbols 113 has_ndarray : bool 114 Whether the elements in args contains ndarrays 115 ctx_set : set of mxnet.context.Context 116 Contains all possible contexts of the inner ndarrays in args. Can be empty if there is no 117 ndarray inside args. 118 first_ctx : mxnet.context.Context or None 119 Context of the first appeared NDArray (for backward-compatibility) 120 """ 121 if isinstance(args, NDArray): 122 return False, True, {args.ctx}, args.ctx 123 elif isinstance(args, Symbol): 124 return True, False, set(), None 125 elif isinstance(args, (list, tuple)): 126 has_symbol = False 127 has_ndarray = False 128 ctx_set = set() 129 first_ctx = None 130 for ele in args: 131 ele_has_sym, ele_has_nd, ele_ctx_set, ele_first_ctx =\ 132 _gather_type_ctx_info(ele) 133 has_symbol = has_symbol or ele_has_sym 134 has_ndarray = has_ndarray or ele_has_nd 135 if first_ctx is None and ele_first_ctx is not None: 136 first_ctx = ele_first_ctx 137 ctx_set = ctx_set | ele_ctx_set 138 if has_symbol and has_ndarray: 139 break 140 return has_symbol, has_ndarray, ctx_set, first_ctx 141 else: 142 return False, False, set(), None 143 144 145def _flatten(args, inout_str): 146 """Parse the arguments into a flattened list + an additional format array. 147 The format array stores the structure of the original arguments to help reconstruct the inputs. 148 149 Parameters 150 ---------- 151 args : NDArray, Symbol, or (nested) list of Symbol or NDArray 152 We allow None inside the args. 153 inout_str : str 154 The name of the HybridBlock 155 156 Returns 157 ------- 158 flat : list of Symbol or NDArray 159 The flatten version of the input args. 160 fmts : (nested) list of ints 161 Stores the format information of the original structured args. 162 """ 163 if isinstance(args, NDArray): 164 return [args], int(0) 165 if isinstance(args, Symbol): 166 length = len(args.list_outputs()) 167 length = length if length > 1 else 0 168 return [args], int(length) 169 if args is None: 170 return [None], int(-1) 171 172 if not isinstance(args, (list, tuple)): 173 raise ValueError("When hybridized, the input of HybridBlock {}" 174 " must be (nested) list of Symbol" 175 " or NDArray, " 176 "but got {} of type {}".format(inout_str, str(args), str(type(args)))) 177 flat = [] 178 fmts = [] 179 for i in args: 180 arg, fmt = _flatten(i, inout_str) 181 flat.extend(arg) 182 fmts.append(fmt) 183 return flat, fmts 184 185 186def _regroup(args, fmt): 187 """Reconstruct the structured arguments based on the flattened version. 188 189 Parameters 190 ---------- 191 args : NDArray, Symbol, or (nested) list of Symbol or NDArray 192 We allow None inside the args. 193 fmt : (nested) list of ints 194 Stores the format information of the original structured args. 195 196 Returns 197 ------- 198 ret : NDArray, Symbol, or (nested) list of Symbol or NDArray 199 200 """ 201 def _merger(args, fmt): 202 """Recursive call to merge the arguments""" 203 if isinstance(fmt, int): 204 if fmt < -1: 205 raise ValueError("Unsupported encoded format {}.".format(fmt)) 206 if fmt == 0: 207 return args[0], args[1:] 208 if fmt == -1: 209 if args[0] is not None: 210 raise ValueError('We do not support passing types that are not None' 211 ' when the initial HybridBlock has received NoneType and' 212 ' has been hybridized.' 213 ' Received arg = {}, fmt = {}.'.format(args[0], fmt)) 214 return None, args[1:] 215 else: 216 return args[:fmt], args[fmt:] 217 218 if not isinstance(args, (list, tuple)): 219 raise ValueError("When hybridized, the output of HybridBlock must be (nested)" 220 " list of Symbol or NDArray, " 221 "but got {} of type {}".format(args, type(args))) 222 ret = [] 223 for i in fmt: 224 res, args = _merger(args, i) 225 ret.append(res) 226 return ret, args 227 return _merger(args, fmt)[0] 228 229 230class Block(object): 231 """Base class for all neural network layers and models. Your models should 232 subclass this class. 233 234 :py:class:`Block` can be nested recursively in a tree structure. You can create and 235 assign child :py:class:`Block` as regular attributes:: 236 237 from mxnet.gluon import Block, nn 238 from mxnet import ndarray as F 239 240 class Model(Block): 241 def __init__(self, **kwargs): 242 super(Model, self).__init__(**kwargs) 243 # use name_scope to give child Blocks appropriate names. 244 with self.name_scope(): 245 self.dense0 = nn.Dense(20) 246 self.dense1 = nn.Dense(20) 247 248 def forward(self, x): 249 x = F.relu(self.dense0(x)) 250 return F.relu(self.dense1(x)) 251 252 model = Model() 253 model.initialize(ctx=mx.cpu(0)) 254 model(F.zeros((10, 10), ctx=mx.cpu(0))) 255 256 257 Child :py:class:`Block` assigned this way will be registered and :py:meth:`collect_params` 258 will collect their Parameters recursively. You can also manually register 259 child blocks with :py:meth:`register_child`. 260 261 Parameters 262 ---------- 263 prefix : str 264 Prefix acts like a name space. All children blocks created in parent block's 265 :py:meth:`name_scope` will have parent block's prefix in their name. 266 Please refer to 267 `naming tutorial </api/python/docs/tutorials/packages/gluon/blocks/naming.html>`_ 268 for more info on prefix and naming. 269 params : ParameterDict or None 270 :py:class:`ParameterDict` for sharing weights with the new :py:class:`Block`. For example, 271 if you want ``dense1`` to share ``dense0``'s weights, you can do:: 272 273 dense0 = nn.Dense(20) 274 dense1 = nn.Dense(20, params=dense0.collect_params()) 275 """ 276 def __init__(self, prefix=None, params=None): 277 self._empty_prefix = prefix == '' 278 self._prefix, self._params = _BlockScope.create(prefix, params, self._alias()) 279 self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix 280 self._scope = _BlockScope(self) 281 self._children = OrderedDict() 282 self._reg_params = {} 283 self._forward_hooks = OrderedDict() 284 self._forward_pre_hooks = OrderedDict() 285 286 def __repr__(self): 287 s = '{name}(\n{modstr}\n)' 288 modstr = '\n'.join([' ({key}): {block}'.format(key=key, 289 block=_indent(block.__repr__(), 2)) 290 for key, block in self.__dict__.items() if isinstance(block, Block)]) 291 return s.format(name=self.__class__.__name__, modstr=modstr) 292 293 def __setattr__(self, name, value): 294 """Registers parameters.""" 295 296 if hasattr(self, name): 297 existing = getattr(self, name) 298 if isinstance(existing, (Parameter, Block)) and not isinstance(value, type(existing)): 299 raise TypeError('Changing attribute type for {name} from {type1} to {type2}' \ 300 'is not allowed.'.format( 301 name=name, type1=type(existing), type2=type(value))) 302 303 if isinstance(value, Block): 304 self.register_child(value, name) 305 elif isinstance(value, Parameter): 306 assert name not in self._reg_params, \ 307 "Overriding Parameter attribute %s is not allowed. " \ 308 "If you want to share parameters between blocks, please set " \ 309 "'params' at Block construction instead." 310 self._reg_params[name] = value 311 312 super(Block, self).__setattr__(name, value) 313 314 def _check_container_with_block(self): 315 children = set(self._children.values()) 316 def _find_unregistered_block_in_container(data): 317 # Find whether a nested container structure contains Blocks 318 if isinstance(data, (list, tuple)): 319 for ele in data: 320 if _find_unregistered_block_in_container(ele): 321 return True 322 return False 323 elif isinstance(data, dict): 324 for _, v in data.items(): 325 if _find_unregistered_block_in_container(v): 326 return True 327 return False 328 elif isinstance(data, Block): 329 return not data in children 330 else: 331 return False 332 for k, v in self.__dict__.items(): 333 if isinstance(v, (list, tuple, dict)) and not (k.startswith('__') or k == '_children'): 334 if _find_unregistered_block_in_container(v): 335 warnings.warn('"{name}" is an unregistered container with Blocks. ' 336 'Note that Blocks inside the list, tuple or dict will not be ' 337 'registered automatically. Make sure to register them using ' 338 'register_child() or switching to ' 339 'nn.Sequential/nn.HybridSequential instead. ' 340 .format(name=self.__class__.__name__ + "." + k), stacklevel=3) 341 342 def _alias(self): 343 return self.__class__.__name__.lower() 344 345 @property 346 def prefix(self): 347 """Prefix of this :py:class:`Block`.""" 348 return self._prefix 349 350 @property 351 def name(self): 352 """Name of this :py:class:`Block`, without '_' in the end.""" 353 return self._name 354 355 def name_scope(self): 356 """Returns a name space object managing a child :py:class:`Block` and parameter 357 names. Should be used within a ``with`` statement:: 358 359 with self.name_scope(): 360 self.dense = nn.Dense(20) 361 362 Please refer to 363 `the naming tutorial </api/python/docs/tutorials/packages/gluon/blocks/naming.html>`_ 364 for more info on prefix and naming. 365 """ 366 return self._scope 367 368 @property 369 def params(self): 370 """Returns this :py:class:`Block`'s parameter dictionary (does not include its 371 children's parameters).""" 372 return self._params 373 374 def collect_params(self, select=None): 375 """Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its 376 children's Parameters(default), also can returns the select :py:class:`ParameterDict` 377 which match some given regular expressions. 378 379 For example, collect the specified parameters in ['conv1_weight', 'conv1_bias', 'fc_weight', 380 'fc_bias']:: 381 382 model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias') 383 384 or collect all parameters whose names end with 'weight' or 'bias', this can be done 385 using regular expressions:: 386 387 model.collect_params('.*weight|.*bias') 388 389 Parameters 390 ---------- 391 select : str 392 regular expressions 393 394 Returns 395 ------- 396 The selected :py:class:`ParameterDict` 397 """ 398 # We need to check here because blocks inside containers are not supported. 399 self._check_container_with_block() 400 ret = ParameterDict(self._params.prefix) 401 if not select: 402 ret.update(self.params) 403 else: 404 pattern = re.compile(select) 405 ret.update({name:value for name, value in self.params.items() if pattern.match(name)}) 406 for cld in self._children.values(): 407 ret.update(cld.collect_params(select=select)) 408 return ret 409 410 def _collect_params_with_prefix(self, prefix=''): 411 if prefix: 412 prefix += '.' 413 ret = {prefix + key : val for key, val in self._reg_params.items()} 414 for name, child in self._children.items(): 415 ret.update(child._collect_params_with_prefix(prefix + name)) 416 return ret 417 418 def save_parameters(self, filename, deduplicate=False): 419 """Save parameters to file. 420 421 Saved parameters can only be loaded with `load_parameters`. Note that this 422 method only saves parameters, not model structure. If you want to save 423 model structures, please use :py:meth:`HybridBlock.export`. 424 425 Parameters 426 ---------- 427 filename : str 428 Path to file. 429 deduplicate : bool, default False 430 If True, save shared parameters only once. Otherwise, if a Block 431 contains multiple sub-blocks that share parameters, each of the 432 shared parameters will be separately saved for every sub-block. 433 434 References 435 ---------- 436 `Saving and Loading Gluon Models \ 437 <https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_ 438 """ 439 params = self._collect_params_with_prefix() 440 441 if deduplicate: 442 # Shared parameters are stored only a single time as of MXNet 1.6. 443 # Shared parameters are registered under multiple prefixes returned by 444 # _collect_params_with_prefix. We select a single one and only store 445 # it. In load_parameters it is sufficient for a shared parameter to 446 # only set it for a single prefix. 447 reverse_params = {v: k for k, v in params.items()} 448 params = {v: k for k, v in reverse_params.items()} 449 450 arg_dict = {key: val._reduce() for key, val in params.items()} 451 save_fn = _mx_npx.save if is_np_array() else ndarray.save 452 save_fn(filename, arg_dict) 453 454 def save_params(self, filename): 455 """[Deprecated] Please use save_parameters. Note that if you want load 456 from SymbolBlock later, please use export instead. 457 458 Save parameters to file. 459 460 filename : str 461 Path to file. 462 """ 463 warnings.warn("save_params is deprecated. Please use save_parameters. " 464 "Note that if you want load from SymbolBlock later, please " 465 "use export instead. For details, see " 466 "https://mxnet.apache.org/tutorials/gluon/save_lo" 467 "ad_params.html") 468 try: 469 self.collect_params().save(filename, strip_prefix=self.prefix) 470 except ValueError as e: 471 raise ValueError('%s\nsave_params is deprecated. Using ' \ 472 'save_parameters may resolve this error.'%e.message) 473 474 def load_parameters(self, filename, ctx=None, allow_missing=False, 475 ignore_extra=False, cast_dtype=False, dtype_source='current'): 476 """Load parameters from file previously saved by `save_parameters`. 477 478 Parameters 479 ---------- 480 filename : str 481 Path to parameter file. 482 ctx : Context or list of Context, default cpu() 483 Context(s) to initialize loaded parameters on. 484 allow_missing : bool, default False 485 Whether to silently skip loading parameters not represents in the file. 486 ignore_extra : bool, default False 487 Whether to silently ignore parameters from the file that are not 488 present in this Block. 489 cast_dtype : bool, default False 490 Cast the data type of the NDArray loaded from the checkpoint to the dtype 491 provided by the Parameter if any. 492 dtype_source : str, default 'current' 493 must be in {'current', 'saved'} 494 Only valid if cast_dtype=True, specify the source of the dtype for casting 495 the parameters 496 References 497 ---------- 498 `Saving and Loading Gluon Models \ 499 <https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_ 500 """ 501 if is_np_array(): 502 # failure may happen when loading parameters saved as NDArrays within 503 # NumPy semantics. Check the failure type and recover from it if it happens. 504 try: 505 loaded = _mx_npx.load(filename) 506 except MXNetError as e: 507 err_msg = str(e) 508 if 'is_np_shape' in err_msg: 509 # Loading failure due to parameters saved without numpy semantics. 510 # Temporarily disable numpy semantics and load parameters. After it's 511 # done, resume the numpy semantics. This is fine because the cases 512 # numpy ndarray covers is a superset of the legacy ndarray's. 513 with np_array(False): 514 with np_shape(False): 515 loaded_nds = ndarray.load(filename) 516 assert isinstance(loaded_nds, dict),\ 517 'expecting a dict type, got {}'.format(str(type(loaded_nds))) 518 loaded = {k: loaded_nds[k].as_np_ndarray() for k in loaded_nds} 519 else: 520 raise ValueError(err_msg) 521 else: 522 loaded = ndarray.load(filename) 523 params = self._collect_params_with_prefix() 524 if not loaded and not params: 525 return 526 527 if not any('.' in i for i in loaded.keys()): 528 # legacy loading 529 loaded = None # This should be changed to `del loaded` when dropping Python 2 530 self.collect_params().load( 531 filename, ctx, allow_missing, ignore_extra, self.prefix, 532 cast_dtype=cast_dtype, dtype_source=dtype_source) 533 return 534 535 if not allow_missing: 536 # Shared parameters are stored only a single time as of MXNet 1.6. 537 # We thus retrieve all prefixes (through _collect_params_with_prefix) 538 # that a shared parameter is used with. Check that there are no 539 # missing parameters that were not yet already loaded from the 540 # shared version. 541 params_inv = defaultdict(list) 542 for k, v in params.items(): 543 params_inv[v].append(k) 544 545 for name, param in params.items(): 546 assert any(p in loaded for p in params_inv[param]), \ 547 "Parameter '%s' is missing in file '%s', which contains parameters: %s. " \ 548 "Set allow_missing=True to ignore missing parameters."%( 549 name, filename, _brief_print_list(loaded.keys())) 550 for name in loaded: 551 if not ignore_extra and name not in params: 552 raise ValueError( 553 "Parameter '%s' loaded from file '%s' is not present in ParameterDict, " \ 554 "which contains parameters %s. Set ignore_extra=True to ignore. "%( 555 name, filename, _brief_print_list(self._params.keys()))) 556 if name in params: 557 params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source) 558 559 def load_params(self, filename, ctx=None, allow_missing=False, 560 ignore_extra=False): 561 """[Deprecated] Please use load_parameters. 562 563 Load parameters from file. 564 565 filename : str 566 Path to parameter file. 567 ctx : Context or list of Context, default cpu() 568 Context(s) to initialize loaded parameters on. 569 allow_missing : bool, default False 570 Whether to silently skip loading parameters not represents in the file. 571 ignore_extra : bool, default False 572 Whether to silently ignore parameters from the file that are not 573 present in this Block. 574 """ 575 warnings.warn("load_params is deprecated. Please use load_parameters.") 576 self.load_parameters(filename, ctx, allow_missing, ignore_extra) 577 578 def register_child(self, block, name=None): 579 """Registers block as a child of self. :py:class:`Block` s assigned to self as 580 attributes will be registered automatically.""" 581 if name is None: 582 name = str(len(self._children)) 583 self._children[name] = block 584 585 def register_forward_pre_hook(self, hook): 586 r"""Registers a forward pre-hook on the block. 587 588 The hook function is called immediately before :func:`forward`. 589 It should not modify the input or output. 590 591 Parameters 592 ---------- 593 hook : callable 594 The forward hook function of form `hook(block, input) -> None`. 595 596 Returns 597 ------- 598 :class:`mxnet.gluon.utils.HookHandle` 599 """ 600 handle = HookHandle() 601 handle.attach(self._forward_pre_hooks, hook) 602 return handle 603 604 def register_forward_hook(self, hook): 605 r"""Registers a forward hook on the block. 606 607 The hook function is called immediately after :func:`forward`. 608 It should not modify the input or output. 609 610 Parameters 611 ---------- 612 hook : callable 613 The forward hook function of form `hook(block, input, output) -> None`. 614 615 Returns 616 ------- 617 :class:`mxnet.gluon.utils.HookHandle` 618 """ 619 handle = HookHandle() 620 handle.attach(self._forward_hooks, hook) 621 return handle 622 623 def apply(self, fn): 624 r"""Applies ``fn`` recursively to every child block as well as self. 625 626 Parameters 627 ---------- 628 fn : callable 629 Function to be applied to each submodule, of form `fn(block)`. 630 631 Returns 632 ------- 633 this block 634 """ 635 for cld in self._children.values(): 636 cld.apply(fn) 637 fn(self) 638 return self 639 640 def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False, 641 force_reinit=False): 642 """Initializes :py:class:`Parameter` s of this :py:class:`Block` and its children. 643 Equivalent to ``block.collect_params().initialize(...)`` 644 645 Parameters 646 ---------- 647 init : Initializer 648 Global default Initializer to be used when :py:meth:`Parameter.init` is ``None``. 649 Otherwise, :py:meth:`Parameter.init` takes precedence. 650 ctx : Context or list of Context 651 Keeps a copy of Parameters on one or many context(s). 652 verbose : bool, default False 653 Whether to verbosely print out details on initialization. 654 force_reinit : bool, default False 655 Whether to force re-initialization if parameter is already initialized. 656 """ 657 self.collect_params().initialize(init, ctx, verbose, force_reinit) 658 659 def hybridize(self, active=True, **kwargs): 660 """ Please refer description of HybridBlock hybridize(). 661 """ 662 for cld in self._children.values(): 663 cld.hybridize(active, **kwargs) 664 665 def save(self, prefix): 666 """Save the model architecture and parameters to load again later 667 668 Saves the model architecture as a nested dictionary where each Block 669 in the model is a dictionary and its children are sub-dictionaries. 670 671 Each Block is uniquely identified by Block class name and a unique ID. 672 We save the child's name that that parent uses for it to restore later 673 in order to match the saved parameters. 674 675 Recursively traverses a Block's children in order (since its an 676 OrderedDict) and uses the unique ID to denote that specific Block. 677 Assumes that the model is created in an identical order every time. 678 If the model is not able to be recreated deterministically do not 679 use this set of APIs to save/load your model. 680 681 For HybridBlocks, the cached_graph (Symbol & inputs) is saved if 682 it has already been hybridized. 683 684 Parameters 685 ---------- 686 prefix : str 687 The prefix to use in filenames for saving this model: 688 <prefix>-model.json and <prefix>-model.params 689 """ 690 # create empty model structure 691 model = {} 692 def _save_cached_graphs(blk, index, structure): 693 # create new entry for this block 694 mdl = {'orig_name': blk.name} 695 # encode unique name based on block type and ID 696 name = type(blk).__name__.lower() 697 structure[name+str(index[0])] = mdl 698 if isinstance(blk, HybridBlock): 699 if blk._cached_graph: 700 # save in/out formats 701 mdl['in_format'] = blk._in_format 702 mdl['out_format'] = blk._out_format 703 # save cached graph & input symbols 704 syms, out = blk._cached_graph 705 mdl_syms = [] 706 for sym in syms: 707 mdl_syms.append(sym.tojson()) 708 mdl['inputs'] = mdl_syms 709 mdl['symbol'] = out.tojson() 710 mdl['hybridized'] = True 711 else: 712 mdl['hybridized'] = False 713 children = dict() 714 mdl['children'] = children 715 # recursively save children 716 for ch_name, child in blk._children.items(): 717 index[0] += 1 718 # save child's original name in this block's map 719 children[child.name] = ch_name 720 _save_cached_graphs(child, index, mdl) 721 # save top-level block 722 index = [0] 723 _save_cached_graphs(self, index, model) 724 # save model 725 with open(prefix+'-model.json', 'w') as fp: 726 json.dump(model, fp) 727 # save params 728 self.save_parameters(prefix+'-model.params') 729 730 def load(self, prefix): 731 """Load a model saved using the `save` API 732 733 Reconfigures a model using the saved configuration. This function 734 does not regenerate the model architecture. It resets the children's 735 names as they were when saved in order to match the names of the 736 saved parameters. 737 738 This function assumes the Blocks in the model were created in the same 739 order they were when the model was saved. This is because each Block is 740 uniquely identified by Block class name and a unique ID in order (since 741 its an OrderedDict) and uses the unique ID to denote that specific Block. 742 Assumes that the model is created in an identical order every time. 743 If the model is not able to be recreated deterministically do not 744 use this set of APIs to save/load your model. 745 746 For HybridBlocks, the cached_graph (Symbol & inputs) and settings are 747 restored if it had been hybridized before saving. 748 749 Parameters 750 ---------- 751 prefix : str 752 The prefix to use in filenames for loading this model: 753 <prefix>-model.json and <prefix>-model.params 754 """ 755 # load model json from file 756 with open(prefix+'-model.json') as fp: 757 model = json.load(fp) 758 759 def _load_cached_graphs(blk, index, structure): 760 # get block name 761 name = type(blk).__name__.lower() 762 # lookup previous encoded name based on block type and ID 763 mdl = structure[name+str(index[0])] 764 # rename block to what it was when saved 765 blk._name = mdl['orig_name'] 766 if isinstance(blk, HybridBlock): 767 if mdl['hybridized']: 768 # restore in/out formats 769 blk._in_format = mdl['in_format'] 770 blk._out_format = mdl['out_format'] 771 # get saved symbol 772 out = load_json(mdl['symbol']) 773 syms = [] 774 # recreate inputs for this symbol 775 for inp in mdl['inputs']: 776 syms.append(load_json(inp)) 777 # reset cached_graph and active status 778 blk._cached_graph = (syms, out) 779 blk._active = True 780 # rename params with updated block name 781 pnames = list(blk.params.keys()) 782 for p in pnames: 783 param = blk.params._params[p] 784 new_name = blk.name +'_'+ p[len(blk.params._prefix):] 785 blk.params._params.pop(p) 786 blk.params._params[new_name] = param 787 # recursively reload children 788 for ch_name, child in blk._children.items(): 789 index[0] += 1 790 _load_cached_graphs(child, index, mdl) 791 # current set of child names 792 ch_names = list(blk._children.keys()) 793 # original child names 794 children = mdl['children'] 795 # loop and remap children with original names 796 for ch_name in ch_names: 797 child = blk._children[ch_name] 798 blk._children.pop(ch_name) 799 orig_name = children[child.name] 800 blk._children[orig_name] = child 801 # load top-level block 802 index = [0] 803 _load_cached_graphs(self, index, model) 804 # load params 805 self.load_parameters(prefix+'-model.params') 806 807 def cast(self, dtype): 808 """Cast this Block to use another data type. 809 810 Parameters 811 ---------- 812 dtype : str or numpy.dtype 813 The new data type. 814 """ 815 for child in self._children.values(): 816 child.cast(dtype) 817 for _, param in self.params.items(): 818 param.cast(dtype) 819 820 def __call__(self, *args): 821 """Calls forward. Only accepts positional arguments.""" 822 for hook in self._forward_pre_hooks.values(): 823 hook(self, args) 824 825 out = self.forward(*args) 826 827 for hook in self._forward_hooks.values(): 828 hook(self, args, out) 829 if _mx_npx.is_np_array(): 830 _check_all_np_ndarrays(out) 831 return out 832 833 def forward(self, *args): 834 """Overrides to implement forward computation using :py:class:`NDArray`. Only 835 accepts positional arguments. 836 837 Parameters 838 ---------- 839 *args : list of NDArray 840 Input tensors. 841 """ 842 # pylint: disable= invalid-name 843 raise NotImplementedError 844 845 def register_op_hook(self, callback, monitor_all=False): 846 """Install callback monitor. 847 848 Parameters 849 ---------- 850 callback : function 851 Takes a string and a NDArrayHandle. 852 monitor_all : bool, default False 853 If true, monitor both input and output, otherwise monitor output only. 854 """ 855 for cld in self._children.values(): 856 cld.register_op_hook(callback, monitor_all) 857 858 def summary(self, *inputs): 859 """Print the summary of the model's output and parameters. 860 861 The network must have been initialized, and must not have been hybridized. 862 863 Parameters 864 ---------- 865 inputs : object 866 Any input that the model supports. For any tensor in the input, only 867 :class:`mxnet.ndarray.NDArray` is supported. 868 """ 869 summary = OrderedDict() 870 seen = set() 871 hooks = [] 872 873 def _get_shape_str(args): 874 def flatten(args): 875 if not isinstance(args, (list, tuple)): 876 return [args], int(0) 877 flat = [] 878 fmts = [] 879 for i in args: 880 arg, fmt = flatten(i) 881 flat.extend(arg) 882 fmts.append(fmt) 883 return flat, fmts 884 885 def regroup(args, fmt): 886 if isinstance(fmt, int): 887 if fmt == 0: 888 return args[0], args[1:] 889 return args[:fmt], args[fmt:] 890 ret = [] 891 for i in fmt: 892 res, args = regroup(args, i) 893 ret.append(res) 894 return ret, args 895 896 flat_args, fmts = flatten(args) 897 flat_arg_shapes = [x.shape if isinstance(x, ndarray.NDArray) else x 898 for x in flat_args] 899 shapes = regroup(flat_arg_shapes, fmts)[0] 900 if isinstance(shapes, list): 901 shape_str = str(shapes)[1:-1] 902 else: 903 shape_str = str(shapes) 904 return shape_str.replace('L', '') 905 906 def _register_summary_hook(block): 907 assert not isinstance(block, HybridBlock) or not block._active, \ 908 '"{}" must not be hybridized to print summary.'.format(block.name) 909 def _summary_hook(block, _, outputs): 910 class_name = block.__class__.__name__ 911 block_idx = len(summary) - 1 912 913 m_key = '%s-%i' % (class_name, block_idx+1) 914 summary[m_key] = OrderedDict() 915 summary[m_key]['output_shape'] = _get_shape_str(outputs) 916 917 params = 0 918 summary[m_key]['trainable'] = 0 919 summary[m_key]['shared'] = 0 920 for p in block.params.values(): 921 params += p.data().size 922 summary[m_key]['trainable'] += 0 if p.grad_req == 'null' else p.data().size 923 if p in seen: 924 summary[m_key]['shared'] += p.data().size 925 else: 926 seen.add(p) 927 summary[m_key]['n_params'] = params 928 929 from .nn.basic_layers import Sequential, HybridSequential 930 if not isinstance(block, (Sequential, HybridSequential)): 931 hooks.append(block.register_forward_hook(_summary_hook)) 932 933 summary['Input'] = OrderedDict() 934 summary['Input']['output_shape'] = _get_shape_str(inputs) 935 summary['Input']['n_params'] = 0 936 summary['Input']['trainable'] = 0 937 summary['Input']['shared'] = 0 938 939 try: 940 self.apply(_register_summary_hook) 941 self(*inputs) 942 943 line_format = '{:>20} {:>42} {:>15}' 944 print('-'*80) 945 print(line_format.format('Layer (type)', 'Output Shape', 'Param #')) 946 print('='*80) 947 total_params = 0 948 trainable_params = 0 949 shared_params = 0 950 for layer in summary: 951 print(line_format.format(layer, 952 str(summary[layer]['output_shape']), 953 summary[layer]['n_params'])) 954 total_params += summary[layer]['n_params'] 955 trainable_params += summary[layer]['trainable'] 956 shared_params += summary[layer]['shared'] 957 print('='*80) 958 print('Parameters in forward computation graph, duplicate included') 959 print(' Total params: ' + str(total_params)) 960 print(' Trainable params: ' + str(trainable_params)) 961 print(' Non-trainable params: ' + str(total_params - trainable_params)) 962 print('Shared params in forward computation graph: ' + str(shared_params)) 963 print('Unique parameters in model: ' + str(total_params - shared_params)) 964 print('-'*80) 965 finally: 966 for h in hooks: 967 h.detach() 968 969 970class HybridBlock(Block): 971 """`HybridBlock` supports forwarding with both Symbol and NDArray. 972 973 `HybridBlock` is similar to `Block`, with a few differences:: 974 975 import mxnet as mx 976 from mxnet.gluon import HybridBlock, nn 977 978 class Model(HybridBlock): 979 def __init__(self, **kwargs): 980 super(Model, self).__init__(**kwargs) 981 # use name_scope to give child Blocks appropriate names. 982 with self.name_scope(): 983 self.dense0 = nn.Dense(20) 984 self.dense1 = nn.Dense(20) 985 986 def hybrid_forward(self, F, x): 987 x = F.relu(self.dense0(x)) 988 return F.relu(self.dense1(x)) 989 990 model = Model() 991 model.initialize(ctx=mx.cpu(0)) 992 model.hybridize() 993 model(mx.nd.zeros((10, 10), ctx=mx.cpu(0))) 994 995 Forward computation in :py:class:`HybridBlock` must be static to work with :py:class:`Symbol` s, 996 i.e. you cannot call :py:meth:`NDArray.asnumpy`, :py:attr:`NDArray.shape`, 997 :py:attr:`NDArray.dtype`, `NDArray` indexing (`x[i]`) etc on tensors. 998 Also, you cannot use branching or loop logic that bases on non-constant 999 expressions like random numbers or intermediate results, since they change 1000 the graph structure for each iteration. 1001 1002 Before activating with :py:meth:`hybridize()`, :py:class:`HybridBlock` works just like normal 1003 :py:class:`Block`. After activation, :py:class:`HybridBlock` will create a symbolic graph 1004 representing the forward computation and cache it. On subsequent forwards, 1005 the cached graph will be used instead of :py:meth:`hybrid_forward`. 1006 1007 Please see references for detailed tutorial. 1008 1009 References 1010 ---------- 1011 `Hybrid - Faster training and easy deployment 1012 <https://mxnet.io/tutorials/gluon/hybrid.html>`_ 1013 """ 1014 def __init__(self, prefix=None, params=None): 1015 super(HybridBlock, self).__init__(prefix=prefix, params=params) 1016 self._cached_graph = () 1017 self._cached_op = None 1018 self._cached_op_args = [] 1019 self._out_format = None 1020 self._in_format = None 1021 self._active = False 1022 self._flags = [] 1023 self._callback = None 1024 self._monitor_all = False 1025 self._backend = None 1026 self._backend_opts = {} 1027 1028 def __setattr__(self, name, value): 1029 """Registers parameters.""" 1030 super(HybridBlock, self).__setattr__(name, value) 1031 if isinstance(value, HybridBlock): 1032 self._clear_cached_op() 1033 1034 def _get_graph(self, *args): 1035 if not self._cached_graph: 1036 flatten_args, self._in_format = _flatten(args, "input") 1037 flatten_inputs = [] 1038 symbol_inputs = [] 1039 cnt = 0 1040 real_arg_num = sum([ele is not None for ele in flatten_args]) 1041 if real_arg_num == 0: 1042 raise ValueError('All args are None and we do not support such a case.' 1043 ' Received args={}'.format(args)) 1044 for arg in flatten_args: 1045 if arg is not None: 1046 if real_arg_num > 1: 1047 arg_sym = symbol.var('data{}'.format(cnt)) 1048 else: 1049 arg_sym = symbol.var('data') 1050 if isinstance(arg, _mx_np.ndarray): 1051 arg_sym = arg_sym.as_np_ndarray() 1052 cnt += 1 1053 flatten_inputs.append(arg_sym) 1054 symbol_inputs.append(arg_sym) 1055 else: 1056 flatten_inputs.append(None) 1057 grouped_inputs = _regroup(flatten_inputs, self._in_format) 1058 params = {i: j.var() for i, j in self._reg_params.items()} 1059 with self.name_scope(): 1060 out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter 1061 out, self._out_format = _flatten(out, "output") 1062 1063 self._cached_graph = symbol_inputs, symbol.Group(out, _check_same_symbol_type(out)) 1064 1065 return self._cached_graph 1066 1067 def _build_cache(self, *args): 1068 data, out = self._get_graph(*args) 1069 data_names = {data.name: i for i, data in enumerate(data)} 1070 input_names = out.list_inputs() 1071 expected_names = set(input_names) 1072 1073 # try to reuse cached_op_args for params 1074 if len(self._cached_op_args) > 0: 1075 params = {param_tuple[1].name:param_tuple[1] 1076 for param_tuple in self._cached_op_args 1077 if isinstance(param_tuple[1], Parameter)} 1078 else: 1079 params = self.collect_params() 1080 param_names = set(params.keys()) 1081 for name in expected_names: 1082 assert name in param_names or name in data_names, \ 1083 "Unknown input to HybridBlock: %s" %name 1084 1085 used_data_names = [i for i in data_names if i in expected_names] 1086 if len(used_data_names) != len(data_names): 1087 unused = ', '.join(['%d-th'%i for name, i in data_names.items() 1088 if name not in expected_names]) 1089 warnings.warn("The %s input to HybridBlock is not used by any " 1090 "computation. Is this intended?"%unused, stacklevel=4) 1091 1092 used_param_names = [i for i in param_names if i in expected_names] 1093 if len(used_param_names) != len(param_names): 1094 unused = ', '.join(list(param_names - set(used_param_names))) 1095 warnings.warn("Parameter %s is not used by any computation. " 1096 "Is this intended?"%unused, stacklevel=4) 1097 1098 args, _ = _flatten(args, "input") 1099 try: 1100 for name in input_names: 1101 if name in params: 1102 params[name].data() 1103 except DeferredInitializationError: 1104 self._deferred_infer_shape(*args) 1105 for name in input_names: 1106 if name in params: 1107 params[name]._finish_deferred_init() 1108 1109 arg_dict, aux_dict = dict(), dict() 1110 if self._backend: 1111 # set context for inputs 1112 _, _, ctx_set, _ = _gather_type_ctx_info(list(args)) 1113 ctx = ctx_set.pop() if len(ctx_set) > 0 else None 1114 # get list of params in the order of out.list_arguments 1115 input_shapes = dict() 1116 for name in out.list_arguments(): 1117 if name in data_names.keys() and data_names[name] < len(args): 1118 if isinstance(args[data_names[name]], NDArray): 1119 arg_dict[name] = args[data_names[name]] 1120 elif (isinstance(args[data_names[name]], symbol.Symbol) and 1121 '__shape__' in args[data_names[name]].list_attr()): 1122 shape_str = args[data_names[name]].list_attr()['__shape__'] 1123 input_shapes[name] = tuple(map(int, shape_str.strip('()').split(','))) 1124 elif name in params: 1125 arg_dict[name] = params[name].data() 1126 1127 for name in out.list_auxiliary_states(): 1128 if name in data_names.keys() and data_names[name] < len(args): 1129 if isinstance(args[data_names[name]], NDArray): 1130 aux_dict[name] = args[data_names[name]] 1131 elif (isinstance(args[data_names[name]], symbol.Symbol) and 1132 '__shape__' in args[data_names[name]].list_attr()): 1133 shape_str = args[data_names[name]].list_attr()['__shape__'] 1134 input_shapes[name] = tuple(map(int, shape_str.strip('()').split(','))) 1135 elif name in params: 1136 aux_dict[name] = params[name].data() 1137 1138 # Partition the graph 1139 out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, input_shapes, **self._backend_opts) 1140 1141 # convert to numpy symbol if needed 1142 if _mx_npx.is_np_array(): 1143 out = out.as_np_ndarray() 1144 1145 #update cached graph with partitioned graph 1146 self._cached_graph = data, out 1147 1148 input_names = out.list_inputs() 1149 data_indices = [] 1150 param_indices = [] 1151 1152 # In the default case, _cached_ops_args contains all the parameters from params (the sets are identical) 1153 # In the case of Partition API optimized graph _cached_ops_args might contain some parameters from params, 1154 # might contain some new parameters created during optimization and added to `arg_dict/aux_dict`, 1155 # and might not contain some parameters that were deleted during optimization. 1156 self._cached_op_args = [] 1157 for i, name in enumerate(input_names): 1158 pair = None 1159 if name in data_names: 1160 data_indices.append(i) 1161 pair = (True, data_names[name]) 1162 else: 1163 param_indices.append(i) 1164 if name in params: 1165 param = params[name] 1166 else: 1167 # The param is missing from the original params dictionary, which means the param must have 1168 # been added by the Partition API backend 1169 if name in arg_dict or name: 1170 param_data = arg_dict[name] 1171 elif name in aux_dict: 1172 param_data = aux_dict[name] 1173 else: 1174 raise RuntimeError('A parameter was added to the graph during optimization but it was not ' 1175 'added to the parameter dicts.\n' 1176 'Please check the backend.') 1177 1178 param = Parameter(name, dtype=param_data.dtype) 1179 param._load_init(param_data, param_data.context) 1180 pair = (False, param) 1181 1182 self._cached_op_args.append(pair) 1183 1184 flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \ 1185 self._flags 1186 1187 self._cached_op = ndarray.CachedOp(out, flags) 1188 1189 1190 def _deferred_infer_shape(self, *args): 1191 try: 1192 self.infer_shape(*args) 1193 except Exception as e: 1194 error_msg = "Deferred initialization failed because shape"\ 1195 " cannot be inferred. {}".format(e) 1196 raise ValueError(error_msg) 1197 1198 def _call_cached_op(self, *args): 1199 if self._cached_op is None: 1200 self._build_cache(*args) 1201 assert self._cached_op, "Gluon failed to build the cache. " \ 1202 "This should never happen. " \ 1203 "Please submit an issue on Github" \ 1204 " https://github.com/apache/incubator-mxnet." 1205 if self._callback: 1206 self._cached_op._register_op_hook(self._callback, self._monitor_all) 1207 if len(self._flags) >= 2 and (self._flags[1] or self._flags[0]): 1208 warnings.warn("register_op_hook is experimental when static_alloc=True / static_shape=True " 1209 " and may not work correctly") 1210 1211 args, fmt = _flatten(args, "input") 1212 if fmt != self._in_format: 1213 # Do not raise in the case that the fmt or stored_fmt ends with None and 1214 # We are relying on the default values. 1215 if len(self._in_format) > len(fmt): 1216 valid = all([self._in_format[i] == -1 1217 for i in range(len(fmt), len(self._in_format))]) 1218 valid = valid and (fmt == self._in_format[:len(fmt)]) 1219 elif len(self._in_format) < len(fmt): 1220 valid = all([fmt[i] == -1 1221 for i in range(len(self._in_format), len(fmt))]) 1222 valid = valid and (fmt[:len(self._in_format)] == self._in_format) 1223 else: 1224 valid = False 1225 if not valid: 1226 raise ValueError("The argument structure of HybridBlock does not match" 1227 " the cached version. Stored format = {}, input format = {}" 1228 .format(fmt, self._in_format)) 1229 1230 args_without_none = [ele for ele in args if ele is not None] 1231 cargs = [args_without_none[i] if is_arg else i.data() 1232 for is_arg, i in self._cached_op_args] 1233 out = self._cached_op(*cargs) 1234 if isinstance(out, NDArray): 1235 out = [out] 1236 return _regroup(out, self._out_format) 1237 1238 def optimize_for(self, x, *args, backend=None, clear=False, 1239 static_alloc=False, 1240 static_shape=False, 1241 inline_limit=2, 1242 forward_bulk_size=None, 1243 backward_bulk_size=None, 1244 **kwargs): 1245 """Partitions the current HybridBlock and optimizes it for a given backend 1246 without executing a forward pass. Modifies the HybridBlock in-place. 1247 1248 Immediately partitions a HybridBlock using the specified backend. Combines 1249 the work done in the hybridize API with part of the work done in the forward 1250 pass without calling the CachedOp. Can be used in place of hybridize, 1251 afterwards `export` can be called or inference can be run. See README.md in 1252 example/extensions/lib_subgraph/README.md for more details. 1253 1254 Examples 1255 -------- 1256 # partition and then export to file 1257 block.optimize_for(x, backend='myPart') 1258 block.export('partitioned') 1259 1260 # partition and then run inference 1261 block.optimize_for(x, backend='myPart') 1262 block(x) 1263 1264 Parameters 1265 ---------- 1266 x : NDArray 1267 first input to model 1268 *args : NDArray 1269 other inputs to model 1270 backend : str 1271 The name of backend, as registered in `SubgraphBackendRegistry`, default None 1272 clear : bool, default False 1273 Clears any previous optimizations 1274 static_alloc : bool, default False 1275 Statically allocate memory to improve speed. Memory usage may increase. 1276 static_shape : bool, default False 1277 Optimize for invariant input shapes between iterations. Must also 1278 set static_alloc to True. Change of input shapes is still allowed 1279 but slower. 1280 inline_limit : optional int, default 2 1281 Maximum number of operators that can be inlined. 1282 forward_bulk_size : optional int, default None 1283 Segment size of bulk execution during forward pass. 1284 backward_bulk_size : optional int, default None 1285 Segment size of bulk execution during forward pass. 1286 **kwargs: The backend options, optional 1287 Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty` 1288 """ 1289 if len(kwargs) > 0: 1290 self._backend_opts = kwargs 1291 if not backend: 1292 raise ValueError('Must specify "backend" to optimize_for') 1293 1294 self.hybridize(True, backend, clear, static_alloc, static_shape, 1295 inline_limit, forward_bulk_size, backward_bulk_size) 1296 1297 # do part of forward API call 1298 has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args)) 1299 if not has_symbol and not has_ndarray: 1300 raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.' 1301 ' Please check the type of the args.\n') 1302 if len(ctx_set) > 1: 1303 raise ValueError('Found multiple contexts in the input, ' 1304 'After hybridized, the HybridBlock only supports one input ' 1305 'context. You can print the ele.ctx in the ' 1306 'input arguments to inspect their contexts. ' 1307 'Find all contexts = {}'.format(ctx_set)) 1308 1309 self._build_cache(x, *args) 1310 assert self._cached_op, "Gluon failed to build the cache. " \ 1311 "This should never happen. " \ 1312 "Please submit an issue on Github" \ 1313 " https://github.com/apache/incubator-mxnet." 1314 # do not actually call the cached_op 1315 1316 def _clear_cached_op(self): 1317 self._cached_graph = () 1318 self._cached_op = None 1319 self._cached_op_args = [] 1320 1321 def register_child(self, block, name=None): 1322 if not isinstance(block, HybridBlock): 1323 raise ValueError( 1324 "Children of HybridBlock must also be HybridBlock, " \ 1325 "but %s has type %s. If you are using Sequential, " \ 1326 "please try HybridSequential instead."%( 1327 str(block), str(type(block)))) 1328 super(HybridBlock, self).register_child(block, name) 1329 self._clear_cached_op() 1330 1331 def hybridize(self, active=True, backend=None, clear=True, 1332 static_alloc=False, static_shape=False, 1333 inline_limit=2, 1334 forward_bulk_size=None, 1335 backward_bulk_size=None, 1336 **kwargs): 1337 """Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on 1338 non-hybrid children. 1339 1340 Parameters 1341 ---------- 1342 active : bool, default True 1343 Whether to turn hybrid on or off. 1344 backend : str 1345 The name of backend, as registered in `SubgraphBackendRegistry`, default None 1346 clear : bool, default True 1347 Clears any previous optimizations 1348 static_alloc : optional bool, default False 1349 Statically allocate memory to improve speed. Memory usage may increase. 1350 static_shape : optional bool, default False 1351 Optimize for invariant input shapes between iterations. Must also 1352 set static_alloc to True. Change of input shapes is still allowed 1353 but slower. 1354 inline_limit : optional int, default 2 1355 Maximum number of operators that can be inlined. 1356 forward_bulk_size : optional int, default None 1357 Segment size of bulk execution during forward pass. 1358 backward_bulk_size : optional int, default None 1359 Segment size of bulk execution during forward pass. 1360 **kwargs: optional 1361 Backend options. 1362 """ 1363 if len(kwargs) > 0: 1364 self._backend_opts = kwargs 1365 1366 self._backend = backend 1367 1368 self._active = active 1369 self._flags = [("static_alloc", static_alloc), ("static_shape", static_shape), 1370 ("inline_limit", inline_limit)] 1371 if forward_bulk_size is not None: 1372 self._flags.append(("forward_bulk_size", forward_bulk_size)) 1373 if backward_bulk_size is not None: 1374 self._flags.append(("backward_bulk_size", backward_bulk_size)) 1375 if clear: 1376 self._clear_cached_op() 1377 if active and self._forward_hooks or self._forward_pre_hooks: 1378 warnings.warn('"{block}" is being hybridized while still having forward hook/pre-hook. ' 1379 'If "{block}" is a child of HybridBlock, the hooks will not take effect.' 1380 .format(block=self)) 1381 super(HybridBlock, self).hybridize(active, 1382 static_alloc=static_alloc, 1383 static_shape=static_shape, 1384 inline_limit=inline_limit, 1385 forward_bulk_size=forward_bulk_size, 1386 backward_bulk_size=backward_bulk_size) 1387 1388 def cast(self, dtype): 1389 self._clear_cached_op() 1390 super(HybridBlock, self).cast(dtype) 1391 1392 def _infer_attrs(self, infer_fn, attr, *args): 1393 """Generic infer attributes.""" 1394 inputs, out = self._get_graph(*args) 1395 args, _ = _flatten(args, "input") 1396 args_without_none = [ele for ele in args if ele is not None] 1397 with warnings.catch_warnings(record=True) as w: 1398 arg_attrs, _, aux_attrs = getattr(out, infer_fn)( 1399 **{i.name: getattr(j, attr) for i, j in zip(inputs, args_without_none)}) 1400 if arg_attrs is None: 1401 raise ValueError(w[0].message) 1402 sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)} 1403 sdict.update({name : attr for name, attr in \ 1404 zip(out.list_auxiliary_states(), aux_attrs)}) 1405 for i in self.collect_params().values(): 1406 setattr(i, attr, sdict[i.name]) 1407 1408 def infer_shape(self, *args): 1409 """Infers shape of Parameters from inputs.""" 1410 self._infer_attrs('infer_shape', 'shape', *args) 1411 1412 def infer_type(self, *args): 1413 """Infers data type of Parameters from inputs.""" 1414 self._infer_attrs('infer_type', 'dtype', *args) 1415 1416 def export(self, path, epoch=0, remove_amp_cast=True): 1417 """Export HybridBlock to json format that can be loaded by 1418 `gluon.SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface. 1419 1420 .. note:: When there are only one input, it will have name `data`. When there 1421 Are more than one inputs, they will be named as `data0`, `data1`, etc. 1422 1423 Parameters 1424 ---------- 1425 path : str 1426 Path to save model. Two files `path-symbol.json` and `path-xxxx.params` 1427 will be created, where xxxx is the 4 digits epoch number. 1428 epoch : int 1429 Epoch number of saved model. 1430 """ 1431 if not self._cached_graph: 1432 raise RuntimeError( 1433 "Please first call block.hybridize() and then run forward with " 1434 "this block at least once before calling export.") 1435 sym = self._cached_graph[1] 1436 sym.save('%s-symbol.json'%path, remove_amp_cast=remove_amp_cast) 1437 1438 arg_names = set(sym.list_arguments()) 1439 aux_names = set(sym.list_auxiliary_states()) 1440 arg_dict = {} 1441 for is_arg, param in self._cached_op_args: 1442 if not is_arg: 1443 name = param.name 1444 if name in arg_names: 1445 arg_dict['arg:{}'.format(name)] = param._reduce() 1446 else: 1447 if name not in aux_names: 1448 warnings.warn('Parameter "{name}" is not found in the graph. ' 1449 .format(name=name), stacklevel=3) 1450 else: 1451 arg_dict['aux:%s'%name] = param._reduce() 1452 save_fn = _mx_npx.save if is_np_array() else ndarray.save 1453 save_fn('%s-%04d.params'%(path, epoch), arg_dict) 1454 1455 def register_op_hook(self, callback, monitor_all=False): 1456 """Install op hook for block recursively. 1457 1458 Parameters 1459 ---------- 1460 callback : function 1461 Takes a string and a NDArrayHandle. 1462 monitor_all : bool, default False 1463 If true, monitor both input and output, otherwise monitor output only. 1464 """ 1465 self._callback = callback 1466 self._monitor_all = monitor_all 1467 for cld in self._children.values(): 1468 cld._callback = callback 1469 cld._monitor_all = monitor_all 1470 1471 def forward(self, x, *args): 1472 """Defines the forward computation. Arguments can be either 1473 :py:class:`NDArray` or :py:class:`Symbol`.""" 1474 1475 has_symbol, has_ndarray, ctx_set, first_ctx = _gather_type_ctx_info([x] + list(args)) 1476 if has_symbol and has_ndarray: 1477 raise ValueError('In HybridBlock, we do not support mixed NDArrays and Symbols' 1478 ' types for the input. Please check the type of the args.\n') 1479 if not has_symbol and not has_ndarray: 1480 raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.' 1481 ' Please check the type of the args.\n') 1482 if has_ndarray: 1483 ctx = first_ctx 1484 if self._active: 1485 if len(ctx_set) > 1: 1486 raise ValueError('Find multiple contexts in the input, ' 1487 'After hybridized, the HybridBlock only supports one input ' 1488 'context. You can print the ele.ctx in the ' 1489 'input arguments to inspect their contexts. ' 1490 'Find all contexts = {}'.format(ctx_set)) 1491 with ctx: 1492 return self._call_cached_op(x, *args) 1493 with ctx: 1494 try: 1495 params = {k: v.data(ctx) for k, v in self._reg_params.items()} 1496 except DeferredInitializationError: 1497 self._deferred_infer_shape(x, *args) 1498 for _, v in self.params.items(): 1499 v._finish_deferred_init() 1500 params = {k: v.data(ctx) for k, v in self._reg_params.items()} 1501 1502 return self.hybrid_forward(ndarray, x, *args, **params) 1503 params = {i: j.var() for i, j in self._reg_params.items()} 1504 with self.name_scope(): 1505 return self.hybrid_forward(symbol, x, *args, **params) 1506 1507 def hybrid_forward(self, F, x, *args, **kwargs): 1508 """Overrides to construct symbolic graph for this `Block`. 1509 1510 Parameters 1511 ---------- 1512 x : Symbol or NDArray 1513 The first input tensor. 1514 *args : list of Symbol or list of NDArray 1515 Additional input tensors. 1516 """ 1517 # pylint: disable= invalid-name 1518 raise NotImplementedError 1519 1520def _common_prefix(names): 1521 """Get the common prefix for all names""" 1522 if not names: 1523 return '' 1524 prefix = names[0] 1525 for name in names: 1526 i = 0 1527 while i < len(prefix) and i < len(name) and prefix[i] == name[i]: 1528 i += 1 1529 prefix = prefix[:i] 1530 return prefix 1531 1532 1533class SymbolBlock(HybridBlock): 1534 """Construct block from symbol. This is useful for using pre-trained models 1535 as feature extractors. For example, you may want to extract the output 1536 from fc2 layer in AlexNet. 1537 1538 Parameters 1539 ---------- 1540 outputs : Symbol or list of Symbol 1541 The desired output for SymbolBlock. 1542 inputs : Symbol or list of Symbol 1543 The Variables in output's argument that should be used as inputs. 1544 params : ParameterDict 1545 Parameter dictionary for arguments and auxililary states of outputs 1546 that are not inputs. 1547 1548 Examples 1549 -------- 1550 >>> # To extract the feature from fc1 and fc2 layers of AlexNet: 1551 >>> alexnet = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=mx.cpu(), 1552 prefix='model_') 1553 >>> inputs = mx.sym.var('data') 1554 >>> out = alexnet(inputs) 1555 >>> internals = out.get_internals() 1556 >>> print(internals.list_outputs()) 1557 ['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...] 1558 >>> outputs = [internals['model_dense0_relu_fwd_output'], 1559 internals['model_dense1_relu_fwd_output']] 1560 >>> # Create SymbolBlock that shares parameters with alexnet 1561 >>> feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params()) 1562 >>> x = mx.nd.random.normal(shape=(16, 3, 224, 224)) 1563 >>> print(feat_model(x)) 1564 """ 1565 @staticmethod 1566 def imports(symbol_file, input_names, param_file=None, ctx=None, allow_missing=False, 1567 ignore_extra=False): 1568 """Import model previously saved by `gluon.HybridBlock.export` or 1569 `Module.save_checkpoint` as a `gluon.SymbolBlock` for use in Gluon. 1570 1571 Parameters 1572 ---------- 1573 symbol_file : str 1574 Path to symbol file. 1575 input_names : list of str 1576 List of input variable names 1577 param_file : str, optional 1578 Path to parameter file. 1579 ctx : Context, default None 1580 The context to initialize `gluon.SymbolBlock` on. 1581 allow_missing : bool, default False 1582 Whether to silently skip loading parameters not represents in the file. 1583 ignore_extra : bool, default False 1584 Whether to silently ignore parameters from the file that are not 1585 present in this Block. 1586 1587 Returns 1588 ------- 1589 gluon.SymbolBlock 1590 `gluon.SymbolBlock` loaded from symbol and parameter files. 1591 1592 Examples 1593 -------- 1594 >>> net1 = gluon.model_zoo.vision.resnet18_v1( 1595 ... prefix='resnet', pretrained=True) 1596 >>> net1.hybridize() 1597 >>> x = mx.nd.random.normal(shape=(1, 3, 32, 32)) 1598 >>> out1 = net1(x) 1599 >>> net1.export('net1', epoch=1) 1600 >>> 1601 >>> net2 = gluon.SymbolBlock.imports( 1602 ... 'net1-symbol.json', ['data'], 'net1-0001.params') 1603 >>> out2 = net2(x) 1604 """ 1605 if is_np_array(): 1606 sym = np_symbol.load(symbol_file) 1607 else: 1608 sym = symbol.load(symbol_file) 1609 if isinstance(input_names, str): 1610 input_names = [input_names] 1611 if param_file is None: 1612 # Get a valid type inference by using fp32 1613 inputs = [symbol.var(i, dtype=mx_real_t) for i in input_names] 1614 else: 1615 # Do not specify type, rely on saved params type instead 1616 inputs = [symbol.var(i).as_np_ndarray() if is_np_array() else symbol.var(i) for i in input_names] 1617 ret = SymbolBlock(sym, inputs) 1618 if param_file is not None: 1619 ret.collect_params().load(param_file, ctx, allow_missing, ignore_extra, cast_dtype=True, 1620 dtype_source='saved') 1621 return ret 1622 1623 def __repr__(self): 1624 s = '{name}(\n{modstr}\n)' 1625 modstr = '\n'.join(['{block} : {numinputs} -> {numoutputs}'.format(block=self._cached_graph[1], 1626 numinputs=len(self._cached_graph[0]), 1627 numoutputs=len(self._cached_graph[1]. 1628 list_outputs()))]) 1629 return s.format(name=self.__class__.__name__, 1630 modstr=modstr) 1631 1632 def __init__(self, outputs, inputs, params=None): 1633 super(SymbolBlock, self).__init__(prefix=None, params=None) 1634 self._prefix = '' 1635 self._params = ParameterDict('', params) 1636 if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) == 1: 1637 inputs = [inputs] 1638 if isinstance(outputs, (list, tuple)) and len(outputs) == 1: 1639 outputs = outputs[0] 1640 1641 syms, self._in_format = _flatten(inputs, "input") 1642 out, self._out_format = _flatten(outputs, "output") 1643 input_names = set() 1644 for i in syms: 1645 assert len(i.get_internals().list_outputs()) == 1, \ 1646 "Input symbols must be variable, but %s is an output of operators"%str(i) 1647 input_names.add(i.name) 1648 1649 # check if any symbol is row_sparse 1650 row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse'] 1651 1652 for i in out: 1653 for j in i.get_internals(): 1654 assert(j.attr("__storage_type__") != str(row_sparse_storage)), \ 1655 "SymbolBlock doesn't support Parameter '%s' because its storage " \ 1656 "type is 'row_sparse'." % j.name 1657 if len(out) > 1: 1658 out = symbol.Group(out, _check_same_symbol_type(out)) 1659 else: 1660 out = out[0] 1661 1662 # Infer type of parameters. Without this, every parameter will be created with 1663 # default type i.e., fp32 1664 arg_params = out.list_arguments() 1665 aux_params = out.list_auxiliary_states() 1666 1667 arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params) 1668 1669 for i, arg in enumerate(arg_params): 1670 if arg not in input_names: 1671 self.params.get(arg, allow_deferred_init=True, dtype=arg_types[i]) 1672 1673 for i, aux in enumerate(aux_params): 1674 if aux not in input_names: 1675 self.params.get(aux, grad_req='null', allow_deferred_init=True, dtype=aux_types[i]) 1676 1677 self._cached_graph = syms, out 1678 len_prefix = len(_common_prefix(list(self._params.keys()))) 1679 self._reg_params = {key[len_prefix:]: val for key, val in self._params.items()} 1680 1681 def forward(self, x, *args): 1682 if isinstance(x, NDArray): 1683 with x.ctx: 1684 return self._call_cached_op(x, *args) 1685 1686 assert isinstance(x, Symbol), \ 1687 "HybridBlock requires the first argument to forward be either " \ 1688 "Symbol or NDArray, but got %s"%type(x) 1689 args, in_fmt = _flatten([x] + list(args), "input") 1690 assert in_fmt == self._in_format, "Invalid input format" 1691 ret = copy.copy(self._cached_graph[1]) 1692 ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], args)}) 1693 return _regroup(list(ret), self._out_format) 1694 1695 def _clear_cached_op(self): 1696 tmp = self._cached_graph 1697 super(SymbolBlock, self)._clear_cached_op() 1698 self._cached_graph = tmp 1699 1700 def cast(self, dtype): 1701 self._clear_cached_op() 1702 super(SymbolBlock, self).cast(dtype) 1703 if np.dtype(dtype).name == 'float16': 1704 # correct BatchNorm types back to float32 due to its special requirement 1705 out = self._cached_graph[1] 1706 params_list = out.get_internals().list_inputs() 1707 for node in params_list: 1708 if node.endswith('running_var'): 1709 prefix = node[:-11] 1710 sibs = [prefix + t for t in ('running_mean', 'gamma', 'beta')] 1711 is_bn = all(p in params_list for p in sibs) 1712 if is_bn: 1713 self.params.get(node).cast('float32') 1714 for sib in sibs: 1715 self.params.get(sib).cast('float32') 1716 if node.endswith('moving_var'): 1717 # another convention used 1718 prefix = node[:-10] 1719 sibs = [prefix + t for t in ('moving_mean', 'gamma', 'beta')] 1720 is_bn = all(p in params_list for p in sibs) 1721 if is_bn: 1722 self.params.get(node).cast('float32') 1723 for sib in sibs: 1724 self.params.get(sib).cast('float32') 1725 1726 def hybrid_forward(self, F, x, *args, **kwargs): 1727 raise NotImplementedError 1728 1729 def reset_ctx(self, ctx): 1730 """Re-assign all Parameters to other contexts. If the Block is hybridized, it will reset the _cached_op_args. 1731 Parameters 1732 ---------- 1733 ctx : Context or list of Context, default :py:meth:`context.current_context()`. 1734 Assign Parameter to given context. If ctx is a list of Context, a 1735 copy will be made for each context. 1736 """ 1737 params = self.collect_params() 1738 if self._cached_op: 1739 for p in self._cached_op_args: 1740 # resetting parameters creating by the partitioning backend 1741 if p.name not in params: 1742 p.reset_ctx(ctx) 1743 for p in params.values(): 1744 p.reset_ctx(ctx) 1745 1746def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dtype=mx_real_t): 1747 """Utility function that helps in inferring DType of args and auxs params 1748 from given input param. 1749 1750 Parameters 1751 ---------- 1752 in_params: List of Symbol 1753 List of input symbol variables. 1754 out_params: Symbol 1755 Output symbol variable. 1756 arg_params: List of Str 1757 List of names of argument parametrs. 1758 aux_params: List of Str 1759 List of names of auxiliary parameters. 1760 default_dtype: numpy.dtype or str, default 'float32' 1761 Default data type for arg_params and aux_params, if unable to infer the type. 1762 1763 Returns 1764 ------- 1765 arg_types: List of numpy.dtype 1766 List of arg_params type. Order is same as arg_params. 1767 Defaults to 'float32', if unable to infer type. 1768 aux_types: List of numpy.dtype 1769 List of aux_params type. Order is same as aux_params. 1770 Defaults to 'float32', if unable to infer type. 1771 """ 1772 arg_types = None 1773 aux_types = None 1774 1775 # Get Input symbol details. This will be used to infer types of 1776 # other parameters. 1777 input_sym_names = [in_param.name for in_param in in_params] 1778 1779 # Try to infer input types. If not successful, we will set default dtype. 1780 # If successful, we will try to infer other params in the graph. 1781 input_sym_arg_types = [] 1782 can_infer_input_type = True 1783 for in_param in in_params: 1784 input_sym_arg_type = in_param.infer_type()[0] 1785 if not input_sym_arg_type or len(input_sym_arg_type) < 1: 1786 can_infer_input_type = False 1787 break 1788 else: 1789 input_sym_arg_types.append(in_param.infer_type()[0][0]) 1790 1791 # Try to infer types of other parameters. 1792 if can_infer_input_type: 1793 params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)} 1794 try: 1795 arg_types, _, aux_types = out_params.infer_type(**params) 1796 except MXNetError: 1797 # Cannot infer type with current input 1798 arg_types, aux_types = None, None 1799 1800 if arg_types is None or len(arg_types) != len(arg_params): 1801 arg_types = [] 1802 for _ in arg_params: 1803 arg_types.append(default_dtype) 1804 1805 if aux_types is None or len(aux_types) != len(aux_params): 1806 aux_types = [] 1807 for _ in aux_params: 1808 aux_types.append(default_dtype) 1809 1810 return (arg_types, aux_types) 1811