1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18 19# Scope for collecting child 'Block's 20use strict; 21use warnings; 22use AI::MXNet::Gluon::Parameter; 23package AI::MXNet::Gluon::BlockScope; 24use AI::MXNet::Function::Parameters; 25my $_current; 26use Mouse; 27has '_block' => (is => 'ro', init_arg => 'block', weak_ref => 1); 28has [qw/_counter _old_scope 29 _name_scope/] => (is => 'rw', init_arg => undef); 30 31sub BUILD 32{ 33 my $self = shift; 34 $self->_counter({}); 35} 36 37# Creates prefix and params for new Block. 38method create($prefix, $params, $hint) 39{ 40 my $current = $_current; 41 if(not defined $current) 42 { 43 if(not defined $prefix) 44 { 45 $prefix = AI::MXNet::Symbol::NameManager->current->get(undef, $hint) . '_'; 46 } 47 if(not defined $params) 48 { 49 $params = AI::MXNet::Gluon::ParameterDict->new(prefix => $prefix); 50 } 51 else 52 { 53 $params = AI::MXNet::Gluon::ParameterDict->new(prefix => $params->prefix, shared => $params); 54 } 55 return ($prefix, $params); 56 } 57 58 if(not defined $prefix) 59 { 60 my $count = $current->_counter->{ $hint } // 0; 61 $prefix = sprintf('%s%d_', $hint, $count); 62 $current->_counter->{$hint} = $count + 1; 63 } 64 if(not defined $params) 65 { 66 my $parent = $current->_block->params; 67 $params = AI::MXNet::Gluon::ParameterDict->new(prefix => $parent->prefix.$prefix, shared => $parent->_shared); 68 } 69 else 70 { 71 $params = AI::MXNet::Gluon::ParameterDict->new(prefix => $params->prefix, $params); 72 } 73 return ($current->_block->prefix.$prefix, $params); 74} 75 76method __enter__() 77{ 78 return $self if $self->_block->_empty_prefix; 79 $self->_old_scope($_current); 80 $_current = $self; 81 $self->_name_scope(AI::MXNet::Symbol::NameManager->current); 82 AI::MXNet::Symbol::NameManager->set_current(AI::MXNet::Symbol::Prefix->new(prefix => $self->_block->prefix)); 83 return $self; 84} 85 86method __exit__() 87{ 88 return if $self->_block->_empty_prefix; 89 AI::MXNet::Symbol::NameManager->set_current($self->_name_scope); 90 $self->_name_scope(undef); 91 $_current = $self->_old_scope; 92} 93 94package AI::MXNet::Gluon::Block; 95use AI::MXNet::Gluon::Mouse; 96use Scalar::Util qw(refaddr); 97 98=head2 NAME 99 100 AI::MXNet::Gluon::Block - Base class for all neural network layers and models. 101 102=head2 DESCRIPTION 103 104 Base class for all neural network layers and models. Your models should 105 subclass this class. 106 107 AI::MXNet::Gluon::Block can be nested recursively in a tree structure. You can create and 108 assign child AI::MXNet::Gluon::Block as regular attributes 109 110 use AI::MXNet::Gluon::NN qw(nn); 111 use AI::MXNet qw(mx); 112 113 package Model; 114 use AI::MXNet::Gluon::Mouse; 115 use AI::MXNet::Function::Parameters; 116 extends 'AI::MXNet::Gluon::Block'; 117 118 sub BUILD 119 { 120 my $self = shift; 121 $self->name_scope(sub { 122 $self->dense0(nn->Dense(5, in_units=>5)); 123 $self->dense1(nn->Dense(5, in_units=>5)); 124 }); 125 } 126 127 method forward($x) 128 { 129 return $self->dense1->($self->dense0->($x)); 130 } 131 132 my $model = Model->new() 133 $model->initialize(ctx=>mx->cpu(0)) 134 $model->(nd->zeros([10, 10], ctx=>mx->cpu(0))); 135 136 137 Child AI::MXNet::Gluon::Block assigned this way will be registered and ->collect_params 138 will collect their Parameters recursively. 139 140 Parameters 141 ---------- 142 Prefix acts like a name space. All children blocks created in parent block's 143 name_scope will have parent block's prefix in their name. 144 Please refer to 145 naming tutorial https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/naming.html 146 for more info on prefix and naming. 147 148 params : AI::MXNet::Gluon::ParameterDict or undef 149 AI::MXNet::Gluon::ParameterDict for sharing weights with the new AI::MXNet::Gluon::Block. For example, 150 if you want `dense1` to share `dense0`'s weights, you can do 151 152 $dense0 = nn->Dense(20); 153 $dense1 = nn->Dense(20, params=>dense0->collect_params()); 154=cut 155 156method _flatten( 157 $args 158) 159{ 160 if(blessed $args and $args->isa('AI::MXNet::NDArray')) 161 { 162 return ([$args], 0); 163 } 164 elsif(blessed $args and $args->isa('AI::MXNet::Symbol')) 165 { 166 my $length = @{ $args->list_outputs() }; 167 $length = $length > 1 ? $length : 0; 168 return ([$args], $length) 169 } 170 my @flat; 171 my @fmts; 172 for my $i (@{ $args }) 173 { 174 my ($arg, $fmt) = __PACKAGE__->_flatten($i); 175 push @flat, @{ $arg }; 176 push @fmts, $fmt; 177 } 178 return (\@flat, \@fmts); 179} 180 181method _regroup( 182 $args, $fmt 183) 184{ 185 my $in_symbol = (blessed $args and $args->isa('AI::MXNet::Symbol')); 186 my @ret; 187 if(not ref $fmt) 188 { 189 my $len = @{$args} - 1; 190 if($fmt == 0) 191 { 192 @ret = ([@{$args}[1..$len]]); 193 if($in_symbol) 194 { 195 $ret[0] = AI::MXNet::Symbol->Group($ret[0]); 196 } 197 return (@{$args}[0], $ret[0]); 198 } 199 @ret = ([@{$args}[0..$fmt-1]], [@{$args}[$fmt..$len]]); 200 if($in_symbol) 201 { 202 @ret = map { AI::MXNet::Symbol->Group($_) } @ret; 203 } 204 return @ret; 205 } 206 for my $i (@{ $fmt }) 207 { 208 my $res; 209 ($res, $args) = __PACKAGE__->_regroup($args, $i); 210 push @ret, $res; 211 } 212 return (\@ret, $args); 213} 214 215has _prefix => (is => 'rw', init_arg => 'prefix', isa => 'Str'); 216has _params => (is => 'rw', init_arg => 'params', isa => 'Maybe[AI::MXNet::Gluon::ParameterDict]'); 217has [qw/_name _scope _empty_prefix/] => (is => 'rw', init_arg => undef); 218has [qw/_children _forward_hooks _forward_pre_hooks/] => (is => 'rw', init_arg => undef, default => sub { Hash::Ordered->new }); 219has '_reg_params' => (is => 'rw', init_arg => undef, default => sub { +{} }); 220around BUILDARGS => \&AI::MXNet::Base::process_arguments; 221 222sub AUTOLOAD { 223 my $name = $AI::MXNet::Gluon::Block::AUTOLOAD; 224 $name =~ s/.*:://; 225 my $self = shift; 226 AI::MXNet::Gluon::Mouse::has($name => (is => 'rw', 'init_arg' => undef, 'caller' => ref $self)); 227 $self->$name(@_); 228} 229 230sub BUILD 231{ 232 my $self = shift; 233 $self->_empty_prefix(defined $self->_prefix and $self->_prefix eq ''); 234 my ($prefix, $params) = AI::MXNet::Gluon::BlockScope->create($self->_prefix, $self->_params, $self->_alias); 235 $self->_prefix($prefix); 236 $self->_params($params); 237 my $name = $prefix; 238 $name =~ s/_$//; 239 $self->_name($name); 240 $self->_scope(AI::MXNet::Gluon::BlockScope->new(block => $self)); 241} 242 243method _class_name() 244{ 245 my $class = ref $self || $self; 246 $class =~ s/^.+:://; 247 $class; 248} 249 250method __setattr__($name, $current, $prev=) 251{ 252 if(defined $prev) 253 { 254 if( 255 ( 256 blessed $prev 257 and 258 ($prev->isa('AI::MXNet::Gluon::Parameter') or $prev->isa('AI::MXNet::Gluon::Block')) 259 ) 260 and not (blessed $current and (ref($prev) eq ref($current))) 261 ) 262 { 263 confess( 264 sprintf( 265 "Changing attribute type for %s from %s to %s is not allowed.", 266 $self->name, 267 ref($prev), 268 ref($current)||'no ref' 269 ) 270 ); 271 } 272 } 273 if(blessed $current and $current->isa('AI::MXNet::Gluon::Block')) 274 { 275 $self->register_child($current, $name); 276 } 277 elsif(blessed $current and $current->isa('AI::MXNet::Gluon::Parameter')) 278 { 279 if(exists $self->_reg_params->{ $name }) 280 { 281 confess("Overriding Parameter attribute $name is not allowed. ". 282 "If you want to share parameters between blocks, please set". 283 "'params' at Block construction instead." 284 ); 285 } 286 $self->_reg_params->{ $name } = $current; 287 } 288} 289 290method _check_container_with_block() 291{ 292 my $_find_unregistered_block_in_container; 293 my %children = map { refaddr($_) => 1 } $self->_children->values; 294 $_find_unregistered_block_in_container = sub { my ($data) = @_; 295 # Find whether a nested container structure contains Blocks 296 if(ref $data eq 'ARRAY') 297 { 298 for my $ele (@{ $data }) 299 { 300 if($_find_unregistered_block_in_container->($ele)) 301 { 302 return 1 303 } 304 } 305 return 0; 306 } 307 elsif(ref $data eq 'HASH') 308 { 309 for my $v (values %$data) 310 { 311 if($_find_unregistered_block_in_container->($v)) 312 { 313 return 1; 314 } 315 } 316 return 0; 317 } 318 elsif(blessed $data and $data->isa('AI::MXNet::Gluon::Block')) 319 { 320 return not exists $children{ refaddr($data) }; 321 } 322 else 323 { 324 return 0; 325 } 326 }; 327 my $attributes_hash = $self->attributes_hash(); 328 while(my ($k, $v) = each %{ $attributes_hash }) 329 { 330 if((ref $v eq 'HASH' or ref $v eq 'ARRAY') and not $k =~ /^__/) 331 { 332 if($_find_unregistered_block_in_container->($v)) 333 { 334 AI::MXNet::Logging->warning( 335 '"%s" is a unregsitered container with Blocks. '. 336 'Note that Blocks inside the list, tuple or dict will not be '. 337 'registered automatically. Make sure to register them using '. 338 'register_child() or switching to '. 339 'nn->Sequential/nn->HybridSequential instead. ', 340 $self->_class_name.'.'.$k 341 ); 342 } 343 } 344 } 345} 346 347method _alias() 348{ 349 lc $self->_class_name; 350} 351 352method attributes_hash() 353{ 354 +{ map { $_ => $self->$_ } $self->meta->get_attribute_list }; 355} 356 357use overload 358 '""' => sub 359 { 360 my $self = shift; 361 my $s = "%s(\n%s\n)"; 362 my @blocks; 363 my %attributes_hash = %{ $self->attributes_hash }; 364 while(my ($k, $v) = each %attributes_hash) 365 { 366 if(blessed $v and $v->isa(__PACKAGE__)) 367 { 368 push @blocks, " ($k): ".AI::MXNet::Base::_indent("$v", 2); 369 } 370 } 371 sprintf("%s(\n%s\n)", $self->_class_name, join("\n", @blocks)); 372 }, 373 '&{}' => sub { my $self = shift; sub { $self->call(@_) } }; 374 375method prefix() 376{ 377 $self->_prefix; 378} 379 380method name() 381{ 382 $self->_name; 383} 384 385method class() 386{ 387 __PACKAGE__; 388} 389 390method name_scope(CodeRef $sub) 391{ 392 $self->_scope->__enter__; 393 eval { $sub->(); }; 394 my $err = $@; 395 $self->_scope->__exit__; 396 confess($err) if $err; 397} 398 399=head2 params 400 401 Returns this `Block`'s parameter dictionary (does not include its 402 children's parameters). 403=cut 404 405method params() 406{ 407 return $self->_params; 408} 409 410=head2 collect_params 411 412 Returns a AI::MXNet::Gluon::ParameterDict containing this AI::MXNet::Gluon::Block and all of its 413 children's Parameters(default), also can returns the ParameterDict 414 with parameters that match a regular expression. 415 416 For example, collects parameters specified in ['conv1_weight', 'conv1_bias', 'fc_weight', 417 'fc_bias' 418 419 $model->collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias') 420 421 or collects all parameters that have the name end with 'weight' or 'bias', this can be done 422 using regular expressions. 423 424 $model->collect_params('.*weight|.*bias') 425 426=cut 427 428method collect_params(Maybe[Str] $select=) 429{ 430 $self->_check_container_with_block(); 431 my $ret = AI::MXNet::Gluon::ParameterDict->new(prefix => $self->_params->prefix); 432 $ret->update($self->params, $select); 433 for my $cld ($self->_children->values) 434 { 435 $ret->update($cld->collect_params($select)); 436 } 437 return $ret; 438} 439 440 441method _collect_params_with_prefix(Str $prefix='') 442{ 443 if($prefix) 444 { 445 $prefix .= '.'; 446 } 447 my %ret = map { $prefix.$_ => $self->_reg_params->{ $_ } } keys %{ $self->_reg_params }; 448 my $iter = $self->_children->iterator; 449 while(my ($name, $child) = $iter->()) 450 { 451 %ret = (%ret, %{ $child->_collect_params_with_prefix("$prefix$name") }); 452 } 453 return \%ret; 454} 455 456=head2 save_parameters 457 458 Save parameters to file. 459 460 filename : str 461 Path to file. 462=cut 463 464method save_parameters(Str $filename) 465{ 466 my $params = $self->_collect_params_with_prefix(); 467 my %arg_dict = map { $_ => $params->{$_}->_reduce } keys %{ $params }; 468 AI::MXNet::NDArray->save($filename, \%arg_dict); 469} 470 471=head2 load_parameters 472 473 Load parameters from file. 474 475 $filename : str 476 Path to parameter file. 477 :$ctx= : Context or list of Context 478 Context(s) initialize loaded parameters on. 479 :$allow_missing : bool, default False 480 Whether to silently skip loading parameters not represents in the file. 481 :$ignore_extra : bool, default False 482 Whether to silently ignore parameters from the file that are not 483 present in this Block. 484=cut 485 486method load_parameters( 487 Str $filename, 488 AI::MXNet::Context|ArrayRef[AI::MXNet::Context] :$ctx=AI::MXNet::Context->current_ctx, 489 Bool :$allow_missing=0, 490 Bool :$ignore_extra=0 491) 492{ 493 my $loaded = AI::MXNet::NDArray->load($filename); 494 my $params = $self->_collect_params_with_prefix; 495 return if not keys %$loaded and not keys %$params; 496 497 if(not grep { /\./ } keys %$loaded) 498 { 499 # legacy loading 500 %$loaded = (); 501 $self->collect_params->load( 502 $filename, 503 ($ctx ? (ctx => $ctx) : ()), 504 allow_missing => $allow_missing, 505 ignore_extra => $ignore_extra, 506 restore_prefix => $self->prefix 507 ); 508 return; 509 } 510 511 if(not $allow_missing) 512 { 513 for my $name (keys %$params) 514 { 515 if(not exists $loaded->{$name}) 516 { 517 confess( 518 "Parameter $name is missing in file $filename, which contains parameters:". 519 join(',', keys %$loaded)."\n". 520 "Set allow_missing=>1 to ignore missing parameters." 521 ); 522 } 523 } 524 } 525 for my $name (keys %$loaded) 526 { 527 if(not $ignore_extra and not exists $params->{ $name }) 528 { 529 confess( 530 "Parameter $name loaded from file $filename is not present in ParameterDict, ". 531 "which contains parameters ". 532 join(',', keys %$params)."\n". 533 "Set ignore_extra=>1 to ignore." 534 ); 535 } 536 $params->{$name}->_load_init($loaded->{$name}, $ctx) if exists $params->{$name}; 537 } 538} 539 540=head2 register_child 541 542 Registers block as a child of self. `Block`s assigned to self as 543 attributes will be registered automatically. 544=cut 545 546method register_child(AI::MXNet::Gluon::Block $block, Maybe[Str] $name=) 547{ 548 $name //= $self->_children->keys; 549 $self->_children->set($name, $block); 550} 551 552=head2 register_forward_pre_hook 553 554 Registers a forward pre-hook on the block. 555 556 The hook function is called immediately before 'forward'. 557 It should not modify the input or output. 558 559 Parameters 560 ---------- 561 $hook : CodeRef or callable object 562 The forward hook function of form $hook->($block, $input). 563 564 Returns 565 ------- 566 AI::MXNet::Gluon::Utils::HookHandle 567=cut 568 569method register_forward_pre_hook($hook) 570{ 571 my $handle = AI::MXNet::Gluon::Utils::HookHandle->new; 572 $handle->attach($self->_forward_pre_hooks, $hook); 573 return $handle; 574} 575 576=head2 register_forward_hook 577 578 Registers a forward hook on the block. 579 580 The hook function is called immediately after 'forward'. 581 It should not modify the input or output. 582 583 Parameters 584 ---------- 585 $hook : CodeRef or callable object 586 The forward hook function of form $hook->($block, $input). 587 588 Returns 589 ------- 590 AI::MXNet::Gluon::Utils::HookHandle 591=cut 592 593method register_forward_hook($hook) 594{ 595 my $handle = AI::MXNet::Gluon::Utils::HookHandle->new; 596 $handle->attach($self->_forward_hooks, $hook); 597 return $handle; 598} 599 600=head2 apply 601 602 Applies $fn recursively to every child block as well as self. 603 604 Parameters 605 ---------- 606 $fn : callable 607 Function to be applied to each submodule, of form `$fn->($block)`. 608 609 Returns 610 ------- 611 this block 612=cut 613 614method apply($fn) 615{ 616 for my $cld ($self->_children->values) 617 { 618 $cld->apply($fn); 619 } 620 $fn->($self); 621 return $self; 622} 623 624=head2 initialize 625 626 627 Initializes AI::MXNet::Gluon::Parameters of this AI::MXNet::Gluon::Block and its children. 628 Equivalent to $block->collect_params()->initialize(...) 629 630 Parameters 631 ---------- 632 $init : Initializer 633 Global default Initializer to be used when Parameter->init is undefined`. 634 Otherwise, Parameter->init takes precedence. 635 ctx : Context or array ref of Context 636 Keeps a copy of Parameters on one or many context(s). 637 verbose : bool, default False 638 Whether to verbosely print out details on initialization. 639 force_reinit : bool, default False 640 Whether to force re-initialization if parameter is already initialized. 641=cut 642 643method initialize( 644 Initializer $init=AI::MXNet::Initializer->Uniform(), 645 AI::MXNet::Context|ArrayRef[AI::MXNet::Context] :$ctx=AI::MXNet::Context->current_ctx, 646 Bool :$verbose=0, 647 Bool :$force_reinit=0 648) 649{ 650 $self->collect_params->initialize(init => $init, ctx => $ctx, verbose => $verbose, force_reinit => $force_reinit); 651} 652 653 654=head2 hybridize 655 656 Activates or deactivates `HybridBlock`s recursively. Has no effect on 657 non-hybrid children. 658 659 Parameters 660 ---------- 661 $active : bool, default True 662 Whether to turn hybrid on or off. 663 :$static_alloc : bool, default False 664 Statically allocate memory to improve speed. Memory usage may increase. 665 :$static_shape : bool, default False 666 Optimize for invariant input shapes between iterations. Must also 667 set static_alloc to True. Change of input shapes is still allowed 668 but slower. 669=cut 670 671method hybridize( 672 Bool $active=1, 673 %args 674) 675{ 676 $_->hybridize( 677 $active, 678 %args 679 ) for $self->_children->values; 680} 681 682=head2 cast 683 684 Cast this Block to use another data type. 685 686 Parameters 687 ---------- 688 dtype : Dtype 689 The new data type. 690=cut 691 692method cast(Dtype $dtype) 693{ 694 for my $child ($self->_children->values) 695 { 696 $child->cast($dtype); 697 } 698 $_->cast($dtype) for $self->params->values; 699} 700 701method call(@args) 702{ 703 for my $hook ($self->_forward_pre_hooks->values) 704 { 705 $hook->($self, \@args); 706 } 707 my @out = $self->forward(@args); 708 for my $hook ($self->_forward_hooks->values) 709 { 710 $hook->($self, \@args, \@out); 711 } 712 return wantarray ? @out : $out[0]; 713} 714 715=head2 forward 716 717 Overrides to implement forward computation using `NDArray`. Only 718 accepts positional arguments. 719 720 Parameters 721 ---------- 722 @args : array of NDArray 723 Input tensors. 724=cut 725 726method forward(@args) 727{ 728 confess("Not Implemented"); 729} 730 731method register(Str $container) 732{ 733 my $sub_name = $self->_class_name; 734 my $dest = $self->can('new'); 735 my $func = sub { 736 splice @_, 0, 1, $self; 737 goto $dest; 738 }; 739 no strict 'refs'; 740 *{"$container\::$sub_name"} = $func; 741} 742 743=head2 summary 744 745 Print the summary of the model's output and parameters. 746 747 The network must have been initialized, and must not have been hybridized. 748 749 Parameters 750 ---------- 751 @inputs : objects 752 Any inputs that the model supports. For any tensor in the input, only 753 AI::MXNet::NDArray is supported. 754=cut 755 756method summary(@inputs) 757{ 758 my $summary = Hash::Ordered->new; 759 my %seen; 760 my @hooks; 761 my $stringify; 762 $stringify = sub { 763 my $in = shift; 764 if(ref($in) eq 'ARRAY') 765 { 766 return '('.join(', ', map { $stringify->($_) } @$in).')'; 767 } 768 else 769 { 770 return "$in"; 771 } 772 }; 773 my $_get_shape_str = sub { my ($args) = @_; 774 $args = $args->[0] if(ref $args eq 'ARRAY' and @$args == 1); 775 my ($flat_args, $fmts) = __PACKAGE__->_flatten($args); 776 my $flat_arg_shapes = [map { (blessed($_) and $_->isa('AI::MXNet::NDArray')) ? $_->shape : $_ } @$flat_args]; 777 my $shapes = (__PACKAGE__->_regroup($flat_arg_shapes, $fmts))[0]; 778 my $shape_str = $stringify->($shapes); 779 $shape_str =~ s/L//g; 780 return $shape_str; 781 }; 782 783 my $_register_summary_hook = sub { my ($block) = @_; 784 unless(not $block->isa('AI::MXNet::Gluon:::HybridBlock') or not $block->_active) 785 { 786 confess("\"${\ $block->name }\" must not be hybridized to print summary."); 787 } 788 my $_summary_hook = sub { my ($block, undef, $outputs) = @_; 789 my $class_name = $block->_class_name; 790 my $block_idx = $summary->keys - 1; 791 792 my $m_key = sprintf('%s-%i', $class_name, $block_idx+1); 793 $summary->set($m_key, Hash::Ordered->new); 794 $summary->get($m_key)->set('output_shape', $_get_shape_str->($outputs)); 795 796 my $params = 0; 797 $summary->get($m_key)->set('trainable', 0); 798 $summary->get($m_key)->set('shared', 0); 799 for my $p (values %{ $block->_reg_params }) 800 { 801 $params += $p->data->size; 802 $summary->get($m_key)->set('trainable', $summary->get($m_key)->get('trainable') + ($p->grad_req eq 'null' ? 0 : $p->data->size)); 803 if(exists $seen{$p}) 804 { 805 $summary->get($m_key)->set('shared', $summary->get($m_key)->get('shared') + $p->data->size); 806 } 807 else 808 { 809 $seen{$p} = 1; 810 } 811 } 812 $summary->get($m_key)->set('n_params', $params); 813 }; 814 815 if(not $block->isa('AI::MXNet::Gluon::NN::Sequential') and not $block->isa('AI::MXNet::Gluon::NN::HybridSequential')) 816 { 817 push @hooks, $block->register_forward_hook($_summary_hook); 818 } 819 }; 820 821 my $input = Hash::Ordered->new; 822 $summary->set('Input', $input); 823 $input->set('output_shape', $_get_shape_str->(\@inputs)); 824 $input->set('n_params', 0); 825 $input->set('trainable', 0); 826 $input->set('shared', 0); 827 828 eval { 829 $self->apply($_register_summary_hook); 830 $self->(@inputs); 831 832 my $line_format = "%20s %42s %15s\n"; 833 print (('-')x80, "\n"); 834 printf($line_format, 'Layer (type)', 'Output Shape', 'Param #'); 835 print (('=')x80, "\n"); 836 my $total_params = 0; 837 my $trainable_params = 0; 838 my $shared_params = 0; 839 for my $layer ($summary->keys) 840 { 841 printf($line_format, $layer, $summary->get($layer)->get('output_shape'), $summary->get($layer)->get('n_params')); 842 $total_params += $summary->get($layer)->get('n_params'); 843 $trainable_params += $summary->get($layer)->get('trainable'); 844 $shared_params += $summary->get($layer)->get('shared'); 845 } 846 print (('=')x80, "\n"); 847 print "Parameters in forward computation graph, duplicate included\n"; 848 print " Total params: $total_params\n"; 849 print " Non-trainable params: ", $total_params - $trainable_params, "\n"; 850 print "Shared params in forward computation graph: $shared_params\n"; 851 print "Unique parameters in model: ", $total_params - $shared_params, "\n"; 852 print (('-')x80, "\n"); 853 }; 854 $_->detach for @hooks; 855} 856 857__PACKAGE__->register('AI::MXNet::Gluon'); 858 859package AI::MXNet::Gluon::HybridBlock; 860=head2 NAME 861 862 AI::MXNet::Gluon::HybridBlock 863 864=head2 DESCRIPTION 865 866 HybridBlock supports forwarding with both Symbol and NDArray. 867 868 Forward computation in HybridBlock must be static to work with Symbols, 869 i.e. you cannot call aspdl, shape, dtype, etc on tensors. 870 Also, you cannot use branching or loop logic that bases on non-constant 871 expressions like random numbers or intermediate results, since they change 872 the graph structure for each iteration. 873 874 Before activating with hybridize(), HybridBlock works just like normal 875 Block. After activation, HybridBlock will create a symbolic graph 876 representing the forward computation and cache it. On subsequent forwards, 877 the cached graph will be used instead of hybrid_forward. 878 879 Refer Hybrid tutorial L<https://mxnet.io/tutorials/gluon/hybrid.html> to see 880 the end-to-end usage. 881=cut 882 883use AI::MXNet::Gluon::Mouse; 884use AI::MXNet::Base; 885extends 'AI::MXNet::Gluon::Block'; 886has [qw/ 887 _cached_graph 888 _cached_op 889 _out_format _in_format 890 _active _flags _cached_op_args 891/] => (is => 'rw', init_arg => undef); 892 893sub BUILD 894{ 895 my $self = shift; 896 $self->_active(0); 897 $self->_flags([]); 898 $self->_cached_graph([]); 899 $self->_cached_op_args([]); 900} 901 902method __setattr__($name, $current, $prev=) 903{ 904 $self->SUPER::__setattr__($name, $current, $prev); 905 if(blessed $current and $current->isa('AI::MXNet::Gluon::HybridBlock')) 906 { 907 $self->_clear_cached_op(); 908 } 909} 910 911method register_child(AI::MXNet::Gluon::HybridBlock $block, Maybe[Str] $name=) 912{ 913 $self->SUPER::register_child($block, $name); 914 $self->_clear_cached_op(); 915} 916 917method hybridize(@args) 918{ 919 my $active; 920 if(@args%2) 921 { 922 $active = shift(@args); 923 } 924 else 925 { 926 $active = 1; 927 } 928 $self->_active($active); 929 @{ $self->_flags } = @args; 930 $self->_clear_cached_op(); 931 if($self->_active and ($self->_forward_hooks or $self->_forward_pre_hooks)) 932 { 933 AI::MXNet::Logging->warning( 934 "$self is being hybridized while still having forward hook/pre-hook. ". 935 "If $self is a child of HybridBlock, the hooks will not take effect." 936 ); 937 } 938 $self->SUPER::hybridize($self->_active, @args); 939} 940 941method cast(Dtype $dtype) 942{ 943 $self->_clear_cached_op; 944 $self->SUPER::cast($dtype); 945} 946 947method _infer_attrs($infer_fn, $attr, @args) 948{ 949 my ($inputs, $out) = $self->_get_graph(@args); 950 my ($args) = __PACKAGE__->_flatten([@args]); 951 my %in; 952 zip(sub { 953 my ($i, $j) = @_; 954 $in{ $i->name } = $j->$attr; 955 }, $inputs, $args); 956 my ($arg_attrs, $aux_attrs); 957 ($arg_attrs, undef, $aux_attrs) = $out->$infer_fn(%in); 958 if(not defined $arg_attrs) 959 { 960 confess($@); 961 } 962 my %sdict; 963 zip(sub { 964 my ($i, $j) = @_; 965 $sdict{ $i } = $j; 966 }, $out->list_arguments, $arg_attrs); 967 zip(sub { 968 my ($i, $j) = @_; 969 $sdict{ $i } = $j; 970 }, $out->list_auxiliary_states, $aux_attrs); 971 972 for my $i ($self->collect_params->values) 973 { 974 $i->$attr($sdict{ $i->name }); 975 } 976} 977 978method infer_shape(@args) 979{ 980 $self->_infer_attrs('infer_shape', 'shape', @args); 981} 982 983method infer_type(@args) 984{ 985 $self->_infer_attrs('infer_type', 'dtype', @args); 986} 987 988method _get_graph(@args) 989{ 990 if(not @{ $self->_cached_graph }) 991 { 992 my $args = [@args]; 993 my ($in_format, $out_format); 994 ($args, $in_format) = __PACKAGE__->_flatten($args); 995 $self->_in_format($in_format); 996 my @inputs; 997 if(@args > 1) 998 { 999 @inputs = map { AI::MXNet::Symbol->var("data_$_") } 0 .. @$args-1; 1000 } 1001 else 1002 { 1003 @inputs = (AI::MXNet::Symbol->var("data")) 1004 } 1005 my ($grouped_inputs) = __PACKAGE__->_regroup(\@inputs, $self->_in_format); 1006 my %params = map { $_ => $self->_reg_params->{$_}->var } keys %{ $self->_reg_params }; 1007 my @out; 1008 $self->name_scope(sub { 1009 @out = $self->hybrid_forward('AI::MXNet::Symbol', @{ $grouped_inputs }, %params); 1010 }); 1011 my $out = @out > 1 ? [@out] : $out[0]; 1012 ($out, $out_format) = __PACKAGE__->_flatten($out); 1013 $self->_out_format($out_format); 1014 @{ $self->_cached_graph } = (\@inputs, AI::MXNet::Symbol->Group($out)); 1015 } 1016 return @{ $self->_cached_graph }; 1017} 1018 1019=head2 infer_shape 1020 1021 Infers shape of Parameters from inputs. 1022=cut 1023 1024method _build_cache(@args) 1025{ 1026 my ($data, $out) = $self->_get_graph(@args); 1027 my $i = 0; 1028 my %data_names = map { $_->name => $i++ } @{ $data }; 1029 my $params = $self->collect_params; 1030 my $input_names = $out->list_inputs; 1031 my %param_names = map { $_ => 1 } $params->keys; 1032 my %expected_names = map { $_ => 1 } @{ $input_names }; 1033 for my $name (keys %expected_names) 1034 { 1035 assert( 1036 (exists $param_names{ $name } or exists $data_names{ $name }), 1037 "Unknown input to HybridBlock: $name" 1038 ); 1039 } 1040 my $unused = join(', ', map { "$data_names{$_}-th" } grep { !exists $expected_names{ $_ } } keys %data_names); 1041 AI::MXNet::Logging->warn( 1042 "The $unused input to HybridBlock is not used by any ". 1043 "computation. Is this intended?" 1044 ) if $unused; 1045 $unused = join(', ', grep { !exists $expected_names{ $_ } } keys %param_names); 1046 AI::MXNet::Logging->warn( 1047 "Parameter %s is not used by any computation. " . 1048 "Is this intended?" 1049 ) if $unused; 1050 1051 my @data_indices; 1052 my @param_indices; 1053 $self->_cached_op_args([]); 1054 enumerate(sub { 1055 my ($i, $name) = @_; 1056 if(exists $data_names{ $name }) 1057 { 1058 push @data_indices, $i; 1059 push @{ $self->_cached_op_args }, [1, $data_names{$name}]; 1060 } 1061 else 1062 { 1063 push @param_indices, $i; 1064 push @{ $self->_cached_op_args }, [0, $params->params->get($name)]; 1065 } 1066 }, $input_names); 1067 my %flags = ( 1068 data_indices => \@data_indices, 1069 param_indices => \@param_indices, 1070 @{ $self->_flags } 1071 ); 1072 $self->_cached_op(AI::MXNet::CachedOp->new($out, \%flags)); 1073} 1074 1075method _deferred_infer_shape(@args) 1076{ 1077 eval { 1078 $self->infer_shape(@args) 1079 }; 1080 if($@) 1081 { 1082 confess( 1083 "Deferred initialization failed because shape". 1084 " cannot be inferred. $@" 1085 ); 1086 } 1087} 1088 1089method _clear_cached_op() 1090{ 1091 $self->_cached_graph([]); 1092 $self->_cached_op(undef); 1093} 1094 1095use Data::Dumper; 1096method _call_cached_op(@args) 1097{ 1098 if(not defined $self->_cached_op) 1099 { 1100 $self->_build_cache(@args); 1101 } 1102 my $args = [@args]; 1103 my $fmt; 1104 ($args, $fmt) = __PACKAGE__->_flatten($args); 1105 assert((Dumper($fmt) eq Dumper($self->_in_format)), "Invalid input format"); 1106 my @cargs; 1107 eval { 1108 @cargs = map { (not $_->[0]) ? $_->[1]->data() : $args->[$_->[1]] } @{ $self->_cached_op_args }; 1109 }; 1110 if($@) 1111 { 1112 if($@ =~ /DeferredInitializationError/) 1113 { 1114 $self->_deferred_infer_shape(@$args); 1115 @cargs = (); 1116 map { 1117 if($_->[0]) 1118 { 1119 push @cargs, $args->[$_->[1]]; 1120 } 1121 else 1122 { 1123 $_->[1]->_finish_deferred_init(); 1124 push @cargs, $_->[1]->data; 1125 } 1126 } @{ $self->_cached_op_args }; 1127 } 1128 else 1129 { 1130 confess($@); 1131 } 1132 } 1133 my $out = $self->_cached_op->(@cargs); 1134 if(blessed $out and $out->isa('AI::MXNet::NDArray')) 1135 { 1136 $out = [$out]; 1137 } 1138 my $ret = (__PACKAGE__->_regroup($out, $self->_out_format))[0]; 1139 if(ref($ret) eq 'ARRAY' and wantarray) 1140 { 1141 return @$ret; 1142 } 1143 else 1144 { 1145 return $ret; 1146 } 1147} 1148 1149=head2 forward 1150 1151 Defines the forward computation. Arguments can be either 1152 NDArray or Symbol 1153=cut 1154 1155method forward($x, @args) 1156{ 1157 if(blessed $x and $x->isa('AI::MXNet::NDArray')) 1158 { 1159 my @out; 1160 my $out; 1161 my $ctx = $x->context; 1162 my $current_ctx = AI::MXNet::Context->current_ctx; 1163 AI::MXNet::Context->set_current($ctx); 1164 if($self->_active) 1165 { 1166 if(wantarray) 1167 { 1168 my @out = $self->_call_cached_op($x, @args); 1169 AI::MXNet::Context->set_current($current_ctx); 1170 return @out; 1171 } 1172 else 1173 { 1174 my $out = $self->_call_cached_op($x, @args); 1175 AI::MXNet::Context->set_current($current_ctx); 1176 return $out; 1177 } 1178 } 1179 my %params; 1180 eval { 1181 %params = map { $_ => $self->_reg_params->{ $_ }->data($ctx) } keys %{ $self->_reg_params }; 1182 }; 1183 if($@) 1184 { 1185 if($@ =~ /DeferredInitializationError/) 1186 { 1187 $self->_deferred_infer_shape($x, @args); 1188 $_->_finish_deferred_init for $self->params->values; 1189 %params = map { $_ => $self->_reg_params->{ $_ }->data($ctx) } keys %{ $self->_reg_params }; 1190 } 1191 else 1192 { 1193 confess($@); 1194 } 1195 } 1196 @out = $self->hybrid_forward('AI::MXNet::NDArray', $x, @args, %params); 1197 AI::MXNet::Context->set_current($current_ctx); 1198 return wantarray ? @out : $out[0]; 1199 } 1200 assert( 1201 (blessed $x and $x->isa('AI::MXNet::Symbol')), 1202 "HybridBlock requires the first argument to forward be either ". 1203 "Symbol or NDArray, but got [".ref($x)."]" 1204 ); 1205 my %params = map { $_ => $self->_reg_params->{ $_ }->var } keys %{ $self->_reg_params }; 1206 my @ret; 1207 $self->name_scope(sub { 1208 @ret = $self->hybrid_forward('AI::MXNet::Symbol', $x, @args, %params); 1209 }); 1210 return wantarray ? @ret : $ret[0]; 1211} 1212 1213=head2 hybrid_forward 1214 1215 Overrides to construct symbolic graph for this `Block`. 1216 1217 Parameters 1218 ---------- 1219 x : Symbol or NDArray 1220 The first input tensor. 1221 *args : list of Symbol or list of NDArray 1222 Additional input tensors. 1223=cut 1224 1225method hybrid_forward($F, $x, @args) 1226{ 1227 confess("NotImplementedError"); 1228} 1229 1230=head2 export 1231 1232 Export HybridBlock to json format that can be loaded by AI::MXNet::Module 1233 or the C++ interface. 1234 1235 When there are only one input, it will have name 'data'. When there 1236 Are more than one inputs, they will be named as 'data0', 'data1', etc. 1237 1238 Parameters 1239 ---------- 1240 $path : str 1241 Path to save model. Two files 'path-symbol.json' and 'path-xxxx.params' 1242 will be created, where xxxx is the 4 digits epoch number. 1243 :$epoch=0 : Int 1244 Epoch number of saved model. 1245=cut 1246 1247method export(Str $path, :$epoch=0) 1248{ 1249 if(not @{ $self->_cached_graph }) 1250 { 1251 confess( 1252 "Please first call \$block->hybridize() and then run forward with ". 1253 "this block at least once before calling export." 1254 ); 1255 } 1256 my $sym = $self->_cached_graph->[1]; 1257 $sym->save("$path-symbol.json"); 1258 1259 my %arg_names = map { $_ => 1 } @{ $sym->list_arguments }; 1260 my %aux_names = map { $_ => 1 } @{ $sym->list_auxiliary_states }; 1261 my %arg_dict; 1262 my $params = $self->collect_params; 1263 for my $name ($params->keys) 1264 { 1265 my $param = $params->get($name); 1266 if(exists $arg_names{ $name }) 1267 { 1268 $arg_dict{ "arg:$name" } = $param->_reduce; 1269 } 1270 else 1271 { 1272 assert(exists $aux_names{ $name }); 1273 $arg_dict{ "aux:$name" } = $param->_reduce; 1274 } 1275 } 1276 AI::MXNet::NDArray->save(sprintf('%s-%04d.params', $path, $epoch), \%arg_dict); 1277} 1278 1279__PACKAGE__->register('AI::MXNet::Gluon'); 1280 1281package AI::MXNet::Gluon::SymbolBlock; 1282use AI::MXNet::Gluon::Mouse; 1283use AI::MXNet::Base; 1284extends 'AI::MXNet::Gluon::HybridBlock'; 1285 1286=head1 NAME 1287 1288 AI::MXNet::Gluon::SymbolBlock - Construct block from symbol. 1289=cut 1290 1291=head1 DESCRIPTION 1292 1293 Construct block from symbol. This is useful for using pre-trained models 1294 as feature extractors. For example, you may want to extract get the output 1295 from fc2 layer in AlexNet. 1296 1297 Parameters 1298 ---------- 1299 outputs : Symbol or list of Symbol 1300 The desired output for SymbolBlock. 1301 inputs : Symbol or list of Symbol 1302 The Variables in output's argument that should be used as inputs. 1303 params : ParameterDict 1304 Parameter dictionary for arguments and auxililary states of outputs 1305 that are not inputs. 1306 1307 Examples 1308 -------- 1309 >>> # To extract the feature from fc1 and fc2 layers of AlexNet 1310 >>> $alexnet = gluon->model_zoo->vision->alexnet(pretrained=>1, ctx=>mx->cpu(), 1311 prefix=>'model_'); 1312 >>> $inputs = mx->sym->var('data'); 1313 >>> $out = $alexnet->($inputs); 1314 >>> $internals = $out->get_internals() 1315 >>> print($internals->list_outputs()) 1316 ['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...] 1317 >>> $outputs = [$internals->slice('model_dense0_relu_fwd_output'), 1318 $internals->slice('model_dense1_relu_fwd_output')]; 1319 >>> # Create SymbolBlock that shares parameters with alexnet 1320 >>> $feat_model = gluon->SymbolBlock($outputs, $inputs, params=>$alexnet->collect_params()); 1321 >>> $x = mx->nd->random_normal(shape=>[16, 3, 224, 224]); 1322 >>> print($feat_model->($x)); 1323=cut 1324 1325has [qw/outputs inputs/] => (is => 'rw', isa => 'AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]'); 1326method python_constructor_arguments() { [qw/outputs inputs/] } 1327 1328sub BUILD 1329{ 1330 my ($self, $orig_params) = @_; 1331 return unless defined $self->outputs and defined $self->inputs; 1332 $self->_prefix(''); 1333 $self->_params(AI::MXNet::Gluon::ParameterDict->new(prefix => '', shared => $orig_params->{params})); 1334 if(blessed $self->inputs and @{ $self->inputs->list_outputs } == 1) 1335 { 1336 $self->inputs([$self->inputs]); 1337 } 1338 if(not blessed $self->outputs and @{ $self->outputs } == 1) 1339 { 1340 $self->outputs($self->outputs->[0]); 1341 } 1342 my ($syms, $in_format) = __PACKAGE__->_flatten($self->inputs); 1343 my ($out, $out_format) = __PACKAGE__->_flatten($self->outputs); 1344 $self->_in_format($in_format); 1345 $self->_out_format($out_format); 1346 $out = AI::MXNet::Symbol->Group($out); 1347 1348 my %input_names; 1349 for my $i (@{ $syms }) 1350 { 1351 assert( 1352 (@{ $i->get_internals->list_outputs() } == 1), 1353 "Input symbols must be variable, but $i is an output of operators" 1354 ); 1355 $input_names{ $i->name } = 1; 1356 } 1357 1358 # check if any symbol is row_sparse 1359 my $row_sparse_storage = STORAGE_TYPE_STR_TO_ID->{row_sparse}; 1360 for my $i (@{ $out }) 1361 { 1362 for my $j (@{ $i->get_internals }) 1363 { 1364 assert( 1365 (not defined $j->attr("__storage_type__") or $j->attr("__storage_type__") ne $row_sparse_storage), 1366 "SymbolBlock doesn't support Parameter ${\ $j->name } because its storage ". 1367 "type is 'row_sparse'." 1368 ); 1369 } 1370 } 1371 1372 my $arg_params = $out->list_arguments; 1373 my $aux_params = $out->list_auxiliary_states; 1374 my ($arg_types, $aux_types) = _infer_param_types($syms, $out, $arg_params, $aux_params); 1375 1376 for(enumerate($arg_params)) 1377 { 1378 my ($i, $arg) = @$_; 1379 if(not exists $input_names{ $arg }) 1380 { 1381 $self->params->get($arg, allow_deferred_init => 1, dtype => $arg_types->[$i]); 1382 } 1383 } 1384 1385 for(enumerate($aux_params)) 1386 { 1387 my ($i, $arg) = @$_; 1388 if(not exists $input_names{ $arg }) 1389 { 1390 $self->params->get($arg, grad_req => 'null', allow_deferred_init => 1, dtype => $aux_types->[$i]); 1391 } 1392 } 1393 1394 $self->_cached_graph([$syms, $out]); 1395 my $prefix = _common_prefix($self->_params->keys); 1396 my %params = $self->_params->items; 1397 while(my ($key, $val) = each %params) 1398 { 1399 $key =~ s/^$prefix//; 1400 $self->_reg_params->{ $key } = $val; 1401 } 1402 $self->_prefix($prefix); 1403} 1404 1405 1406func _infer_param_types($in_params, $out_params, $arg_params, $aux_params, $default_dtype='float32') 1407{ 1408 # Utility function that helps in inferring DType of args and auxs params 1409 # from given input param. 1410 # Parameters 1411 # ---------- 1412 # in_params: array ref of AI::MXNet::Symbol objects 1413 # List of input symbol variables. 1414 # out_params: AI::MXNet::Symbol 1415 # Output symbol variable. 1416 # arg_params: array ref of Str 1417 # List of names of argument parametrs. 1418 # aux_params: array ref of Str 1419 # List of names of auxiliary parameters. 1420 # default_dtype: Dtype, default 'float32' 1421 # Default data type for arg_params and aux_params, if unable to infer the type. 1422 # Returns 1423 # ------- 1424 # arg_types: Array ref of Dtype 1425 # List of arg_params type. Order is same as arg_params. 1426 # Defaults to 'float32', if unable to infer type. 1427 # aux_types: Array ref of Dtype 1428 # List of aux_params type. Order is same as aux_params. 1429 # Defaults to 'float32', if unable to infer type. 1430 1431 my $arg_types; 1432 my $aux_types; 1433 # Get Input symbol details. This will be used to infer types of 1434 # other parameters. 1435 my @input_sym_names = map { $_->name } @{ $in_params }; 1436 # Try to infer input types. If not successful, we will set default dtype. 1437 # If successful, we will try to infer other params in the graph. 1438 my @input_sym_arg_types; 1439 my $can_infer_input_type = 1; 1440 for my $in_param(@{ $in_params }) 1441 { 1442 my $input_sym_arg_type = ($in_param->infer_type)[0]; 1443 if(not $input_sym_arg_type or @$input_sym_arg_type < 1) 1444 { 1445 $can_infer_input_type = 0; 1446 last; 1447 } 1448 else 1449 { 1450 push @input_sym_arg_types, $input_sym_arg_type->[0]; 1451 } 1452 } 1453 # Try to infer types of other parameters. 1454 if($can_infer_input_type) 1455 { 1456 my %params = map { $_->[0] => $_->[1] } zip(\@input_sym_names, \@input_sym_arg_types); 1457 ($arg_types, undef, $aux_types) = $out_params->infer_type(%params); 1458 if(not defined $arg_types or @$arg_types != @$arg_params) 1459 { 1460 $arg_types = [($default_dtype)x@$arg_params]; 1461 } 1462 if(not defined $aux_types or @$aux_types != @$aux_params) 1463 { 1464 $aux_types = [($default_dtype)x@$aux_params]; 1465 } 1466 } 1467 return ($arg_types, $aux_types); 1468} 1469 1470func _common_prefix(@names) 1471{ 1472 if(not @names) 1473 { 1474 return '' 1475 } 1476 my $prefix = $names[0]; 1477 for my $name (@names) 1478 { 1479 my $i = 0; 1480 while($i < length($prefix) and $i < length($name) and substr($prefix, $i, 1) eq substr($name, $i, 1)) 1481 { 1482 $i++; 1483 } 1484 $prefix = substr($prefix, 0, $i); 1485 } 1486 return $prefix; 1487} 1488 1489method forward($x, @args) 1490{ 1491 if(blessed $x and $x->isa('AI::MXNet::NDArray')) 1492 { 1493 my @out; 1494 my $out; 1495 my $ctx = $x->context; 1496 my $current_ctx = AI::MXNet::Context->current_ctx; 1497 AI::MXNet::Context->set_current($ctx); 1498 if(wantarray) 1499 { 1500 my @out = $self->_call_cached_op($x, @args); 1501 AI::MXNet::Context->set_current($current_ctx); 1502 return @out; 1503 } 1504 else 1505 { 1506 my $out = $self->_call_cached_op($x, @args); 1507 AI::MXNet::Context->set_current($current_ctx); 1508 return $out; 1509 } 1510 } 1511 assert( 1512 (blessed $x and $x->isa('AI::MXNet::Symbol')), 1513 "HybridBlock requires the first argument to forward be either ". 1514 "Symbol or NDArray, but got [".ref($x)."]" 1515 ); 1516 my $args = \@args; 1517 my $in_fmt; 1518 ($args, $in_fmt) = __PACKAGE__->_flatten([$x, @$args]); 1519 assert((Data::Dumper::Dumper($in_fmt) eq Data::Dumper::Dumper($self->_in_format)), "Invalid input format"); 1520 my $ret = $self->_cached_graph->[1]->deepcopy; 1521 my %in; 1522 for(zip($self->_cached_graph->[0], $args)) { 1523 my ($k, $v) = @$_; 1524 $in{$k->name} = $v; 1525 } 1526 $ret->_compose(%in); 1527 $ret = (__PACKAGE__->_regroup($ret, $self->_out_format))[0]; 1528 if(ref($ret) eq 'ARRAY' and wantarray) 1529 { 1530 return @$ret; 1531 } 1532 else 1533 { 1534 return $ret; 1535 } 1536} 1537 1538method _clear_cached_op() 1539{ 1540 my $tmp = $self->_cached_graph; 1541 $self->SUPER::_clear_cached_op; 1542 $self->_cached_graph($tmp); 1543} 1544 1545method hybrid_forward(@args) 1546{ 1547 confess('NotImplementedError'); 1548} 1549 1550=head2 imports 1551 1552 Import model previously saved by HybridBlock->export or 1553 Module->save_checkpoint as a SymbolBlock for use in Gluon. 1554 1555 Parameters 1556 ---------- 1557 $symbol_file : Str 1558 Path to symbol file. 1559 $input_names : Str|ArrayRef[Str] 1560 List of input variable names 1561 :$param_file : Str, optional 1562 Path to parameter file. 1563 $ctx : Context, default undef 1564 The context to initialize SymbolBlock on. 1565 1566 Returns 1567 ------- 1568 SymbolBlock 1569 SymbolBlock loaded from symbol and parameter files. 1570=cut 1571 1572method imports(Str $symbol_file, Str|ArrayRef[Str] $input_names, Maybe [Str] $param_file=, Maybe[AI::MXNet::Context] $ctx=) 1573{ 1574 my $sym = AI::MXNet::Symbol->load($symbol_file); 1575 $input_names = [$input_names] unless ref $input_names; 1576 my @inputs = map { AI::MXNet::Symbol->var($_) } @{ $input_names }; 1577 my $ret = __PACKAGE__->new($sym, \@inputs); 1578 if(defined $param_file) 1579 { 1580 $ret->load_parameters($param_file, (defined $ctx ? (ctx=>$ctx) : ())); 1581 } 1582 return $ret 1583} 1584 1585__PACKAGE__->register('AI::MXNet::Gluon'); 1586 15871; 1588