1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18package AI::MXNet::RNN::Params; 19use Mouse; 20use AI::MXNet::Function::Parameters; 21 22=head1 NAME 23 24 AI::MXNet::RNN::Params - A container for holding variables. 25=cut 26 27=head1 DESCRIPTION 28 29 A container for holding variables. 30 Used by RNN cells for parameter sharing between cells. 31 32 Parameters 33 ---------- 34 prefix : str 35 All variables name created by this container will 36 be prepended with the prefix 37=cut 38has '_prefix' => (is => 'ro', init_arg => 'prefix', isa => 'Str', default => ''); 39has '_params' => (is => 'rw', init_arg => undef); 40around BUILDARGS => sub { 41 my $orig = shift; 42 my $class = shift; 43 return $class->$orig(prefix => $_[0]) if @_ == 1; 44 return $class->$orig(@_); 45}; 46 47sub BUILD 48{ 49 my $self = shift; 50 $self->_params({}); 51} 52 53 54=head2 get 55 56 Get a variable with the name or create a new one if does not exist. 57 58 Parameters 59 ---------- 60 $name : str 61 name of the variable 62 @kwargs: 63 more arguments that are passed to mx->sym->Variable call 64=cut 65 66method get(Str $name, @kwargs) 67{ 68 $name = $self->_prefix . $name; 69 if(not exists $self->_params->{$name}) 70 { 71 $self->_params->{$name} = AI::MXNet::Symbol->Variable($name, @kwargs); 72 } 73 return $self->_params->{$name}; 74} 75 76package AI::MXNet::RNN::Cell::Base; 77=head1 NAME 78 79 AI::MXNet::RNNCell::Base 80=cut 81 82=head1 DESCRIPTION 83 84 Abstract base class for RNN cells 85 86 Parameters 87 ---------- 88 prefix : str 89 prefix for name of layers 90 (and name of weight if params is undef) 91 params : AI::MXNet::RNN::Params or undef 92 container for weight sharing between cells. 93 created if undef. 94=cut 95 96use AI::MXNet::Base; 97use Mouse; 98use overload "&{}" => sub { my $self = shift; sub { $self->call(@_) } }; 99has '_prefix' => (is => 'rw', init_arg => 'prefix', isa => 'Str', default => ''); 100has '_params' => (is => 'rw', init_arg => 'params', isa => 'Maybe[AI::MXNet::RNN::Params]'); 101has [qw/_own_params 102 _modified 103 _init_counter 104 _counter 105 /] => (is => 'rw', init_arg => undef); 106 107around BUILDARGS => sub { 108 my $orig = shift; 109 my $class = shift; 110 return $class->$orig(prefix => $_[0]) if @_ == 1; 111 return $class->$orig(@_); 112}; 113 114sub BUILD 115{ 116 my $self = shift; 117 if(not defined $self->_params) 118 { 119 $self->_own_params(1); 120 $self->_params(AI::MXNet::RNN::Params->new($self->_prefix)); 121 } 122 else 123 { 124 $self->_own_params(0); 125 } 126 $self->_modified(0); 127 $self->reset; 128} 129 130=head2 reset 131 132 Reset before re-using the cell for another graph 133=cut 134 135method reset() 136{ 137 $self->_init_counter(-1); 138 $self->_counter(-1); 139} 140 141=head2 call 142 143 Construct symbol for one step of RNN. 144 145 Parameters 146 ---------- 147 $inputs : mx->sym->Variable 148 input symbol, 2D, batch * num_units 149 $states : mx->sym->Variable or ArrayRef[AI::MXNet::Symbol] 150 state from previous step or begin_state(). 151 152 Returns 153 ------- 154 $output : AI::MXNet::Symbol 155 output symbol 156 $states : ArrayRef[AI::MXNet::Symbol] 157 state to next step of RNN. 158 Can be called via overloaded &{}: &{$cell}($inputs, $states); 159=cut 160 161method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states) 162{ 163 confess("Not Implemented"); 164} 165 166method _gate_names() 167{ 168 ['']; 169} 170 171=head2 params 172 173 Parameters of this cell 174=cut 175 176method params() 177{ 178 $self->_own_params(0); 179 return $self->_params; 180} 181 182=head2 state_shape 183 184 shape(s) of states 185=cut 186 187method state_shape() 188{ 189 return [map { $_->{shape} } @{ $self->state_info }]; 190} 191 192=head2 state_info 193 194 shape and layout information of states 195=cut 196 197method state_info() 198{ 199 confess("Not Implemented"); 200} 201 202=head2 begin_state 203 204 Initial state for this cell. 205 206 Parameters 207 ---------- 208 :$func : sub ref, default is AI::MXNet::Symbol->can('zeros') 209 Function for creating initial state. 210 Can be AI::MXNet::Symbol->can('zeros'), 211 AI::MXNet::Symbol->can('uniform'), AI::MXNet::Symbol->can('Variable') etc. 212 Use AI::MXNet::Symbol->can('Variable') if you want to directly 213 feed the input as states. 214 @kwargs : 215 more keyword arguments passed to func. For example 216 mean, std, dtype, etc. 217 218 Returns 219 ------- 220 $states : ArrayRef[AI::MXNet::Symbol] 221 starting states for first RNN step 222=cut 223 224method begin_state(CodeRef :$func=AI::MXNet::Symbol->can('zeros'), @kwargs) 225{ 226 assert( 227 (not $self->_modified), 228 "After applying modifier cells (e.g. DropoutCell) the base " 229 ."cell cannot be called directly. Call the modifier cell instead." 230 ); 231 my @states; 232 my $func_needs_named_name = $func ne AI::MXNet::Symbol->can('Variable'); 233 for my $info (@{ $self->state_info }) 234 { 235 $self->_init_counter($self->_init_counter + 1); 236 my @name = (sprintf("%sbegin_state_%d", $self->_prefix, $self->_init_counter)); 237 my %info = %{ $info//{} }; 238 if($func_needs_named_name) 239 { 240 unshift(@name, 'name'); 241 } 242 else 243 { 244 if(exists $info{__layout__}) 245 { 246 $info{kwargs} = { __layout__ => delete $info{__layout__} }; 247 } 248 } 249 my %kwargs = (@kwargs, %info); 250 my $state = $func->( 251 'AI::MXNet::Symbol', 252 @name, 253 %kwargs 254 ); 255 push @states, $state; 256 } 257 return \@states; 258} 259 260=head2 unpack_weights 261 262 Unpack fused weight matrices into separate 263 weight matrices 264 265 Parameters 266 ---------- 267 $args : HashRef[AI::MXNet::NDArray] 268 hash ref containing packed weights. 269 usually from AI::MXNet::Module->get_output() 270 271 Returns 272 ------- 273 $args : HashRef[AI::MXNet::NDArray] 274 hash ref with weights associated with 275 this cell, unpacked. 276=cut 277 278method unpack_weights(HashRef[AI::MXNet::NDArray] $args) 279{ 280 my %args = %{ $args }; 281 my $h = $self->_num_hidden; 282 for my $group_name ('i2h', 'h2h') 283 { 284 my $weight = delete $args{ sprintf('%s%s_weight', $self->_prefix, $group_name) }; 285 my $bias = delete $args{ sprintf('%s%s_bias', $self->_prefix, $group_name) }; 286 enumerate(sub { 287 my ($j, $name) = @_; 288 my $wname = sprintf('%s%s%s_weight', $self->_prefix, $group_name, $name); 289 $args->{$wname} = $weight->slice([$j*$h,($j+1)*$h-1])->copy; 290 my $bname = sprintf('%s%s%s_bias', $self->_prefix, $group_name, $name); 291 $args->{$bname} = $bias->slice([$j*$h,($j+1)*$h-1])->copy; 292 }, $self->_gate_names); 293 } 294 return \%args; 295} 296 297=head2 pack_weights 298 299 Pack fused weight matrices into common 300 weight matrices 301 302 Parameters 303 ---------- 304 args : HashRef[AI::MXNet::NDArray] 305 hash ref containing unpacked weights. 306 307 Returns 308 ------- 309 $args : HashRef[AI::MXNet::NDArray] 310 hash ref with weights associated with 311 this cell, packed. 312=cut 313 314method pack_weights(HashRef[AI::MXNet::NDArray] $args) 315{ 316 my %args = %{ $args }; 317 my $h = $self->_num_hidden; 318 for my $group_name ('i2h', 'h2h') 319 { 320 my @weight; 321 my @bias; 322 for my $name (@{ $self->_gate_names }) 323 { 324 my $wname = sprintf('%s%s%s_weight', $self->_prefix, $group_name, $name); 325 push @weight, delete $args{$wname}; 326 my $bname = sprintf('%s%s%s_bias', $self->_prefix, $group_name, $name); 327 push @bias, delete $args{$bname}; 328 } 329 $args{ sprintf('%s%s_weight', $self->_prefix, $group_name) } = AI::MXNet::NDArray->concatenate( 330 \@weight 331 ); 332 $args{ sprintf('%s%s_bias', $self->_prefix, $group_name) } = AI::MXNet::NDArray->concatenate( 333 \@bias 334 ); 335 } 336 return \%args; 337} 338 339=head2 unroll 340 341 Unroll an RNN cell across time steps. 342 343 Parameters 344 ---------- 345 :$length : Int 346 number of steps to unroll 347 :$inputs : AI::MXNet::Symbol, array ref of Symbols, or undef 348 if inputs is a single Symbol (usually the output 349 of Embedding symbol), it should have shape 350 of [$batch_size, $length, ...] if layout == 'NTC' (batch, time series) 351 or ($length, $batch_size, ...) if layout == 'TNC' (time series, batch). 352 353 If inputs is a array ref of symbols (usually output of 354 previous unroll), they should all have shape 355 ($batch_size, ...). 356 357 If inputs is undef, a placeholder variables are 358 automatically created. 359 :$begin_state : array ref of Symbol 360 input states. Created by begin_state() 361 or output state of another cell. Created 362 from begin_state() if undef. 363 :$input_prefix : str 364 prefix for automatically created input 365 placehodlers. 366 :$layout : str 367 layout of input symbol. Only used if the input 368 is a single Symbol. 369 :$merge_outputs : Bool 370 If 0, returns outputs as an array ref of Symbols. 371 If 1, concatenates the output across the time steps 372 and returns a single symbol with the shape 373 [$batch_size, $length, ...) if the layout equal to 'NTC', 374 or [$length, $batch_size, ...) if the layout equal tp 'TNC'. 375 If undef, output whatever is faster 376 377 Returns 378 ------- 379 $outputs : array ref of Symbol or Symbol 380 output symbols. 381 $states : Symbol or nested list of Symbol 382 has the same structure as begin_state() 383=cut 384 385 386method unroll( 387 Int $length, 388 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, 389 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, 390 Str :$input_prefix='', 391 Str :$layout='NTC', 392 Maybe[Bool] :$merge_outputs= 393) 394{ 395 $self->reset; 396 my $axis = index($layout, 'T'); 397 if(not defined $inputs) 398 { 399 $inputs = [ 400 map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1) 401 ]; 402 } 403 elsif(blessed($inputs)) 404 { 405 assert( 406 (@{ $inputs->list_outputs() } == 1), 407 "unroll doesn't allow grouped symbol as input. Please " 408 ."convert to list first or let unroll handle slicing" 409 ); 410 $inputs = AI::MXNet::Symbol->SliceChannel( 411 $inputs, 412 axis => $axis, 413 num_outputs => $length, 414 squeeze_axis => 1 415 ); 416 } 417 else 418 { 419 assert(@$inputs == $length); 420 } 421 $begin_state //= $self->begin_state; 422 my $states = $begin_state; 423 my $outputs; 424 my @inputs = @{ $inputs }; 425 for my $i (0..$length-1) 426 { 427 my $output; 428 ($output, $states) = $self->( 429 $inputs[$i], 430 $states 431 ); 432 push @$outputs, $output; 433 } 434 if($merge_outputs) 435 { 436 @$outputs = map { AI::MXNet::Symbol->expand_dims($_, axis => $axis) } @$outputs; 437 $outputs = AI::MXNet::Symbol->Concat(@$outputs, dim => $axis); 438 } 439 return($outputs, $states); 440} 441 442method _get_activation($inputs, $activation, @kwargs) 443{ 444 if(not ref $activation) 445 { 446 return AI::MXNet::Symbol->Activation($inputs, act_type => $activation, @kwargs); 447 } 448 else 449 { 450 return $activation->($inputs, @kwargs); 451 } 452} 453 454method _cells_state_shape($cells) 455{ 456 return [map { @{ $_->state_shape } } @$cells]; 457} 458 459method _cells_state_info($cells) 460{ 461 return [map { @{ $_->state_info } } @$cells]; 462} 463 464method _cells_begin_state($cells, @kwargs) 465{ 466 return [map { @{ $_->begin_state(@kwargs) } } @$cells]; 467} 468 469method _cells_unpack_weights($cells, $args) 470{ 471 $args = $_->unpack_weights($args) for @$cells; 472 return $args; 473} 474 475method _cells_pack_weights($cells, $args) 476{ 477 $args = $_->pack_weights($args) for @$cells; 478 return $args; 479} 480 481package AI::MXNet::RNN::Cell; 482use Mouse; 483extends 'AI::MXNet::RNN::Cell::Base'; 484 485=head1 NAME 486 487 AI::MXNet::RNN::Cell 488=cut 489 490=head1 DESCRIPTION 491 492 Simple recurrent neural network cell 493 494 Parameters 495 ---------- 496 num_hidden : int 497 number of units in output symbol 498 activation : str or Symbol, default 'tanh' 499 type of activation function 500 prefix : str, default 'rnn_' 501 prefix for name of layers 502 (and name of weight if params is undef) 503 params : AI::MXNet::RNNParams or undef 504 container for weight sharing between cells. 505 created if undef. 506=cut 507 508has '_num_hidden' => (is => 'ro', init_arg => 'num_hidden', isa => 'Int', required => 1); 509has 'forget_bias' => (is => 'ro', isa => 'Num'); 510has '_activation' => ( 511 is => 'ro', 512 init_arg => 'activation', 513 isa => 'Activation', 514 default => 'tanh' 515); 516has '+_prefix' => (default => 'rnn_'); 517has [qw/_iW _iB 518 _hW _hB/] => (is => 'rw', init_arg => undef); 519 520around BUILDARGS => sub { 521 my $orig = shift; 522 my $class = shift; 523 return $class->$orig(num_hidden => $_[0]) if @_ == 1; 524 return $class->$orig(@_); 525}; 526 527sub BUILD 528{ 529 my $self = shift; 530 $self->_iW($self->params->get('i2h_weight')); 531 $self->_iB( 532 $self->params->get( 533 'i2h_bias', 534 (defined($self->forget_bias) 535 ? (init => AI::MXNet::LSTMBias->new(forget_bias => $self->forget_bias)) 536 : () 537 ) 538 ) 539 ); 540 $self->_hW($self->params->get('h2h_weight')); 541 $self->_hB($self->params->get('h2h_bias')); 542} 543 544method state_info() 545{ 546 return [{ shape => [0, $self->_num_hidden], __layout__ => 'NC' }]; 547} 548 549method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) 550{ 551 $self->_counter($self->_counter + 1); 552 my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); 553 my $i2h = AI::MXNet::Symbol->FullyConnected( 554 data => $inputs, 555 weight => $self->_iW, 556 bias => $self->_iB, 557 num_hidden => $self->_num_hidden, 558 name => "${name}i2h" 559 ); 560 my $h2h = AI::MXNet::Symbol->FullyConnected( 561 data => @{$states}[0], 562 weight => $self->_hW, 563 bias => $self->_hB, 564 num_hidden => $self->_num_hidden, 565 name => "${name}h2h" 566 ); 567 my $output = $self->_get_activation( 568 $i2h + $h2h, 569 $self->_activation, 570 name => "${name}out" 571 ); 572 return ($output, [$output]); 573} 574 575package AI::MXNet::RNN::LSTMCell; 576use Mouse; 577use AI::MXNet::Base; 578extends 'AI::MXNet::RNN::Cell'; 579 580=head1 NAME 581 582 AI::MXNet::RNN::LSTMCell 583=cut 584 585=head1 DESCRIPTION 586 587 Long-Short Term Memory (LSTM) network cell. 588 589 Parameters 590 ---------- 591 num_hidden : int 592 number of units in output symbol 593 prefix : str, default 'lstm_' 594 prefix for name of layers 595 (and name of weight if params is undef) 596 params : AI::MXNet::RNN::Params or None 597 container for weight sharing between cells. 598 created if undef. 599 forget_bias : bias added to forget gate, default 1.0. 600 Jozefowicz et al. 2015 recommends setting this to 1.0 601=cut 602 603has '+_prefix' => (default => 'lstm_'); 604has '+_activation' => (init_arg => undef); 605has '+forget_bias' => (is => 'ro', isa => 'Num', default => 1); 606 607method state_info() 608{ 609 return [{ shape => [0, $self->_num_hidden], __layout__ => 'NC' } , { shape => [0, $self->_num_hidden], __layout__ => 'NC' }]; 610} 611 612method _gate_names() 613{ 614 [qw/_i _f _c _o/]; 615} 616 617method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) 618{ 619 $self->_counter($self->_counter + 1); 620 my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); 621 my @states = @{ $states }; 622 my $i2h = AI::MXNet::Symbol->FullyConnected( 623 data => $inputs, 624 weight => $self->_iW, 625 bias => $self->_iB, 626 num_hidden => $self->_num_hidden*4, 627 name => "${name}i2h" 628 ); 629 my $h2h = AI::MXNet::Symbol->FullyConnected( 630 data => $states[0], 631 weight => $self->_hW, 632 bias => $self->_hB, 633 num_hidden => $self->_num_hidden*4, 634 name => "${name}h2h" 635 ); 636 my $gates = $i2h + $h2h; 637 my @slice_gates = @{ AI::MXNet::Symbol->SliceChannel( 638 $gates, num_outputs => 4, name => "${name}slice" 639 ) }; 640 my $in_gate = AI::MXNet::Symbol->Activation( 641 $slice_gates[0], act_type => "sigmoid", name => "${name}i" 642 ); 643 my $forget_gate = AI::MXNet::Symbol->Activation( 644 $slice_gates[1], act_type => "sigmoid", name => "${name}f" 645 ); 646 my $in_transform = AI::MXNet::Symbol->Activation( 647 $slice_gates[2], act_type => "tanh", name => "${name}c" 648 ); 649 my $out_gate = AI::MXNet::Symbol->Activation( 650 $slice_gates[3], act_type => "sigmoid", name => "${name}o" 651 ); 652 my $next_c = AI::MXNet::Symbol->_plus( 653 $forget_gate * $states[1], $in_gate * $in_transform, 654 name => "${name}state" 655 ); 656 my $next_h = AI::MXNet::Symbol->_mul( 657 $out_gate, 658 AI::MXNet::Symbol->Activation( 659 $next_c, act_type => "tanh" 660 ), 661 name => "${name}out" 662 ); 663 return ($next_h, [$next_h, $next_c]); 664 665} 666 667package AI::MXNet::RNN::GRUCell; 668use Mouse; 669use AI::MXNet::Base; 670extends 'AI::MXNet::RNN::Cell'; 671 672=head1 NAME 673 674 AI::MXNet::RNN::GRUCell 675=cut 676 677=head1 DESCRIPTION 678 679 Gated Rectified Unit (GRU) network cell. 680 Note: this is an implementation of the cuDNN version of GRUs 681 (slight modification compared to Cho et al. 2014). 682 683 Parameters 684 ---------- 685 num_hidden : int 686 number of units in output symbol 687 prefix : str, default 'gru_' 688 prefix for name of layers 689 (and name of weight if params is undef) 690 params : AI::MXNet::RNN::Params or undef 691 container for weight sharing between cells. 692 created if undef. 693=cut 694 695has '+_prefix' => (default => 'gru_'); 696 697method _gate_names() 698{ 699 [qw/_r _z _o/]; 700} 701 702method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) 703{ 704 $self->_counter($self->_counter + 1); 705 my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); 706 my $prev_state_h = @{ $states }[0]; 707 my $i2h = AI::MXNet::Symbol->FullyConnected( 708 data => $inputs, 709 weight => $self->_iW, 710 bias => $self->_iB, 711 num_hidden => $self->_num_hidden*3, 712 name => "${name}i2h" 713 ); 714 my $h2h = AI::MXNet::Symbol->FullyConnected( 715 data => $prev_state_h, 716 weight => $self->_hW, 717 bias => $self->_hB, 718 num_hidden => $self->_num_hidden*3, 719 name => "${name}h2h" 720 ); 721 my ($i2h_r, $i2h_z); 722 ($i2h_r, $i2h_z, $i2h) = @{ AI::MXNet::Symbol->SliceChannel( 723 $i2h, num_outputs => 3, name => "${name}_i2h_slice" 724 ) }; 725 my ($h2h_r, $h2h_z); 726 ($h2h_r, $h2h_z, $h2h) = @{ AI::MXNet::Symbol->SliceChannel( 727 $h2h, num_outputs => 3, name => "${name}_h2h_slice" 728 ) }; 729 my $reset_gate = AI::MXNet::Symbol->Activation( 730 $i2h_r + $h2h_r, act_type => "sigmoid", name => "${name}_r_act" 731 ); 732 my $update_gate = AI::MXNet::Symbol->Activation( 733 $i2h_z + $h2h_z, act_type => "sigmoid", name => "${name}_z_act" 734 ); 735 my $next_h_tmp = AI::MXNet::Symbol->Activation( 736 $i2h + $reset_gate * $h2h, act_type => "tanh", name => "${name}_h_act" 737 ); 738 my $next_h = AI::MXNet::Symbol->_plus( 739 (1 - $update_gate) * $next_h_tmp, $update_gate * $prev_state_h, 740 name => "${name}out" 741 ); 742 return ($next_h, [$next_h]); 743} 744 745package AI::MXNet::RNN::FusedCell; 746use Mouse; 747use AI::MXNet::Types; 748use AI::MXNet::Base; 749extends 'AI::MXNet::RNN::Cell::Base'; 750 751=head1 NAME 752 753 AI::MXNet::RNN::FusedCell 754=cut 755 756=head1 DESCRIPTION 757 758 Fusing RNN layers across time step into one kernel. 759 Improves speed but is less flexible. Currently only 760 supported if using cuDNN on GPU. 761=cut 762 763has '_num_hidden' => (is => 'ro', isa => 'Int', init_arg => 'num_hidden', required => 1); 764has '_num_layers' => (is => 'ro', isa => 'Int', init_arg => 'num_layers', default => 1); 765has '_dropout' => (is => 'ro', isa => 'Num', init_arg => 'dropout', default => 0); 766has '_get_next_state' => (is => 'ro', isa => 'Bool', init_arg => 'get_next_state', default => 0); 767has '_bidirectional' => (is => 'ro', isa => 'Bool', init_arg => 'bidirectional', default => 0); 768has 'forget_bias' => (is => 'ro', isa => 'Num', default => 1); 769has 'initializer' => (is => 'rw', isa => 'Maybe[Initializer]'); 770has '_mode' => ( 771 is => 'ro', 772 isa => enum([qw/rnn_relu rnn_tanh lstm gru/]), 773 init_arg => 'mode', 774 default => 'lstm' 775); 776has [qw/_parameter 777 _directions/] => (is => 'rw', init_arg => undef); 778 779around BUILDARGS => sub { 780 my $orig = shift; 781 my $class = shift; 782 return $class->$orig(num_hidden => $_[0]) if @_ == 1; 783 return $class->$orig(@_); 784}; 785 786sub BUILD 787{ 788 my $self = shift; 789 if(not $self->_prefix) 790 { 791 $self->_prefix($self->_mode.'_'); 792 } 793 if(not defined $self->initializer) 794 { 795 $self->initializer( 796 AI::MXNet::Xavier->new( 797 factor_type => 'in', 798 magnitude => 2.34 799 ) 800 ); 801 } 802 if(not $self->initializer->isa('AI::MXNet::FusedRNN')) 803 { 804 $self->initializer( 805 AI::MXNet::FusedRNN->new( 806 init => $self->initializer, 807 num_hidden => $self->_num_hidden, 808 num_layers => $self->_num_layers, 809 mode => $self->_mode, 810 bidirectional => $self->_bidirectional, 811 forget_bias => $self->forget_bias 812 ) 813 ); 814 } 815 $self->_parameter($self->params->get('parameters', init => $self->initializer)); 816 $self->_directions($self->_bidirectional ? [qw/l r/] : ['l']); 817} 818 819 820method state_info() 821{ 822 my $b = @{ $self->_directions }; 823 my $n = $self->_mode eq 'lstm' ? 2 : 1; 824 return [map { +{ shape => [$b*$self->_num_layers, 0, $self->_num_hidden], __layout__ => 'LNC' } } 0..$n-1]; 825} 826 827method _gate_names() 828{ 829 return { 830 rnn_relu => [''], 831 rnn_tanh => [''], 832 lstm => [qw/_i _f _c _o/], 833 gru => [qw/_r _z _o/] 834 }->{ $self->_mode }; 835} 836 837method _num_gates() 838{ 839 return scalar(@{ $self->_gate_names }) 840} 841 842method _slice_weights($arr, $li, $lh) 843{ 844 my %args; 845 my @gate_names = @{ $self->_gate_names }; 846 my @directions = @{ $self->_directions }; 847 848 my $b = @directions; 849 my $p = 0; 850 for my $layer (0..$self->_num_layers-1) 851 { 852 for my $direction (@directions) 853 { 854 for my $gate (@gate_names) 855 { 856 my $name = sprintf('%s%s%d_i2h%s_weight', $self->_prefix, $direction, $layer, $gate); 857 my $size; 858 if($layer > 0) 859 { 860 $size = $b*$lh*$lh; 861 $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $b*$lh]); 862 } 863 else 864 { 865 $size = $li*$lh; 866 $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $li]); 867 } 868 $p += $size; 869 } 870 for my $gate (@gate_names) 871 { 872 my $name = sprintf('%s%s%d_h2h%s_weight', $self->_prefix, $direction, $layer, $gate); 873 my $size = $lh**2; 874 $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $lh]); 875 $p += $size; 876 } 877 } 878 } 879 for my $layer (0..$self->_num_layers-1) 880 { 881 for my $direction (@directions) 882 { 883 for my $gate (@gate_names) 884 { 885 my $name = sprintf('%s%s%d_i2h%s_bias', $self->_prefix, $direction, $layer, $gate); 886 $args{$name} = $arr->slice([$p,$p+$lh-1]); 887 $p += $lh; 888 } 889 for my $gate (@gate_names) 890 { 891 my $name = sprintf('%s%s%d_h2h%s_bias', $self->_prefix, $direction, $layer, $gate); 892 $args{$name} = $arr->slice([$p,$p+$lh-1]); 893 $p += $lh; 894 } 895 } 896 } 897 assert($p == $arr->size, "Invalid parameters size for FusedRNNCell"); 898 return %args; 899} 900 901method unpack_weights(HashRef[AI::MXNet::NDArray] $args) 902{ 903 my %args = %{ $args }; 904 my $arr = delete $args{ $self->_parameter->name }; 905 my $b = @{ $self->_directions }; 906 my $m = $self->_num_gates; 907 my $h = $self->_num_hidden; 908 my $num_input = int(int(int($arr->size/$b)/$h)/$m) - ($self->_num_layers - 1)*($h+$b*$h+2) - $h - 2; 909 my %nargs = $self->_slice_weights($arr, $num_input, $self->_num_hidden); 910 %args = (%args, map { $_ => $nargs{$_}->copy } keys %nargs); 911 return \%args 912} 913 914method pack_weights(HashRef[AI::MXNet::NDArray] $args) 915{ 916 my %args = %{ $args }; 917 my $b = @{ $self->_directions }; 918 my $m = $self->_num_gates; 919 my @c = @{ $self->_gate_names }; 920 my $h = $self->_num_hidden; 921 my $w0 = $args{ sprintf('%sl0_i2h%s_weight', $self->_prefix, $c[0]) }; 922 my $num_input = $w0->shape->[1]; 923 my $total = ($num_input+$h+2)*$h*$m*$b + ($self->_num_layers-1)*$m*$h*($h+$b*$h+2)*$b; 924 my $arr = AI::MXNet::NDArray->zeros([$total], ctx => $w0->context, dtype => $w0->dtype); 925 my %nargs = $self->_slice_weights($arr, $num_input, $h); 926 while(my ($name, $nd) = each %nargs) 927 { 928 $nd .= delete $args{ $name }; 929 } 930 $args{ $self->_parameter->name } = $arr; 931 return \%args; 932} 933 934method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) 935{ 936 confess("AI::MXNet::RNN::FusedCell cannot be stepped. Please use unroll"); 937} 938 939method unroll( 940 Int $length, 941 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, 942 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, 943 Str :$input_prefix='', 944 Str :$layout='NTC', 945 Maybe[Bool] :$merge_outputs= 946) 947{ 948 $self->reset; 949 my $axis = index($layout, 'T'); 950 $inputs //= AI::MXNet::Symbol->Variable("${input_prefix}data"); 951 if(blessed($inputs)) 952 { 953 assert( 954 (@{ $inputs->list_outputs() } == 1), 955 "unroll doesn't allow grouped symbol as input. Please " 956 ."convert to list first or let unroll handle slicing" 957 ); 958 if($axis == 1) 959 { 960 AI::MXNet::Logging->warning( 961 "NTC layout detected. Consider using " 962 ."TNC for RNN::FusedCell for faster speed" 963 ); 964 $inputs = AI::MXNet::Symbol->SwapAxis($inputs, dim1 => 0, dim2 => 1); 965 } 966 else 967 { 968 assert($axis == 0, "Unsupported layout $layout"); 969 } 970 } 971 else 972 { 973 assert(@$inputs == $length); 974 $inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis => 0) } @{ $inputs }]; 975 $inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim => 0); 976 } 977 $begin_state //= $self->begin_state; 978 my $states = $begin_state; 979 my @states = @{ $states }; 980 my %states; 981 if($self->_mode eq 'lstm') 982 { 983 %states = (state => $states[0], state_cell => $states[1]); 984 } 985 else 986 { 987 %states = (state => $states[0]); 988 } 989 my $rnn = AI::MXNet::Symbol->RNN( 990 data => $inputs, 991 parameters => $self->_parameter, 992 state_size => $self->_num_hidden, 993 num_layers => $self->_num_layers, 994 bidirectional => $self->_bidirectional, 995 p => $self->_dropout, 996 state_outputs => $self->_get_next_state, 997 mode => $self->_mode, 998 name => $self->_prefix.'rnn', 999 %states 1000 ); 1001 my $outputs; 1002 my %attr = (__layout__ => 'LNC'); 1003 if(not $self->_get_next_state) 1004 { 1005 ($outputs, $states) = ($rnn, []); 1006 } 1007 elsif($self->_mode eq 'lstm') 1008 { 1009 my @rnn = @{ $rnn }; 1010 $rnn[1]->_set_attr(%attr); 1011 $rnn[2]->_set_attr(%attr); 1012 ($outputs, $states) = ($rnn[0], [$rnn[1], $rnn[2]]); 1013 } 1014 else 1015 { 1016 my @rnn = @{ $rnn }; 1017 $rnn[1]->_set_attr(%attr); 1018 ($outputs, $states) = ($rnn[0], [$rnn[1]]); 1019 } 1020 if(defined $merge_outputs and not $merge_outputs) 1021 { 1022 AI::MXNet::Logging->warning( 1023 "Call RNN::FusedCell->unroll with merge_outputs=1 " 1024 ."for faster speed" 1025 ); 1026 $outputs = [@ { 1027 AI::MXNet::Symbol->SliceChannel( 1028 $outputs, 1029 axis => 0, 1030 num_outputs => $length, 1031 squeeze_axis => 1 1032 ) 1033 }]; 1034 } 1035 elsif($axis == 1) 1036 { 1037 $outputs = AI::MXNet::Symbol->SwapAxis($outputs, dim1 => 0, dim2 => 1); 1038 } 1039 return ($outputs, $states); 1040} 1041 1042=head2 unfuse 1043 1044 Unfuse the fused RNN 1045 1046 Returns 1047 ------- 1048 $cell : AI::MXNet::RNN::SequentialCell 1049 unfused cell that can be used for stepping, and can run on CPU. 1050=cut 1051 1052method unfuse() 1053{ 1054 my $stack = AI::MXNet::RNN::SequentialCell->new; 1055 my $get_cell = { 1056 rnn_relu => sub { 1057 AI::MXNet::RNN::Cell->new( 1058 num_hidden => $self->_num_hidden, 1059 activation => 'relu', 1060 prefix => shift 1061 ) 1062 }, 1063 rnn_tanh => sub { 1064 AI::MXNet::RNN::Cell->new( 1065 num_hidden => $self->_num_hidden, 1066 activation => 'tanh', 1067 prefix => shift 1068 ) 1069 }, 1070 lstm => sub { 1071 AI::MXNet::RNN::LSTMCell->new( 1072 num_hidden => $self->_num_hidden, 1073 prefix => shift 1074 ) 1075 }, 1076 gru => sub { 1077 AI::MXNet::RNN::GRUCell->new( 1078 num_hidden => $self->_num_hidden, 1079 prefix => shift 1080 ) 1081 }, 1082 }->{ $self->_mode }; 1083 for my $i (0..$self->_num_layers-1) 1084 { 1085 if($self->_bidirectional) 1086 { 1087 $stack->add( 1088 AI::MXNet::RNN::BidirectionalCell->new( 1089 $get_cell->(sprintf('%sl%d_', $self->_prefix, $i)), 1090 $get_cell->(sprintf('%sr%d_', $self->_prefix, $i)), 1091 output_prefix => sprintf('%sbi_%s_%d', $self->_prefix, $self->_mode, $i) 1092 ) 1093 ); 1094 } 1095 else 1096 { 1097 $stack->add($get_cell->(sprintf('%sl%d_', $self->_prefix, $i))); 1098 } 1099 } 1100 return $stack; 1101} 1102 1103package AI::MXNet::RNN::SequentialCell; 1104use Mouse; 1105use AI::MXNet::Base; 1106extends 'AI::MXNet::RNN::Cell::Base'; 1107 1108=head1 NAME 1109 1110 AI:MXNet::RNN::SequentialCell 1111=cut 1112 1113=head1 DESCRIPTION 1114 1115 Sequentially stacking multiple RNN cells 1116 1117 Parameters 1118 ---------- 1119 params : AI::MXNet::RNN::Params or undef 1120 container for weight sharing between cells. 1121 created if undef. 1122=cut 1123 1124has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef); 1125 1126sub BUILD 1127{ 1128 my ($self, $original_arguments) = @_; 1129 $self->_override_cell_params(defined $original_arguments->{params}); 1130 $self->_cells([]); 1131} 1132 1133=head2 add 1134 1135 Append a cell to the stack. 1136 1137 Parameters 1138 ---------- 1139 $cell : AI::MXNet::RNN::Cell::Base 1140=cut 1141 1142method add(AI::MXNet::RNN::Cell::Base $cell) 1143{ 1144 push @{ $self->_cells }, $cell; 1145 if($self->_override_cell_params) 1146 { 1147 assert( 1148 $cell->_own_params, 1149 "Either specify params for SequentialRNNCell " 1150 ."or child cells, not both." 1151 ); 1152 %{ $cell->params->_params } = (%{ $cell->params->_params }, %{ $self->params->_params }); 1153 } 1154 %{ $self->params->_params } = (%{ $self->params->_params }, %{ $cell->params->_params }); 1155} 1156 1157method state_info() 1158{ 1159 return $self->_cells_state_info($self->_cells); 1160} 1161 1162method begin_state(@kwargs) 1163{ 1164 assert( 1165 (not $self->_modified), 1166 "After applying modifier cells (e.g. DropoutCell) the base " 1167 ."cell cannot be called directly. Call the modifier cell instead." 1168 ); 1169 return $self->_cells_begin_state($self->_cells, @kwargs); 1170} 1171 1172method unpack_weights(HashRef[AI::MXNet::NDArray] $args) 1173{ 1174 return $self->_cells_unpack_weights($self->_cells, $args) 1175} 1176 1177method pack_weights(HashRef[AI::MXNet::NDArray] $args) 1178{ 1179 return $self->_cells_pack_weights($self->_cells, $args); 1180} 1181 1182method call($inputs, $states) 1183{ 1184 $self->_counter($self->_counter + 1); 1185 my @next_states; 1186 my $p = 0; 1187 for my $cell (@{ $self->_cells }) 1188 { 1189 assert(not $cell->isa('AI::MXNet::BidirectionalCell')); 1190 my $n = scalar(@{ $cell->state_info }); 1191 my $state = [@{ $states }[$p..$p+$n-1]]; 1192 $p += $n; 1193 ($inputs, $state) = $cell->($inputs, $state); 1194 push @next_states, $state; 1195 } 1196 return ($inputs, [map { @$_} @next_states]); 1197} 1198 1199method unroll( 1200 Int $length, 1201 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, 1202 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, 1203 Str :$input_prefix='', 1204 Str :$layout='NTC', 1205 Maybe[Bool] :$merge_outputs= 1206) 1207{ 1208 my $num_cells = @{ $self->_cells }; 1209 $begin_state //= $self->begin_state; 1210 my $p = 0; 1211 my $states; 1212 my @next_states; 1213 enumerate(sub { 1214 my ($i, $cell) = @_; 1215 my $n = @{ $cell->state_info }; 1216 $states = [@{$begin_state}[$p..$p+$n-1]]; 1217 $p += $n; 1218 ($inputs, $states) = $cell->unroll( 1219 $length, 1220 inputs => $inputs, 1221 input_prefix => $input_prefix, 1222 begin_state => $states, 1223 layout => $layout, 1224 merge_outputs => ($i < $num_cells-1) ? undef : $merge_outputs 1225 ); 1226 push @next_states, $states; 1227 }, $self->_cells); 1228 return ($inputs, [map { @{ $_ } } @next_states]); 1229} 1230 1231package AI::MXNet::RNN::BidirectionalCell; 1232use Mouse; 1233use AI::MXNet::Base; 1234extends 'AI::MXNet::RNN::Cell::Base'; 1235 1236=head1 NAME 1237 1238 AI::MXNet::RNN::BidirectionalCell 1239=cut 1240 1241=head1 DESCRIPTION 1242 1243 Bidirectional RNN cell 1244 1245 Parameters 1246 ---------- 1247 l_cell : AI::MXNet::RNN::Cell::Base 1248 cell for forward unrolling 1249 r_cell : AI::MXNet::RNN::Cell::Base 1250 cell for backward unrolling 1251 output_prefix : str, default 'bi_' 1252 prefix for name of output 1253=cut 1254 1255has 'l_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1); 1256has 'r_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1); 1257has '_output_prefix' => (is => 'ro', init_arg => 'output_prefix', isa => 'Str', default => 'bi_'); 1258has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef); 1259 1260around BUILDARGS => sub { 1261 my $orig = shift; 1262 my $class = shift; 1263 if(@_ >= 2 and blessed $_[0] and blessed $_[1]) 1264 { 1265 my $l_cell = shift(@_); 1266 my $r_cell = shift(@_); 1267 return $class->$orig( 1268 l_cell => $l_cell, 1269 r_cell => $r_cell, 1270 @_ 1271 ); 1272 } 1273 return $class->$orig(@_); 1274}; 1275 1276sub BUILD 1277{ 1278 my ($self, $original_arguments) = @_; 1279 $self->_override_cell_params(defined $original_arguments->{params}); 1280 if($self->_override_cell_params) 1281 { 1282 assert( 1283 ($self->l_cell->_own_params and $self->r_cell->_own_params), 1284 "Either specify params for BidirectionalCell ". 1285 "or child cells, not both." 1286 ); 1287 %{ $self->l_cell->params->_params } = (%{ $self->l_cell->params->_params }, %{ $self->params->_params }); 1288 %{ $self->r_cell->params->_params } = (%{ $self->r_cell->params->_params }, %{ $self->params->_params }); 1289 } 1290 %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->l_cell->params->_params }); 1291 %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->r_cell->params->_params }); 1292 $self->_cells([$self->l_cell, $self->r_cell]); 1293} 1294 1295method unpack_weights(HashRef[AI::MXNet::NDArray] $args) 1296{ 1297 return $self->_cells_unpack_weights($self->_cells, $args) 1298} 1299 1300method pack_weights(HashRef[AI::MXNet::NDArray] $args) 1301{ 1302 return $self->_cells_pack_weights($self->_cells, $args); 1303} 1304 1305method call($inputs, $states) 1306{ 1307 confess("Bidirectional cannot be stepped. Please use unroll"); 1308} 1309 1310method state_info() 1311{ 1312 return $self->_cells_state_info($self->_cells); 1313} 1314 1315method begin_state(@kwargs) 1316{ 1317 assert((not $self->_modified), 1318 "After applying modifier cells (e.g. DropoutCell) the base " 1319 ."cell cannot be called directly. Call the modifier cell instead." 1320 ); 1321 return $self->_cells_begin_state($self->_cells, @kwargs); 1322} 1323 1324method unroll( 1325 Int $length, 1326 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, 1327 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, 1328 Str :$input_prefix='', 1329 Str :$layout='NTC', 1330 Maybe[Bool] :$merge_outputs= 1331) 1332{ 1333 1334 my $axis = index($layout, 'T'); 1335 if(not defined $inputs) 1336 { 1337 $inputs = [ 1338 map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1) 1339 ]; 1340 } 1341 elsif(blessed($inputs)) 1342 { 1343 assert( 1344 (@{ $inputs->list_outputs() } == 1), 1345 "unroll doesn't allow grouped symbol as input. Please " 1346 ."convert to list first or let unroll handle slicing" 1347 ); 1348 $inputs = [ @{ AI::MXNet::Symbol->SliceChannel( 1349 $inputs, 1350 axis => $axis, 1351 num_outputs => $length, 1352 squeeze_axis => 1 1353 ) }]; 1354 } 1355 else 1356 { 1357 assert(@$inputs == $length); 1358 } 1359 $begin_state //= $self->begin_state; 1360 my $states = $begin_state; 1361 my ($l_cell, $r_cell) = @{ $self->_cells }; 1362 my ($l_outputs, $l_states) = $l_cell->unroll( 1363 $length, inputs => $inputs, 1364 begin_state => [@{$states}[0..@{$l_cell->state_info}-1]], 1365 layout => $layout, 1366 merge_outputs => $merge_outputs 1367 ); 1368 my ($r_outputs, $r_states) = $r_cell->unroll( 1369 $length, inputs => [reverse @{$inputs}], 1370 begin_state => [@{$states}[@{$l_cell->state_info}..@{$states}-1]], 1371 layout => $layout, 1372 merge_outputs => $merge_outputs 1373 ); 1374 if(not defined $merge_outputs) 1375 { 1376 $merge_outputs = ( 1377 blessed $l_outputs and $l_outputs->isa('AI::MXNet::Symbol') 1378 and 1379 blessed $r_outputs and $r_outputs->isa('AI::MXNet::Symbol') 1380 ); 1381 if(not $merge_outputs) 1382 { 1383 if(blessed $l_outputs and $l_outputs->isa('AI::MXNet::Symbol')) 1384 { 1385 $l_outputs = [ 1386 @{ AI::MXNet::Symbol->SliceChannel( 1387 $l_outputs, axis => $axis, 1388 num_outputs => $length, 1389 squeeze_axis => 1 1390 ) } 1391 ]; 1392 } 1393 if(blessed $r_outputs and $r_outputs->isa('AI::MXNet::Symbol')) 1394 { 1395 $r_outputs = [ 1396 @{ AI::MXNet::Symbol->SliceChannel( 1397 $r_outputs, axis => $axis, 1398 num_outputs => $length, 1399 squeeze_axis => 1 1400 ) } 1401 ]; 1402 } 1403 } 1404 } 1405 if($merge_outputs) 1406 { 1407 $l_outputs = [@{ $l_outputs }]; 1408 $r_outputs = [@{ AI::MXNet::Symbol->reverse(blessed $r_outputs ? $r_outputs : @{ $r_outputs }, axis=>$axis) }]; 1409 } 1410 else 1411 { 1412 $r_outputs = [reverse(@{ $r_outputs })]; 1413 } 1414 my $outputs = []; 1415 for(zip([0..@{ $l_outputs }-1], [@{ $l_outputs }], [@{ $r_outputs }])) { 1416 my ($i, $l_o, $r_o) = @$_; 1417 push @$outputs, AI::MXNet::Symbol->Concat( 1418 $l_o, $r_o, dim=>(1+($merge_outputs?1:0)), 1419 name => $merge_outputs 1420 ? sprintf('%sout', $self->_output_prefix) 1421 : sprintf('%st%d', $self->_output_prefix, $i) 1422 ); 1423 } 1424 if($merge_outputs) 1425 { 1426 $outputs = @{ $outputs }[0]; 1427 } 1428 $states = [$l_states, $r_states]; 1429 return($outputs, $states); 1430} 1431 1432package AI::MXNet::RNN::ConvCell::Base; 1433use Mouse; 1434use AI::MXNet::Base; 1435extends 'AI::MXNet::RNN::Cell::Base'; 1436 1437=head1 NAME 1438 1439 AI::MXNet::RNN::Conv::Base 1440=cut 1441 1442=head1 DESCRIPTION 1443 1444 Abstract base class for Convolutional RNN cells 1445 1446=cut 1447 1448has '_h2h_kernel' => (is => 'ro', isa => 'Shape', init_arg => 'h2h_kernel'); 1449has '_h2h_dilate' => (is => 'ro', isa => 'Shape', init_arg => 'h2h_dilate'); 1450has '_h2h_pad' => (is => 'rw', isa => 'Shape', init_arg => undef); 1451has '_i2h_kernel' => (is => 'ro', isa => 'Shape', init_arg => 'i2h_kernel'); 1452has '_i2h_stride' => (is => 'ro', isa => 'Shape', init_arg => 'i2h_stride'); 1453has '_i2h_dilate' => (is => 'ro', isa => 'Shape', init_arg => 'i2h_dilate'); 1454has '_i2h_pad' => (is => 'ro', isa => 'Shape', init_arg => 'i2h_pad'); 1455has '_num_hidden' => (is => 'ro', isa => 'DimSize', init_arg => 'num_hidden'); 1456has '_input_shape' => (is => 'ro', isa => 'Shape', init_arg => 'input_shape'); 1457has '_conv_layout' => (is => 'ro', isa => 'Str', init_arg => 'conv_layout', default => 'NCHW'); 1458has '_activation' => (is => 'ro', init_arg => 'activation'); 1459has '_state_shape' => (is => 'rw', init_arg => undef); 1460has [qw/i2h_weight_initializer h2h_weight_initializer 1461 i2h_bias_initializer h2h_bias_initializer/] => (is => 'rw', isa => 'Maybe[Initializer]'); 1462 1463sub BUILD 1464{ 1465 my $self = shift; 1466 assert ( 1467 ($self->_h2h_kernel->[0] % 2 == 1 and $self->_h2h_kernel->[1] % 2 == 1), 1468 "Only support odd numbers, got h2h_kernel= (@{[ $self->_h2h_kernel ]})" 1469 ); 1470 $self->_h2h_pad([ 1471 int($self->_h2h_dilate->[0] * ($self->_h2h_kernel->[0] - 1) / 2), 1472 int($self->_h2h_dilate->[1] * ($self->_h2h_kernel->[1] - 1) / 2) 1473 ]); 1474 # Infer state shape 1475 my $data = AI::MXNet::Symbol->Variable('data'); 1476 my $state_shape = AI::MXNet::Symbol->Convolution( 1477 data => $data, 1478 num_filter => $self->_num_hidden, 1479 kernel => $self->_i2h_kernel, 1480 stride => $self->_i2h_stride, 1481 pad => $self->_i2h_pad, 1482 dilate => $self->_i2h_dilate, 1483 layout => $self->_conv_layout 1484 ); 1485 $state_shape = ($state_shape->infer_shape(data=>$self->_input_shape))[1]->[0]; 1486 $state_shape->[0] = 0; 1487 $self->_state_shape($state_shape); 1488} 1489 1490method state_info() 1491{ 1492 return [ 1493 { shape => $self->_state_shape, __layout__ => $self->_conv_layout }, 1494 { shape => $self->_state_shape, __layout__ => $self->_conv_layout } 1495 ]; 1496} 1497 1498method call($inputs, $states) 1499{ 1500 confess("AI::MXNet::RNN::ConvCell::Base is abstract class for convolutional RNN"); 1501} 1502 1503package AI::MXNet::RNN::ConvCell; 1504use Mouse; 1505extends 'AI::MXNet::RNN::ConvCell::Base'; 1506 1507=head1 NAME 1508 1509 AI::MXNet::RNN::ConvCell 1510=cut 1511 1512=head1 DESCRIPTION 1513 1514 Convolutional RNN cells 1515 1516 Parameters 1517 ---------- 1518 input_shape : array ref of int 1519 Shape of input in single timestep. 1520 num_hidden : int 1521 Number of units in output symbol. 1522 h2h_kernel : array ref of int, default (3, 3) 1523 Kernel of Convolution operator in state-to-state transitions. 1524 h2h_dilate : array ref of int, default (1, 1) 1525 Dilation of Convolution operator in state-to-state transitions. 1526 i2h_kernel : array ref of int, default (3, 3) 1527 Kernel of Convolution operator in input-to-state transitions. 1528 i2h_stride : array ref of int, default (1, 1) 1529 Stride of Convolution operator in input-to-state transitions. 1530 i2h_pad : array ref of int, default (1, 1) 1531 Pad of Convolution operator in input-to-state transitions. 1532 i2h_dilate : array ref of int, default (1, 1) 1533 Dilation of Convolution operator in input-to-state transitions. 1534 activation : str or Symbol, 1535 default functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2) 1536 Type of activation function. 1537 prefix : str, default 'ConvRNN_' 1538 Prefix for name of layers (and name of weight if params is None). 1539 params : RNNParams, default None 1540 Container for weight sharing between cells. Created if None. 1541 conv_layout : str, , default 'NCHW' 1542 Layout of ConvolutionOp 1543=cut 1544 1545has '+_h2h_kernel' => (default => sub { [3, 3] }); 1546has '+_h2h_dilate' => (default => sub { [1, 1] }); 1547has '+_i2h_kernel' => (default => sub { [3, 3] }); 1548has '+_i2h_stride' => (default => sub { [1, 1] }); 1549has '+_i2h_dilate' => (default => sub { [1, 1] }); 1550has '+_i2h_pad' => (default => sub { [1, 1] }); 1551has '+_prefix' => (default => 'ConvRNN_'); 1552has '+_activation' => (default => sub { sub { AI::MXNet::Symbol->LeakyReLU(@_, act_type => 'leaky', slope => 0.2) } }); 1553has '+i2h_bias_initializer' => (default => 'zeros'); 1554has '+h2h_bias_initializer' => (default => 'zeros'); 1555has 'forget_bias' => (is => 'ro', isa => 'Num'); 1556has [qw/_iW _iB 1557 _hW _hB/] => (is => 'rw', init_arg => undef); 1558 1559 1560sub BUILD 1561{ 1562 my $self = shift; 1563 $self->_iW($self->_params->get('i2h_weight', init => $self->i2h_weight_initializer)); 1564 $self->_hW($self->_params->get('h2h_weight', init => $self->h2h_weight_initializer)); 1565 $self->_iB( 1566 $self->params->get( 1567 'i2h_bias', 1568 (defined($self->forget_bias and not defined $self->i2h_bias_initializer) 1569 ? (init => AI::MXNet::LSTMBias->new(forget_bias => $self->forget_bias)) 1570 : (init => $self->i2h_bias_initializer) 1571 ) 1572 ) 1573 ); 1574 $self->_hB($self->_params->get('h2h_bias', init => $self->h2h_bias_initializer)); 1575} 1576 1577method _num_gates() 1578{ 1579 scalar(@{ $self->_gate_names() }); 1580} 1581 1582method _gate_names() 1583{ 1584 return [''] 1585} 1586 1587method _conv_forward($inputs, $states, $name) 1588{ 1589 my $i2h = AI::MXNet::Symbol->Convolution( 1590 name => "${name}i2h", 1591 data => $inputs, 1592 num_filter => $self->_num_hidden*$self->_num_gates(), 1593 kernel => $self->_i2h_kernel, 1594 stride => $self->_i2h_stride, 1595 pad => $self->_i2h_pad, 1596 dilate => $self->_i2h_dilate, 1597 weight => $self->_iW, 1598 bias => $self->_iB 1599 ); 1600 my $h2h = AI::MXNet::Symbol->Convolution( 1601 name => "${name}h2h", 1602 data => @{ $states }[0], 1603 num_filter => $self->_num_hidden*$self->_num_gates(), 1604 kernel => $self->_h2h_kernel, 1605 stride => [1, 1], 1606 pad => $self->_h2h_pad, 1607 dilate => $self->_h2h_dilate, 1608 weight => $self->_hW, 1609 bias => $self->_hB 1610 ); 1611 return ($i2h, $h2h); 1612} 1613 1614method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states) 1615{ 1616 $self->_counter($self->_counter + 1); 1617 my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); 1618 my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name); 1619 my $output = $self->_get_activation($i2h + $h2h, $self->_activation, name => "${name}out"); 1620 return ($output, [$output]); 1621} 1622 1623package AI::MXNet::RNN::ConvLSTMCell; 1624use Mouse; 1625extends 'AI::MXNet::RNN::ConvCell'; 1626has '+forget_bias' => (default => 1); 1627has '+_prefix' => (default => 'ConvLSTM_'); 1628 1629=head1 NAME 1630 1631 AI::MXNet::RNN::ConvLSTMCell 1632=cut 1633 1634=head1 DESCRIPTION 1635 1636 Convolutional LSTM network cell. 1637 1638 Reference: 1639 Xingjian et al. NIPS2015 1640=cut 1641 1642method _gate_names() 1643{ 1644 return ['_i', '_f', '_c', '_o']; 1645} 1646 1647method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states) 1648{ 1649 $self->_counter($self->_counter + 1); 1650 my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); 1651 my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name); 1652 my $gates = $i2h + $h2h; 1653 my @slice_gates = @{ AI::MXNet::Symbol->SliceChannel( 1654 $gates, 1655 num_outputs => 4, 1656 axis => index($self->_conv_layout, 'C'), 1657 name => "${name}slice" 1658 ) }; 1659 my $in_gate = AI::MXNet::Symbol->Activation( 1660 $slice_gates[0], 1661 act_type => "sigmoid", 1662 name => "${name}i" 1663 ); 1664 my $forget_gate = AI::MXNet::Symbol->Activation( 1665 $slice_gates[1], 1666 act_type => "sigmoid", 1667 name => "${name}f" 1668 ); 1669 my $in_transform = $self->_get_activation( 1670 $slice_gates[2], 1671 $self->_activation, 1672 name => "${name}c" 1673 ); 1674 my $out_gate = AI::MXNet::Symbol->Activation( 1675 $slice_gates[3], 1676 act_type => "sigmoid", 1677 name => "${name}o" 1678 ); 1679 my $next_c = AI::MXNet::Symbol->_plus( 1680 $forget_gate * @{$states}[1], 1681 $in_gate * $in_transform, 1682 name => "${name}state" 1683 ); 1684 my $next_h = AI::MXNet::Symbol->_mul( 1685 $out_gate, $self->_get_activation($next_c, $self->_activation), 1686 name => "${name}out" 1687 ); 1688 return ($next_h, [$next_h, $next_c]); 1689} 1690 1691package AI::MXNet::RNN::ConvGRUCell; 1692use Mouse; 1693extends 'AI::MXNet::RNN::ConvCell'; 1694has '+_prefix' => (default => 'ConvGRU_'); 1695 1696=head1 NAME 1697 1698 AI::MXNet::RNN::ConvGRUCell 1699=cut 1700 1701=head1 DESCRIPTION 1702 1703 Convolutional GRU network cell. 1704=cut 1705 1706method _gate_names() 1707{ 1708 return ['_r', '_z', '_o']; 1709} 1710 1711method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states) 1712{ 1713 $self->_counter($self->_counter + 1); 1714 my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); 1715 my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name); 1716 my ($i2h_r, $i2h_z, $h2h_r, $h2h_z); 1717 ($i2h_r, $i2h_z, $i2h) = @{ AI::MXNet::Symbol->SliceChannel($i2h, num_outputs => 3, name => "${name}_i2h_slice") }; 1718 ($h2h_r, $h2h_z, $h2h) = @{ AI::MXNet::Symbol->SliceChannel($h2h, num_outputs => 3, name => "${name}_h2h_slice") }; 1719 my $reset_gate = AI::MXNet::Symbol->Activation( 1720 $i2h_r + $h2h_r, act_type => "sigmoid", 1721 name => "${name}_r_act" 1722 ); 1723 my $update_gate = AI::MXNet::Symbol->Activation( 1724 $i2h_z + $h2h_z, act_type => "sigmoid", 1725 name => "${name}_z_act" 1726 ); 1727 my $next_h_tmp = $self->_get_activation($i2h + $reset_gate * $h2h, $self->_activation, name => "${name}_h_act"); 1728 my $next_h = AI::MXNet::Symbol->_plus( 1729 (1 - $update_gate) * $next_h_tmp, $update_gate * @{$states}[0], 1730 name => "${name}out" 1731 ); 1732 return ($next_h, [$next_h]); 1733} 1734 1735package AI::MXNet::RNN::ModifierCell; 1736use Mouse; 1737use AI::MXNet::Base; 1738extends 'AI::MXNet::RNN::Cell::Base'; 1739 1740=head1 NAME 1741 1742 AI::MXNet::RNN::ModifierCell 1743=cut 1744 1745=head1 DESCRIPTION 1746 1747 Base class for modifier cells. A modifier 1748 cell takes a base cell, apply modifications 1749 on it (e.g. Dropout), and returns a new cell. 1750 1751 After applying modifiers the base cell should 1752 no longer be called directly. The modifer cell 1753 should be used instead. 1754=cut 1755 1756has 'base_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1); 1757 1758around BUILDARGS => sub { 1759 my $orig = shift; 1760 my $class = shift; 1761 if(@_%2) 1762 { 1763 my $base_cell = shift; 1764 return $class->$orig(base_cell => $base_cell, @_); 1765 } 1766 return $class->$orig(@_); 1767}; 1768 1769sub BUILD 1770{ 1771 my $self = shift; 1772 $self->base_cell->_modified(1); 1773} 1774 1775method params() 1776{ 1777 $self->_own_params(0); 1778 return $self->base_cell->params; 1779} 1780 1781method state_info() 1782{ 1783 return $self->base_cell->state_info; 1784} 1785 1786method begin_state(CodeRef :$init_sym=AI::MXNet::Symbol->can('zeros'), @kwargs) 1787{ 1788 assert( 1789 (not $self->_modified), 1790 "After applying modifier cells (e.g. DropoutCell) the base " 1791 ."cell cannot be called directly. Call the modifier cell instead." 1792 ); 1793 $self->base_cell->_modified(0); 1794 my $begin_state = $self->base_cell->begin_state(func => $init_sym, @kwargs); 1795 $self->base_cell->_modified(1); 1796 return $begin_state; 1797} 1798 1799method unpack_weights(HashRef[AI::MXNet::NDArray] $args) 1800{ 1801 return $self->base_cell->unpack_weights($args) 1802} 1803 1804method pack_weights(HashRef[AI::MXNet::NDArray] $args) 1805{ 1806 return $self->base_cell->pack_weights($args) 1807} 1808 1809method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) 1810{ 1811 confess("Not Implemented"); 1812} 1813 1814package AI::MXNet::RNN::DropoutCell; 1815use Mouse; 1816extends 'AI::MXNet::RNN::ModifierCell'; 1817has [qw/dropout_outputs dropout_states/] => (is => 'ro', isa => 'Num', default => 0); 1818 1819=head1 NAME 1820 1821 AI::MXNet::RNN::DropoutCell 1822=cut 1823 1824=head1 DESCRIPTION 1825 1826 Apply the dropout on base cell 1827=cut 1828 1829method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) 1830{ 1831 my ($output, $states) = $self->base_cell->($inputs, $states); 1832 if($self->dropout_outputs > 0) 1833 { 1834 $output = AI::MXNet::Symbol->Dropout(data => $output, p => $self->dropout_outputs); 1835 } 1836 if($self->dropout_states > 0) 1837 { 1838 $states = [map { AI::MXNet::Symbol->Dropout(data => $_, p => $self->dropout_states) } @{ $states }]; 1839 } 1840 return ($output, $states); 1841} 1842 1843package AI::MXNet::RNN::ZoneoutCell; 1844use Mouse; 1845use AI::MXNet::Base; 1846extends 'AI::MXNet::RNN::ModifierCell'; 1847has [qw/zoneout_outputs zoneout_states/] => (is => 'ro', isa => 'Num', default => 0); 1848has 'prev_output' => (is => 'rw', init_arg => undef); 1849 1850=head1 NAME 1851 1852 AI::MXNet::RNN::ZoneoutCell 1853=cut 1854 1855=head1 DESCRIPTION 1856 1857 Apply Zoneout on base cell. 1858=cut 1859 1860sub BUILD 1861{ 1862 my $self = shift; 1863 assert( 1864 (not $self->base_cell->isa('AI::MXNet::RNN::FusedCell')), 1865 "FusedRNNCell doesn't support zoneout. ". 1866 "Please unfuse first." 1867 ); 1868 assert( 1869 (not $self->base_cell->isa('AI::MXNet::RNN::BidirectionalCell')), 1870 "BidirectionalCell doesn't support zoneout since it doesn't support step. ". 1871 "Please add ZoneoutCell to the cells underneath instead." 1872 ); 1873 assert( 1874 (not $self->base_cell->isa('AI::MXNet::RNN::SequentialCell') or not $self->_bidirectional), 1875 "Bidirectional SequentialCell doesn't support zoneout. ". 1876 "Please add ZoneoutCell to the cells underneath instead." 1877 ); 1878} 1879 1880method reset() 1881{ 1882 $self->SUPER::reset; 1883 $self->prev_output(undef); 1884} 1885 1886method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) 1887{ 1888 my ($cell, $p_outputs, $p_states) = ($self->base_cell, $self->zoneout_outputs, $self->zoneout_states); 1889 my ($next_output, $next_states) = $cell->($inputs, $states); 1890 my $mask = sub { 1891 my ($p, $like) = @_; 1892 AI::MXNet::Symbol->Dropout( 1893 AI::MXNet::Symbol->ones_like( 1894 $like 1895 ), 1896 p => $p 1897 ); 1898 }; 1899 my $prev_output = $self->prev_output // AI::MXNet::Symbol->zeros(shape => [0, 0]); 1900 my $output = $p_outputs != 0 1901 ? AI::MXNet::Symbol->where( 1902 $mask->($p_outputs, $next_output), 1903 $next_output, 1904 $prev_output 1905 ) 1906 : $next_output; 1907 my @states; 1908 if($p_states != 0) 1909 { 1910 for(zip($next_states, $states)) { 1911 my ($new_s, $old_s) = @$_; 1912 push @states, AI::MXNet::Symbol->where( 1913 $mask->($p_states, $new_s), 1914 $new_s, 1915 $old_s 1916 ); 1917 } 1918 } 1919 $self->prev_output($output); 1920 return ($output, @states ? \@states : $next_states); 1921} 1922 1923package AI::MXNet::RNN::ResidualCell; 1924use Mouse; 1925use AI::MXNet::Base; 1926extends 'AI::MXNet::RNN::ModifierCell'; 1927 1928=head1 NAME 1929 1930 AI::MXNet::RNN::ResidualCell 1931=cut 1932 1933=head1 DESCRIPTION 1934 1935 Adds residual connection as described in Wu et al, 2016 1936 (https://arxiv.org/abs/1609.08144). 1937 Output of the cell is output of the base cell plus input. 1938=cut 1939 1940method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) 1941{ 1942 my $output; 1943 ($output, $states) = $self->base_cell->($inputs, $states); 1944 $output = AI::MXNet::Symbol->elemwise_add($output, $inputs, name => $output->name.'_plus_residual'); 1945 return ($output, $states) 1946} 1947 1948method unroll( 1949 Int $length, 1950 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, 1951 Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, 1952 Str :$input_prefix='', 1953 Str :$layout='NTC', 1954 Maybe[Bool] :$merge_outputs= 1955) 1956{ 1957 $self->reset; 1958 $self->base_cell->_modified(0); 1959 my ($outputs, $states) = $self->base_cell->unroll($length, inputs=>$inputs, begin_state=>$begin_state, 1960 layout=>$layout, merge_outputs=>$merge_outputs); 1961 $self->base_cell->_modified(1); 1962 $merge_outputs //= (blessed($outputs) and $outputs->isa('AI::MXNet::Symbol')); 1963 ($inputs) = _normalize_sequence($length, $inputs, $layout, $merge_outputs); 1964 if($merge_outputs) 1965 { 1966 $outputs = AI::MXNet::Symbol->elemwise_add($outputs, $inputs, name => $outputs->name . "_plus_residual"); 1967 } 1968 else 1969 { 1970 my @temp; 1971 for(zip([@{ $outputs }], [@{ $inputs }])) { 1972 my ($output_sym, $input_sym) = @$_; 1973 push @temp, AI::MXNet::Symbol->elemwise_add($output_sym, $input_sym, 1974 name=>$output_sym->name."_plus_residual"); 1975 } 1976 $outputs = \@temp; 1977 } 1978 return ($outputs, $states); 1979} 1980 1981func _normalize_sequence($length, $inputs, $layout, $merge, $in_layout=) 1982{ 1983 assert((defined $inputs), 1984 "unroll(inputs=>undef) has been deprecated. ". 1985 "Please create input variables outside unroll." 1986 ); 1987 1988 my $axis = index($layout, 'T'); 1989 my $in_axis = defined $in_layout ? index($in_layout, 'T') : $axis; 1990 if(blessed($inputs)) 1991 { 1992 if(not $merge) 1993 { 1994 assert( 1995 (@{ $inputs->list_outputs() } == 1), 1996 "unroll doesn't allow grouped symbol as input. Please " 1997 ."convert to list first or let unroll handle splitting" 1998 ); 1999 $inputs = [ @{ AI::MXNet::Symbol->split( 2000 $inputs, 2001 axis => $in_axis, 2002 num_outputs => $length, 2003 squeeze_axis => 1 2004 ) }]; 2005 } 2006 } 2007 else 2008 { 2009 assert(not defined $length or @$inputs == $length); 2010 if($merge) 2011 { 2012 $inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis=>$axis) } @{ $inputs }]; 2013 $inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim=>$axis); 2014 $in_axis = $axis; 2015 } 2016 } 2017 2018 if(blessed($inputs) and $axis != $in_axis) 2019 { 2020 $inputs = AI::MXNet::Symbol->swapaxes($inputs, dim0=>$axis, dim1=>$in_axis); 2021 } 2022 return ($inputs, $axis); 2023} 2024 20251; 2026