1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18use strict; 19use warnings; 20package AI::MXNet::Gluon::RNN::RecurrentCell; 21use Mouse::Role; 22use AI::MXNet::Base; 23use AI::MXNet::Function::Parameters; 24 25method _cells_state_info($cells, $batch_size) 26{ 27 return [map { @{ $_->state_info($batch_size) } } $cells->values]; 28} 29 30method _cells_begin_state($cells, %kwargs) 31{ 32 return [map { @{ $_->begin_state(%kwargs) } } $cells->values]; 33} 34 35method _get_begin_state(GluonClass $F, $begin_state, GluonInput $inputs, $batch_size) 36{ 37 if(not defined $begin_state) 38 { 39 if($F =~ /AI::MXNet::NDArray/) 40 { 41 my $ctx = blessed $inputs ? $inputs->context : $inputs->[0]->context; 42 { 43 local($AI::MXNet::current_ctx) = $ctx; 44 my $func = sub { 45 my %kwargs = @_; 46 my $shape = delete $kwargs{shape}; 47 return AI::MXNet::NDArray->zeros($shape, %kwargs); 48 }; 49 $begin_state = $self->begin_state(batch_size => $batch_size, func => $func); 50 } 51 } 52 else 53 { 54 $begin_state = $self->begin_state(batch_size => $batch_size, func => sub { return $F->zeros(@_) }); 55 } 56 } 57 return $begin_state; 58} 59 60 61method _format_sequence($length, $inputs, $layout, $merge, $in_layout=) 62{ 63 assert( 64 (defined $inputs), 65 "unroll(inputs=None) has been deprecated. ". 66 "Please create input variables outside unroll." 67 ); 68 69 my $axis = index($layout, 'T'); 70 my $batch_axis = index($layout, 'N'); 71 my $batch_size = 0; 72 my $in_axis = defined $in_layout ? index($in_layout, 'T') : $axis; 73 my $F; 74 if(blessed $inputs and $inputs->isa('AI::MXNet::Symbol')) 75 { 76 $F = 'AI::MXNet::Symbol'; 77 if(not $merge) 78 { 79 assert( 80 (@{ $inputs->list_outputs() } == 1), 81 "unroll doesn't allow grouped symbol as input. Please convert ". 82 "to list with list(inputs) first or let unroll handle splitting" 83 ); 84 $inputs = [ 85 AI::MXNet::Symbol->split( 86 $inputs, axis => $in_axis, num_outputs => $length, squeeze_axis => 1 87 ) 88 ]; 89 } 90 } 91 elsif(blessed $inputs and $inputs->isa('AI::MXNet::NDArray')) 92 { 93 $F = 'AI::MXNet::NDArray'; 94 $batch_size = $inputs->shape->[$batch_axis]; 95 if(not $merge) 96 { 97 assert(not defined $length or $length == $inputs->shape->[$in_axis]); 98 $inputs = as_array( 99 AI::MXNet::NDArray->split( 100 $inputs, axis=>$in_axis, 101 num_outputs => $inputs->shape->[$in_axis], 102 squeeze_axis => 1 103 ) 104 ); 105 } 106 } 107 else 108 { 109 assert(not defined $length or @{ $inputs } == $length); 110 if($inputs->[0]->isa('AI::MXNet::Symbol')) 111 { 112 $F = 'AI::MXNet::Symbol'; 113 } 114 else 115 { 116 $F = 'AI::MXNet::NDArray'; 117 $batch_size = $inputs->[0]->shape->[$batch_axis]; 118 } 119 if($merge) 120 { 121 $inputs = [map { $F->expand_dims($_, axis => $axis) } @{ $inputs }]; 122 $inputs = $F->stack(@{ $inputs }, axis => $axis); 123 $in_axis = $axis; 124 } 125 } 126 if(blessed $inputs and $axis != $in_axis) 127 { 128 $inputs = $F->swapaxes($inputs, dim1=>$axis, dim2=>$in_axis); 129 } 130 return ($inputs, $axis, $F, $batch_size); 131} 132 133method _mask_sequence_variable_length($F, $data, $length, $valid_length, $time_axis, $merge) 134{ 135 assert(defined $valid_length); 136 if(not blessed $data) 137 { 138 $data = $F->stack(@$data, axis=>$time_axis); 139 } 140 my $outputs = $F->SequenceMask($data, { sequence_length=>$valid_length, use_sequence_length=>1, 141 axis=>$time_axis}); 142 if(not $merge) 143 { 144 $outputs = $F->split($outputs, { num_outputs=>$length, axis=>$time_axis, 145 squeeze_axis=>1}); 146 if(not ref $outputs eq 'ARRAY') 147 { 148 $outputs = [$outputs]; 149 } 150 } 151 return $outputs; 152} 153 154method _reverse_sequences($sequences, $unroll_step, $valid_length=) 155{ 156 my $F; 157 if($sequences->[0]->isa('AI::MXNet::Symbol')) 158 { 159 $F = 'AI::MXNet::Symbol'; 160 } 161 else 162 { 163 $F = 'AI::MXNet::NDArray'; 164 } 165 166 my $reversed_sequences; 167 if(not defined $valid_length) 168 { 169 $reversed_sequences = [reverse(@$sequences)]; 170 } 171 else 172 { 173 $reversed_sequences = $F->SequenceReverse($F->stack(@$sequences, axis=>0), 174 {sequence_length=>$valid_length, 175 use_sequence_length=>1}); 176 $reversed_sequences = $F->split($reversed_sequences, {axis=>0, num_outputs=>$unroll_step, squeeze_axis=>1}); 177 } 178 return $reversed_sequences; 179} 180 181=head1 NAME 182 183 AI::MXNet::Gluon::RNN::RecurrentCell 184=cut 185 186=head1 DESCRIPTION 187 188 Abstract role for RNN cells 189 190 Parameters 191 ---------- 192 prefix : str, optional 193 Prefix for names of `Block`s 194 (this prefix is also used for names of weights if `params` is `None` 195 i.e. if `params` are being created and not reused) 196 params : Parameter or None, optional 197 Container for weight sharing between cells. 198 A new Parameter container is created if `params` is `None`. 199=cut 200 201=head2 reset 202 203 Reset before re-using the cell for another graph. 204=cut 205 206method reset() 207{ 208 $self->init_counter(-1); 209 $self->counter(-1); 210 $_->reset for $self->_children->values; 211} 212 213=head2 state_info 214 215 Shape and layout information of states 216=cut 217method state_info(Int $batch_size=0) 218{ 219 confess('Not Implemented'); 220} 221 222=head2 begin_state 223 224 Initial state for this cell. 225 226 Parameters 227 ---------- 228 $func : CodeRef, default sub { AI::MXNet::Symbol->zeros(@_) } 229 Function for creating initial state. 230 231 For Symbol API, func can be `symbol.zeros`, `symbol.uniform`, 232 `symbol.var etc`. Use `symbol.var` if you want to directly 233 feed input as states. 234 235 For NDArray API, func can be `ndarray.zeros`, `ndarray.ones`, etc. 236 $batch_size: int, default 0 237 Only required for NDArray API. Size of the batch ('N' in layout) 238 dimension of input. 239 240 %kwargs : 241 Additional keyword arguments passed to func. For example 242 `mean`, `std`, `dtype`, etc. 243 244 Returns 245 ------- 246 states : nested array ref of Symbol 247 Starting states for the first RNN step. 248=cut 249 250method begin_state(Int :$batch_size=0, CodeRef :$func=, %kwargs) 251{ 252 $func //= sub { 253 my %kwargs = @_; 254 my $shape = delete $kwargs{shape}; 255 return AI::MXNet::NDArray->zeros($shape, %kwargs); 256 }; 257 assert( 258 (not $self->modified), 259 "After applying modifier cells (e.g. ZoneoutCell) the base ". 260 "cell cannot be called directly. Call the modifier cell instead." 261 ); 262 my @states; 263 for my $info (@{ $self->state_info($batch_size) }) 264 { 265 $self->init_counter($self->init_counter + 1); 266 if(defined $info) 267 { 268 %$info = (%$info, %kwargs); 269 } 270 else 271 { 272 $info = \%kwargs; 273 } 274 my $state = $func->( 275 name => "${\ $self->_prefix }begin_state_${\ $self->init_counter }", 276 %$info 277 ); 278 push @states, $state; 279 } 280 return \@states; 281} 282 283=head2 unroll 284 285 Unrolls an RNN cell across time steps. 286 287 Parameters 288 ---------- 289 $length : int 290 Number of steps to unroll. 291 $inputs : Symbol, list of Symbol, or None 292 If `inputs` is a single Symbol (usually the output 293 of Embedding symbol), it should have shape 294 (batch_size, length, ...) if `layout` is 'NTC', 295 or (length, batch_size, ...) if `layout` is 'TNC'. 296 297 If `inputs` is a list of symbols (usually output of 298 previous unroll), they should all have shape 299 (batch_size, ...). 300 :$begin_state : nested list of Symbol, optional 301 Input states created by `begin_state()` 302 or output state of another cell. 303 Created from `begin_state()` if `None`. 304 :$layout : str, optional 305 `layout` of input symbol. Only used if inputs 306 is a single Symbol. 307 :$merge_outputs : bool, optional 308 If `False`, returns outputs as a list of Symbols. 309 If `True`, concatenates output across time steps 310 and returns a single symbol with shape 311 (batch_size, length, ...) if layout is 'NTC', 312 or (length, batch_size, ...) if layout is 'TNC'. 313 If `None`, output whatever is faster. 314 315 Returns 316 ------- 317 outputs : list of Symbol or Symbol 318 Symbol (if `merge_outputs` is True) or list of Symbols 319 (if `merge_outputs` is False) corresponding to the output from 320 the RNN from this unrolling. 321 322 states : list of Symbol 323 The new state of this RNN after this unrolling. 324 The type of this symbol is same as the output of `begin_state()`. 325=cut 326 327method unroll( 328 Int $length, 329 Maybe[GluonInput] $inputs, 330 Maybe[GluonInput] :$begin_state=, 331 Str :$layout='NTC', 332 Maybe[Bool] :$merge_outputs=, 333 Maybe[Bool] :$valid_length= 334) 335{ 336 $self->reset(); 337 my ($F, $batch_size, $axis); 338 ($inputs, $axis, $F, $batch_size) = $self->_format_sequence($length, $inputs, $layout, 0); 339 $begin_state //= $self->_get_begin_state($F, $begin_state, $inputs, $batch_size); 340 341 my $states = $begin_state; 342 my $outputs = []; 343 my $all_states = []; 344 for my $i (0..$length-1) 345 { 346 my $output; 347 ($output, $states) = $self->($inputs->[$i], $states); 348 push @$outputs, $output; 349 if(defined $valid_length) 350 { 351 push @$all_states, $states; 352 } 353 } 354 if(defined $valid_length) 355 { 356 $states = []; 357 for(zip(@$all_states)) 358 { 359 push @$states, $F->SequenceLast($F->stack(@$_, axis=>0), 360 sequence_length=>$valid_length, 361 use_sequence_length=>1, 362 axis=>0); 363 } 364 $outputs = $self->_mask_sequence_variable_length($F, $outputs, $length, $valid_length, $axis, 1); 365 } 366 ($outputs) = $self->_format_sequence($length, $outputs, $layout, $merge_outputs); 367 return ($outputs, $states); 368} 369 370method _get_activation(GluonClass $F, GluonInput $inputs, Activation $activation, %kwargs) 371{ 372 if(not blessed $activation) 373 { 374 my %act = map { $_ => 1 } qw(tanh relu sigmoid softsign); 375 if(exists $act{$activation}) 376 { 377 return $F->$activation($inputs, %kwargs) 378 } 379 return $F->Activation($inputs, act_type=>$activation, %kwargs); 380 } 381 elsif(ref($activation) =~ /LeakyReLU/) 382 { 383 return $F->LeakyReLU($inputs, act_type=>'leaky', slope => $activation->alpha, %kwargs); 384 } 385 else 386 { 387 return $activation->($inputs, %kwargs); 388 } 389} 390 391=head2 forward 392 393 Unrolls the recurrent cell for one time step. 394 395 Parameters 396 ---------- 397 inputs : sym.Variable 398 Input symbol, 2D, of shape (batch_size * num_units). 399 states : list of sym.Variable 400 RNN state from previous step or the output of begin_state(). 401 402 Returns 403 ------- 404 output : Symbol 405 Symbol corresponding to the output from the RNN when unrolling 406 for a single time step. 407 states : list of Symbol 408 The new state of this RNN after this unrolling. 409 The type of this symbol is same as the output of `begin_state()`. 410 This can be used as an input state to the next time step 411 of this RNN. 412 413 See Also 414 -------- 415 begin_state: This function can provide the states for the first time step. 416 unroll: This function unrolls an RNN for a given number of (>=1) time steps. 417=cut 418 419package AI::MXNet::Gluon::RNN::HybridRecurrentCell; 420use AI::MXNet::Gluon::Mouse; 421extends 'AI::MXNet::Gluon::HybridBlock'; 422with 'AI::MXNet::Gluon::RNN::RecurrentCell'; 423has 'modified' => (is => 'rw', isa => 'Bool', default => 0); 424has [qw/counter 425 init_counter/] => (is => 'rw', isa => 'Int', default => -1); 426 427sub BUILD 428{ 429 my $self = shift; 430 $self->reset; 431} 432 433use overload '""' => sub { 434 my $self = shift; 435 my $s = '%s(%s'; 436 if($self->can('activation')) 437 { 438 $s .= ", ${\ $self->activation }"; 439 } 440 $s .= ')'; 441 my $mapping = $self->input_size ? $self->input_size . " -> " . $self->hidden_size : $self->hidden_size; 442 return sprintf($s, $self->_class_name, $mapping); 443}; 444 445method forward(GluonInput $inputs, Maybe[GluonInput|ArrayRef[GluonInput]] $states) 446{ 447 $self->counter($self->counter + 1); 448 $self->SUPER::forward($inputs, $states); 449} 450 451package AI::MXNet::Gluon::RNN::RNNCell; 452use AI::MXNet::Gluon::Mouse; 453extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell'; 454 455=head1 NAME 456 457 AI::MXNet::Gluon::RNN::RNNCell 458=cut 459 460=head1 DESCRIPTION 461 462 Simple recurrent neural network cell. 463 464 Parameters 465 ---------- 466 hidden_size : int 467 Number of units in output symbol 468 activation : str or Symbol, default 'tanh' 469 Type of activation function. 470 i2h_weight_initializer : str or Initializer 471 Initializer for the input weights matrix, used for the linear 472 transformation of the inputs. 473 h2h_weight_initializer : str or Initializer 474 Initializer for the recurrent weights matrix, used for the linear 475 transformation of the recurrent state. 476 i2h_bias_initializer : str or Initializer 477 Initializer for the bias vector. 478 h2h_bias_initializer : str or Initializer 479 Initializer for the bias vector. 480 prefix : str, default 'rnn_' 481 Prefix for name of `Block`s 482 (and name of weight if params is `None`). 483 params : Parameter or None 484 Container for weight sharing between cells. 485 Created if `None`. 486=cut 487 488has 'hidden_size' => (is => 'rw', isa => 'Int', required => 1); 489has 'activation' => (is => 'rw', isa => 'Activation', default => 'tanh'); 490has [qw/ 491 i2h_weight_initializer 492 h2h_weight_initializer 493 /] => (is => 'rw', isa => 'Maybe[Initializer]'); 494has [qw/ 495 i2h_bias_initializer 496 h2h_bias_initializer 497 /] => (is => 'rw', isa => 'Maybe[Initializer]', default => 'zeros'); 498has 'input_size' => (is => 'rw', isa => 'Int', default => 0); 499has [qw/ 500 i2h_weight 501 h2h_weight 502 i2h_bias 503 h2h_bias 504 /] => (is => 'rw', init_arg => undef); 505 506method python_constructor_arguments() 507{ 508 [qw/ 509 hidden_size activation 510 i2h_weight_initializer h2h_weight_initializer 511 i2h_bias_initializer h2h_bias_initializer 512 input_size 513 /]; 514} 515 516sub BUILD 517{ 518 my $self = shift; 519 $self->i2h_weight($self->params->get( 520 'i2h_weight', shape=>[$self->hidden_size, $self->input_size], 521 init => $self->i2h_weight_initializer, 522 allow_deferred_init => 1 523 )); 524 $self->h2h_weight($self->params->get( 525 'h2h_weight', shape=>[$self->hidden_size, $self->hidden_size], 526 init => $self->h2h_weight_initializer, 527 allow_deferred_init => 1 528 )); 529 $self->i2h_bias($self->params->get( 530 'i2h_bias', shape=>[$self->hidden_size], 531 init => $self->i2h_bias_initializer, 532 allow_deferred_init => 1 533 )); 534 $self->h2h_bias($self->params->get( 535 'h2h_bias', shape=>[$self->hidden_size], 536 init => $self->h2h_bias_initializer, 537 allow_deferred_init => 1 538 )); 539} 540 541method state_info(Int $batch_size=0) 542{ 543 return [{ shape => [$batch_size, $self->hidden_size], __layout__ => 'NC' }]; 544} 545 546method _alias() { 'rnn' } 547 548method hybrid_forward( 549 GluonClass $F, GluonInput $inputs, GluonInput $states, 550 GluonInput :$i2h_weight, GluonInput :$h2h_weight, GluonInput :$i2h_bias, GluonInput :$h2h_bias 551) 552{ 553 my $prefix = "t${\ $self->counter}_"; 554 my $i2h = $F->FullyConnected( 555 data => $inputs, weight => $i2h_weight, bias => $i2h_bias, 556 num_hidden => $self->hidden_size, 557 name => "${prefix}i2h" 558 ); 559 my $h2h = $F->FullyConnected( 560 data => $states->[0], weight => $h2h_weight, bias => $h2h_bias, 561 num_hidden => $self->hidden_size, 562 name => "${prefix}h2h" 563 ); 564 my $i2h_plus_h2h = $F->elemwise_add($i2h, $h2h, name => "${prefix}plus0"); 565 my $output = $self->_get_activation($F, $i2h_plus_h2h, $self->activation, name => "${prefix}out"); 566 return ($output, [$output]); 567} 568 569__PACKAGE__->register('AI::MXNet::Gluon::RNN'); 570 571package AI::MXNet::Gluon::RNN::LSTMCell; 572use AI::MXNet::Gluon::Mouse; 573extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell'; 574 575=head1 NAME 576 577 AI::MXNet::Gluon::RNN::LSTMCell 578=cut 579 580=head1 DESCRIPTION 581 582 Long-Short Term Memory (LSTM) network cell. 583 584 Parameters 585 ---------- 586 hidden_size : int 587 Number of units in output symbol. 588 i2h_weight_initializer : str or Initializer 589 Initializer for the input weights matrix, used for the linear 590 transformation of the inputs. 591 h2h_weight_initializer : str or Initializer 592 Initializer for the recurrent weights matrix, used for the linear 593 transformation of the recurrent state. 594 i2h_bias_initializer : str or Initializer, default 'lstmbias' 595 Initializer for the bias vector. By default, bias for the forget 596 gate is initialized to 1 while all other biases are initialized 597 to zero. 598 h2h_bias_initializer : str or Initializer 599 Initializer for the bias vector. 600 prefix : str, default 'lstm_' 601 Prefix for name of `Block`s 602 (and name of weight if params is `None`). 603 params : Parameter or None 604 Container for weight sharing between cells. 605 Created if `None`. 606=cut 607 608has 'hidden_size' => (is => 'rw', isa => 'Int', required => 1); 609has [qw/ 610 i2h_weight_initializer 611 h2h_weight_initializer 612 /] => (is => 'rw', isa => 'Maybe[Initializer]'); 613has [qw/ 614 i2h_bias_initializer 615 h2h_bias_initializer 616 /] => (is => 'rw', isa => 'Maybe[Initializer]', default => 'zeros'); 617has 'input_size' => (is => 'rw', isa => 'Int', default => 0); 618has [qw/ 619 i2h_weight 620 h2h_weight 621 i2h_bias 622 h2h_bias 623 /] => (is => 'rw', init_arg => undef); 624 625method python_constructor_arguments() 626{ 627 [qw/ 628 hidden_size 629 i2h_weight_initializer h2h_weight_initializer 630 i2h_bias_initializer h2h_bias_initializer 631 input_size 632 /]; 633} 634 635 636sub BUILD 637{ 638 my $self = shift; 639 $self->i2h_weight($self->params->get( 640 'i2h_weight', shape=>[4*$self->hidden_size, $self->input_size], 641 init => $self->i2h_weight_initializer, 642 allow_deferred_init => 1 643 )); 644 $self->h2h_weight($self->params->get( 645 'h2h_weight', shape=>[4*$self->hidden_size, $self->hidden_size], 646 init => $self->h2h_weight_initializer, 647 allow_deferred_init => 1 648 )); 649 $self->i2h_bias($self->params->get( 650 'i2h_bias', shape=>[4*$self->hidden_size], 651 init => $self->i2h_bias_initializer, 652 allow_deferred_init => 1 653 )); 654 $self->h2h_bias($self->params->get( 655 'h2h_bias', shape=>[4*$self->hidden_size], 656 init => $self->h2h_bias_initializer, 657 allow_deferred_init => 1 658 )); 659} 660 661method state_info(Int $batch_size=0) 662{ 663 return [ 664 { shape => [$batch_size, $self->hidden_size], __layout__ => 'NC' }, 665 { shape => [$batch_size, $self->hidden_size], __layout__ => 'NC' } 666 ]; 667} 668 669method _alias() { 'lstm' } 670 671method hybrid_forward( 672 GluonClass $F, GluonInput $inputs, GluonInput $states, 673 GluonInput :$i2h_weight, GluonInput :$h2h_weight, GluonInput :$i2h_bias, GluonInput :$h2h_bias 674) 675{ 676 my $prefix = "t${\ $self->counter}_"; 677 my $i2h = $F->FullyConnected( 678 $inputs, $i2h_weight, $i2h_bias, 679 num_hidden => $self->hidden_size*4, 680 name => "${prefix}i2h" 681 ); 682 my $h2h = $F->FullyConnected( 683 $states->[0], $h2h_weight, $h2h_bias, 684 num_hidden => $self->hidden_size*4, 685 name => "${prefix}h2h" 686 ); 687 my $gates = $F->elemwise_add($i2h, $h2h, name => "${prefix}plus0"); 688 my @slice_gates = @{ $F->SliceChannel($gates, num_outputs => 4, name => "${prefix}slice") }; 689 my $in_gate = $F->Activation($slice_gates[0], act_type=>"sigmoid", name => "${prefix}i"); 690 my $forget_gate = $F->Activation($slice_gates[1], act_type=>"sigmoid", name => "${prefix}f"); 691 my $in_transform = $F->Activation($slice_gates[2], act_type=>"tanh", name => "${prefix}c"); 692 my $out_gate = $F->Activation($slice_gates[3], act_type=>"sigmoid", name => "${prefix}o"); 693 my $next_c = $F->_plus( 694 $F->elemwise_mul($forget_gate, $states->[1], name => "${prefix}mul0"), 695 $F->elemwise_mul($in_gate, $in_transform, name => "${prefix}mul1"), 696 name => "${prefix}state" 697 ); 698 my $next_h = $F->_mul($out_gate, $F->Activation($next_c, act_type=>"tanh", name => "${prefix}activation0"), name => "${prefix}out"); 699 return ($next_h, [$next_h, $next_c]); 700} 701 702__PACKAGE__->register('AI::MXNet::Gluon::RNN'); 703 704package AI::MXNet::Gluon::RNN::GRUCell; 705use AI::MXNet::Gluon::Mouse; 706extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell'; 707 708=head1 NAME 709 710 AI::MXNet::Gluon::RNN::GRUCell 711=cut 712 713=head1 DESCRIPTION 714 715 Gated Rectified Unit (GRU) network cell. 716 Note: this is an implementation of the cuDNN version of GRUs 717 (slight modification compared to Cho et al. 2014). 718 719 Parameters 720 ---------- 721 hidden_size : int 722 Number of units in output symbol. 723 i2h_weight_initializer : str or Initializer 724 Initializer for the input weights matrix, used for the linear 725 transformation of the inputs. 726 h2h_weight_initializer : str or Initializer 727 Initializer for the recurrent weights matrix, used for the linear 728 transformation of the recurrent state. 729 i2h_bias_initializer : str or Initializer 730 Initializer for the bias vector. 731 h2h_bias_initializer : str or Initializer 732 Initializer for the bias vector. 733 prefix : str, default 'gru_' 734 prefix for name of `Block`s 735 (and name of weight if params is `None`). 736 params : Parameter or None 737 Container for weight sharing between cells. 738 Created if `None`. 739=cut 740 741has 'hidden_size' => (is => 'rw', isa => 'Int', required => 1); 742has [qw/ 743 i2h_weight_initializer 744 h2h_weight_initializer 745 /] => (is => 'rw', isa => 'Maybe[Initializer]'); 746has [qw/ 747 i2h_bias_initializer 748 h2h_bias_initializer 749 /] => (is => 'rw', isa => 'Maybe[Initializer]', default => 'zeros'); 750has 'input_size' => (is => 'rw', isa => 'Int', default => 0); 751has [qw/ 752 i2h_weight 753 h2h_weight 754 i2h_bias 755 h2h_bias 756 /] => (is => 'rw', init_arg => undef); 757 758method python_constructor_arguments() 759{ 760 [qw/ 761 hidden_size 762 i2h_weight_initializer h2h_weight_initializer 763 i2h_bias_initializer h2h_bias_initializer 764 input_size 765 /]; 766} 767 768sub BUILD 769{ 770 my $self = shift; 771 $self->i2h_weight($self->params->get( 772 'i2h_weight', shape=>[3*$self->hidden_size, $self->input_size], 773 init => $self->i2h_weight_initializer, 774 allow_deferred_init => 1 775 )); 776 $self->h2h_weight($self->params->get( 777 'h2h_weight', shape=>[3*$self->hidden_size, $self->hidden_size], 778 init => $self->h2h_weight_initializer, 779 allow_deferred_init => 1 780 )); 781 $self->i2h_bias($self->params->get( 782 'i2h_bias', shape=>[3*$self->hidden_size], 783 init => $self->i2h_bias_initializer, 784 allow_deferred_init => 1 785 )); 786 $self->h2h_bias($self->params->get( 787 'h2h_bias', shape=>[3*$self->hidden_size], 788 init => $self->h2h_bias_initializer, 789 allow_deferred_init => 1 790 )); 791} 792 793method state_info(Int $batch_size=0) 794{ 795 return [{ shape => [$batch_size, $self->hidden_size], __layout__ => 'NC' }]; 796} 797 798method _alias() { 'gru' } 799 800method hybrid_forward( 801 GluonClass $F, GluonInput $inputs, GluonInput $states, 802 GluonInput :$i2h_weight, GluonInput :$h2h_weight, GluonInput :$i2h_bias, GluonInput :$h2h_bias 803) 804{ 805 my $prefix = "t${\ $self->counter}_"; 806 my $prev_state_h = $states->[0]; 807 my $i2h = $F->FullyConnected( 808 $inputs, $i2h_weight, $i2h_bias, 809 num_hidden => $self->hidden_size*3, 810 name => "${prefix}i2h" 811 ); 812 my $h2h = $F->FullyConnected( 813 $states->[0], $h2h_weight, $h2h_bias, 814 num_hidden => $self->hidden_size*3, 815 name => "${prefix}h2h" 816 ); 817 my ($i2h_r, $i2h_z, $h2h_r, $h2h_z); 818 ($i2h_r, $i2h_z, $i2h) = @{ $F->SliceChannel($i2h, num_outputs => 3, name => "${prefix}i2h_slice") }; 819 ($h2h_r, $h2h_z, $h2h) = @{ $F->SliceChannel($h2h, num_outputs => 3, name => "${prefix}h2h_slice") }; 820 my $reset_gate = $F->Activation($F->elemwise_add($i2h_r, $h2h_r, name => "${prefix}plus0"), act_type=>"sigmoid", name => "${prefix}r_act"); 821 my $update_gate = $F->Activation($F->elemwise_add($i2h_z, $h2h_z, name => "${prefix}plus1"), act_type=>"sigmoid", name => "${prefix}z_act"); 822 my $next_h_tmp = $F->Activation( 823 $F->elemwise_add( 824 $i2h, 825 $F->elemwise_mul( 826 $reset_gate, $h2h, name => "${prefix}mul0" 827 ), 828 name => "${prefix}plus2" 829 ), 830 act_type => "tanh", 831 name => "${prefix}h_act" 832 ); 833 my $ones = $F->ones_like($update_gate, name => "${prefix}ones_like0"); 834 my $next_h = $F->_plus( 835 $F->elemwise_mul( 836 $F->elemwise_sub($ones, $update_gate, name => "${prefix}minus0"), 837 $next_h_tmp, 838 name => "${prefix}mul1" 839 ), 840 $F->elemwise_mul($update_gate, $prev_state_h, name => "${prefix}mul2"), 841 name => "${prefix}out" 842 ); 843 return ($next_h, [$next_h]); 844} 845 846__PACKAGE__->register('AI::MXNet::Gluon::RNN'); 847 848package AI::MXNet::Gluon::RNN::SequentialRNNCell; 849use AI::MXNet::Gluon::Mouse; 850use AI::MXNet::Base; 851no warnings 'redefine'; 852extends 'AI::MXNet::Gluon::Block'; 853with 'AI::MXNet::Gluon::RNN::RecurrentCell'; 854has 'modified' => (is => 'rw', isa => 'Bool', default => 0); 855has [qw/counter 856 init_counter/] => (is => 'rw', isa => 'Int', default => -1); 857 858sub BUILD 859{ 860 my $self = shift; 861 $self->reset; 862} 863 864=head1 NAME 865 866 AI::MXNet::Gluon::RNN::SequentialRNNCell 867=cut 868 869=head1 DESCRIPTION 870 871 Sequentially stacking multiple RNN cells. 872=cut 873 874=head2 add 875 876 Appends a cell into the stack. 877 878 Parameters 879 ---------- 880 cell : rnn cell 881=cut 882 883method add(AI::MXNet::Gluon::Block $cell) 884{ 885 $self->register_child($cell); 886} 887 888method state_info(Int $batch_size=0) 889{ 890 return $self->_cells_state_info($self->_children, $batch_size); 891} 892 893method begin_state(%kwargs) 894{ 895 assert( 896 (not $self->modified), 897 "After applying modifier cells (e.g. ZoneoutCell) the base ". 898 "cell cannot be called directly. Call the modifier cell instead." 899 ); 900 return $self->_cells_begin_state($self->_children, %kwargs); 901} 902 903method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=, Str :$layout='NTC', Maybe[Bool] :$merge_outputs=) 904{ 905 $self->reset(); 906 my ($F, $batch_size); 907 ($inputs, undef, $F, $batch_size) = $self->_format_sequence($length, $inputs, $layout, undef); 908 my $num_cells = $self->_children->keys; 909 $begin_state = $self->_get_begin_state($F, $begin_state, $inputs, $batch_size); 910 my $p = 0; 911 my @next_states; 912 my $states; 913 enumerate(sub { 914 my ($i, $cell) = @_; 915 my $n = @{ $cell->state_info() }; 916 $states = [@{ $begin_state }[$p..$p+$n-1]]; 917 $p += $n; 918 ($inputs, $states) = $cell->unroll( 919 $length, $inputs, begin_state => $states, layout => $layout, 920 merge_outputs => ($i < ($num_cells - 1)) ? undef : $merge_outputs 921 ); 922 push @next_states, @{ $states }; 923 }, [$self->_children->values]); 924 return ($inputs, \@next_states); 925} 926 927method call($inputs, $states) 928{ 929 $self->counter($self->counter + 1); 930 my @next_states; 931 my $p = 0; 932 for my $cell ($self->_children->values) 933 { 934 assert(not $cell->isa('AI::MXNet::Gluon::RNN::BidirectionalCell')); 935 my $n = @{ $cell->state_info() }; 936 my $state = [@{ $states }[$p,$p+$n-1]]; 937 $p += $n; 938 ($inputs, $state) = $cell->($inputs, $state); 939 push @next_states, @{ $state }; 940 } 941 return ($inputs, \@next_states); 942} 943 944use overload '@{}' => sub { [shift->_children->values] }; 945use overload '""' => sub { 946 my $self = shift; 947 my $s = "%s(\n%s\n)"; 948 my @children; 949 enumerate(sub { 950 my ($i, $m) = @_; 951 push @children, "($i): ". AI::MXNet::Base::_indent("$m", 2); 952 }, [$self->_children->values]); 953 return sprintf($s, $self->_class_name, join("\n", @children)); 954}; 955 956method hybrid_forward(@args) 957{ 958 confess('Not Implemented'); 959} 960 961__PACKAGE__->register('AI::MXNet::Gluon::RNN'); 962 963package AI::MXNet::Gluon::RNN::DropoutCell; 964use AI::MXNet::Gluon::Mouse; 965extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell'; 966 967=head1 NAME 968 969 AI::MXNet::Gluon::RNN::DropoutCell 970=cut 971 972=head1 DESCRIPTION 973 974 Applies dropout on input. 975 976 Parameters 977 ---------- 978 rate : float 979 Percentage of elements to drop out, which 980 is 1 - percentage to retain. 981=cut 982 983has 'rate' => (is => 'ro', isa => 'Num', required => 1); 984method python_constructor_arguments() { ['rate'] } 985 986method state_info(Int $batch_size=0) { [] } 987 988method _alias() { 'dropout' } 989 990method hybrid_forward(GluonClass $F, GluonInput $inputs, GluonInput $states) 991{ 992 if($self->rate > 0) 993 { 994 $inputs = $F->Dropout($inputs, p => $self->rate, name => "t${\ $self->counter }_fwd"); 995 } 996 return ($inputs, $states); 997} 998 999method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=, Str :$layout='NTC', Maybe[Bool] :$merge_outputs=) 1000{ 1001 $self->reset; 1002 my $F; 1003 ($inputs, undef, $F) = $self->_format_sequence($length, $inputs, $layout, $merge_outputs); 1004 if(blessed $inputs) 1005 { 1006 return $self->hybrid_forward($F, $inputs, $begin_state//[]); 1007 } 1008 else 1009 { 1010 return $self->SUPER::unroll( 1011 $length, $inputs, begin_state => $begin_state, layout => $layout, 1012 merge_outputs => $merge_outputs 1013 ); 1014 } 1015} 1016 1017use overload '""' => sub { 1018 my $self = shift; 1019 return $self->_class_name.'(rate ='.$self->rate.')'; 1020}; 1021 1022__PACKAGE__->register('AI::MXNet::Gluon::RNN'); 1023 1024package AI::MXNet::Gluon::RNN::ModifierCell; 1025use AI::MXNet::Gluon::Mouse; 1026use AI::MXNet::Base; 1027extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell'; 1028has 'base_cell' => (is => 'rw', isa => 'AI::MXNet::Gluon::RNN::HybridRecurrentCell', required => 1); 1029 1030=head1 NAME 1031 1032 AI::MXNet::Gluon::RNN::ModifierCell 1033=cut 1034 1035=head1 DESCRIPTION 1036 1037 Base class for modifier cells. A modifier 1038 cell takes a base cell, apply modifications 1039 on it (e.g. Zoneout), and returns a new cell. 1040 1041 After applying modifiers the base cell should 1042 no longer be called directly. The modifier cell 1043 should be used instead. 1044=cut 1045 1046 1047sub BUILD 1048{ 1049 my $self = shift; 1050 assert( 1051 (not $self->base_cell->modified), 1052 "Cell ${\ $self->base_cell->name } is already modified. One cell cannot be modified twice" 1053 ); 1054 $self->base_cell->modified(1); 1055} 1056 1057method params() 1058{ 1059 return $self->base_cell->params; 1060} 1061 1062method state_info(Int $batch_size=0) 1063{ 1064 return $self->base_cell->state_info($batch_size); 1065 1066} 1067 1068method begin_state(CodeRef :$func=sub{ AI::MXNet::Symbol->zeros(@_) }, %kwargs) 1069{ 1070 assert( 1071 (not $self->modified), 1072 "After applying modifier cells (e.g. DropoutCell) the base ". 1073 "cell cannot be called directly. Call the modifier cell instead." 1074 ); 1075 $self->base_cell->modified(0); 1076 my $begin = $self->base_cell->begin_state(func => $func, %kwargs); 1077 $self->base_cell->modified(1); 1078 return $begin; 1079} 1080 1081method hybrid_forward(GluonClass $F, GluonInput $inputs, GluonInput $states) 1082{ 1083 confess('Not Implemented'); 1084} 1085 1086use overload '""' => sub { 1087 my $self = shift; 1088 return $self->_class_name.'('.$self->base_cell.')'; 1089}; 1090 1091package AI::MXNet::Gluon::RNN::ZoneoutCell; 1092use AI::MXNet::Gluon::Mouse; 1093use AI::MXNet::Base; 1094extends 'AI::MXNet::Gluon::RNN::ModifierCell'; 1095 1096=head1 NAME 1097 1098 AI::MXNet::Gluon::RNN::ZoneoutCell 1099=cut 1100 1101=head1 DESCRIPTION 1102 1103 Applies Zoneout on base cell. 1104=cut 1105has [qw/zoneout_outputs 1106 zoneout_states/] => (is => 'ro', isa => 'Num', default => 0); 1107has 'prev_output' => (is => 'rw', init_arg => undef); 1108method python_constructor_arguments() { ['base_cell', 'zoneout_outputs', 'zoneout_states'] } 1109 1110sub BUILD 1111{ 1112 my $self = shift; 1113 assert( 1114 (not $self->base_cell->isa('AI::MXNet::Gluon::RNN::BidirectionalCell')), 1115 "BidirectionalCell doesn't support zoneout since it doesn't support step. ". 1116 "Please add ZoneoutCell to the cells underneath instead." 1117 ); 1118 assert( 1119 (not $self->base_cell->isa('AI::MXNet::Gluon::RNN::SequentialRNNCel') or not $self->base_cell->bidirectional), 1120 "Bidirectional SequentialRNNCell doesn't support zoneout. ". 1121 "Please add ZoneoutCell to the cells underneath instead." 1122 ); 1123} 1124 1125use overload '""' => sub { 1126 my $self = shift; 1127 return $self->_class_name.'(p_out='.$self->zoneout_outputs.', p_state='.$self->zoneout_states. 1128 ', '.$self->base_cell.')'; 1129}; 1130 1131method _alias() { 'zoneout' } 1132 1133method reset() 1134{ 1135 $self->SUPER::reset(); 1136 $self->prev_output(undef); 1137} 1138 1139method hybrid_forward(GluonClass $F, GluonInput $inputs, GluonInput $states) 1140{ 1141 my ($cell, $p_outputs, $p_states) = ($self->base_cell, $self->zoneout_outputs, $self->zoneout_states); 1142 my ($next_output, $next_states) = $cell->($inputs, $states); 1143 my $mask = sub { my ($p, $like) = @_; $F->Dropout($F->ones_like($like), p=>$p) }; 1144 1145 my $prev_output = $self->prev_output//$F->zeros_like($next_output); 1146 my $output = $p_outputs != 0 ? $F->where($mask->($p_outputs, $next_output), $next_output, $prev_output) : $next_output; 1147 if($p_states != 0) 1148 { 1149 my @tmp; 1150 for(zip($next_states, $states)) { 1151 my ($new_s, $old_s) = @$_; 1152 push @tmp, $F->where($mask->($p_states, $new_s), $new_s, $old_s); 1153 } 1154 $states = \@tmp; 1155 } 1156 else 1157 { 1158 $states = $next_states; 1159 } 1160 $self->prev_output($output); 1161 return ($output, $states); 1162} 1163 1164__PACKAGE__->register('AI::MXNet::Gluon::RNN'); 1165 1166package AI::MXNet::Gluon::RNN::ResidualCell; 1167use AI::MXNet::Gluon::Mouse; 1168use AI::MXNet::Base; 1169extends 'AI::MXNet::Gluon::RNN::ModifierCell'; 1170method python_constructor_arguments() { ['base_cell'] } 1171 1172=head1 NAME 1173 1174 AI::MXNet::Gluon::RNN::ResidualCell 1175=cut 1176 1177=head1 DESCRIPTION 1178 1179 Adds residual connection as described in Wu et al, 2016 1180 (https://arxiv.org/abs/1609.08144). 1181 Output of the cell is output of the base cell plus input. 1182=cut 1183 1184method hybrid_forward(GluonClas $F, GluonInput $inputs, GluonInput $states) 1185{ 1186 my $output; 1187 ($output, $states) = $self->base_cell->($inputs, $states); 1188 $output = $F->elemwise_add($output, $inputs, name => "t${\ $self->counter }_fwd"); 1189 return ($output, $states); 1190} 1191 1192method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=, Str :$layout='NTC', Maybe[Bool] :$merge_outputs=) 1193{ 1194 $self->reset(); 1195 1196 $self->base_cell->modified(0); 1197 my ($outputs, $states) = $self->base_cell->unroll( 1198 $length, $inputs, begin_state => $begin_state, layout => $layout, merge_outputs => $merge_outputs 1199 ); 1200 $self->base_cell->modified(1); 1201 1202 $merge_outputs //= blessed $outputs ? 1 : 0; 1203 my $F; 1204 ($inputs, undef, $F) = $self->_format_sequence($length, $inputs, $layout, $merge_outputs); 1205 if($merge_outputs) 1206 { 1207 $outputs = $F->elemwise_add($outputs, $inputs); 1208 } 1209 else 1210 { 1211 my @tmp; 1212 for(zip($outputs, $inputs)) { 1213 my ($i, $j) = @$_; 1214 push @tmp, $F->elemwise_add($i, $j); 1215 } 1216 $outputs = \@tmp; 1217 } 1218 return ($outputs, $states); 1219} 1220 1221__PACKAGE__->register('AI::MXNet::Gluon::RNN'); 1222 1223package AI::MXNet::Gluon::RNN::BidirectionalCell; 1224use AI::MXNet::Gluon::Mouse; 1225use AI::MXNet::Base; 1226extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell'; 1227has [qw/l_cell r_cell/] => (is => 'ro', isa => 'AI::MXNet::Gluon::RNN::HybridRecurrentCell', required => 1); 1228has 'output_prefix' => (is => 'ro', isa => 'Str', default => 'bi_'); 1229method python_constructor_arguments() { ['l_cell', 'r_cell', 'output_prefix'] } 1230 1231=head1 NAME 1232 1233 AI::MXNet::Gluon::RNN::BidirectionalCell 1234=cut 1235 1236=head1 DESCRIPTION 1237 1238 Bidirectional RNN cell. 1239 1240 Parameters 1241 ---------- 1242 l_cell : RecurrentCell 1243 Cell for forward unrolling 1244 r_cell : RecurrentCell 1245 Cell for backward unrolling 1246=cut 1247 1248method call($inputs, $states) 1249{ 1250 confess("Bidirectional cell cannot be stepped. Please use unroll"); 1251} 1252 1253use overload '""' => sub { 1254 my $self = shift; 1255 "${\ $self->_class_name }(forward=${\ $self->l_cell }, backward=${\ $self->r_cell })"; 1256}; 1257 1258method state_info(Int $batch_size=0) 1259{ 1260 return $self->_cells_state_info($self->_children, $batch_size); 1261} 1262 1263method begin_state(%kwargs) 1264{ 1265 assert( 1266 (not $self->modified), 1267 "After applying modifier cells (e.g. DropoutCell) the base ". 1268 "cell cannot be called directly. Call the modifier cell instead." 1269 ); 1270 return $self->_cells_begin_state($self->_children, %kwargs); 1271} 1272 1273method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=, Str :$layout='NTC', Maybe[Bool] :$merge_outputs=) 1274{ 1275 $self->reset(); 1276 my ($axis, $F, $batch_size); 1277 ($inputs, $axis, $F, $batch_size) = $self->_format_sequence($length, $inputs, $layout, 0); 1278 $begin_state //= $self->_get_begin_state($F, $begin_state, $inputs, $batch_size); 1279 1280 my $states = $begin_state; 1281 my ($l_cell, $r_cell) = $self->_children->values; 1282 $l_cell->state_info($batch_size); 1283 my ($l_outputs, $l_states) = $l_cell->unroll( 1284 $length, $inputs, 1285 begin_state => [@{ $states }[0..@{ $l_cell->state_info($batch_size) }-1]], 1286 layout => $layout, 1287 merge_outputs => $merge_outputs 1288 ); 1289 my ($r_outputs, $r_states) = $r_cell->unroll( 1290 $length, [reverse @{$inputs}], 1291 begin_state => [@{$states}[@{ $l_cell->state_info }..@{$states}-1]], 1292 layout => $layout, 1293 merge_outputs => $merge_outputs 1294 ); 1295 if(not defined $merge_outputs) 1296 { 1297 $merge_outputs = blessed $l_outputs and blessed $r_outputs; 1298 ($l_outputs) = $self->_format_sequence(undef, $l_outputs, $layout, $merge_outputs); 1299 ($r_outputs) = $self->_format_sequence(undef, $r_outputs, $layout, $merge_outputs); 1300 } 1301 my $outputs; 1302 if($merge_outputs) 1303 { 1304 $r_outputs = $F->reverse($r_outputs, axis=>$axis); 1305 $outputs = $F->concat($l_outputs, $r_outputs, dim=>2, name=>$self->output_prefix.'out'); 1306 } 1307 else 1308 { 1309 $outputs = []; 1310 enumerate(sub { 1311 my ($i, $l_o, $r_o) = @_; 1312 push @$outputs, $F->concat( 1313 $l_o, $r_o, dim=>1, 1314 name => sprintf('%st%d', $self->output_prefix, $i) 1315 ); 1316 }, [@{ $l_outputs }], [reverse(@{ $r_outputs })] 1317 ); 1318 } 1319 $states = [@{ $l_states }, @{ $r_states }]; 1320 return ($outputs, $states); 1321} 1322 1323__PACKAGE__->register('AI::MXNet::Gluon::RNN'); 1324 13251; 1326