1#!/usr/bin/env perl 2 3# Licensed to the Apache Software Foundation (ASF) under one 4# or more contributor license agreements. See the NOTICE file 5# distributed with this work for additional information 6# regarding copyright ownership. The ASF licenses this file 7# to you under the Apache License, Version 2.0 (the 8# "License"); you may not use this file except in compliance 9# with the License. You may obtain a copy of the License at 10# 11# http://www.apache.org/licenses/LICENSE-2.0 12# 13# Unless required by applicable law or agreed to in writing, 14# software distributed under the License is distributed on an 15# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16# KIND, either express or implied. See the License for the 17# specific language governing permissions and limitations 18# under the License. 19 20use strict; 21use warnings; 22use PDL; 23use AI::MXNet qw(mx); 24use AI::MXNet::Function::Parameters; 25use Getopt::Long qw(HelpMessage); 26 27GetOptions( 28 'num-layers=i' => \(my $num_layers = 2 ), 29 'num-hidden=i' => \(my $num_hidden = 200 ), 30 'num-embed=i' => \(my $num_embed = 200 ), 31 'gpus=s' => \(my $gpus ), 32 'kv-store=s' => \(my $kv_store = 'device'), 33 'num-epoch=i' => \(my $num_epoch = 25 ), 34 'lr=f' => \(my $lr = 0.01 ), 35 'optimizer=s' => \(my $optimizer = 'sgd' ), 36 'mom=f' => \(my $mom = 0 ), 37 'wd=f' => \(my $wd = 0.00001 ), 38 'batch-size=i' => \(my $batch_size = 32 ), 39 'disp-batches=i' => \(my $disp_batches = 50 ), 40 'chkp-prefix=s' => \(my $chkp_prefix = 'lstm_' ), 41 'chkp-epoch=i' => \(my $chkp_epoch = 0 ), 42 'help' => sub { HelpMessage(0) }, 43) or HelpMessage(1); 44 45=head1 NAME 46 47 lstm_bucketing.pl - Example of training LSTM RNN on Sherlock Holmes data using high level RNN interface 48 49=head1 SYNOPSIS 50 51 --num-layers number of stacked RNN layers, default=2 52 --num-hidden hidden layer size, default=200 53 --num-embed embedding layer size, default=200 54 --gpus list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. 55 Increase batch size when using multiple gpus for best performance. 56 --kv-store key-value store type, default='device' 57 --num-epochs max num of epochs, default=25 58 --lr initial learning rate, default=0.01 59 --optimizer the optimizer type, default='sgd' 60 --mom momentum for sgd, default=0.0 61 --wd weight decay for sgd, default=0.00001 62 --batch-size the batch size type, default=32 63 --disp-batches show progress for every n batches, default=50 64 --chkp-prefix prefix for checkpoint files, default='lstm_' 65 --chkp-epoch save checkpoint after this many epoch, default=0 (saving checkpoints is disabled) 66 67=cut 68func tokenize_text($fname, :$vocab=, :$invalid_label=-1, :$start_label=0) 69{ 70 open(F, $fname) or die "Can't open $fname: $!"; 71 my @lines = map { my $l = [split(/ /)]; shift(@$l); $l } (<F>); 72 my $sentences; 73 ($sentences, $vocab) = mx->rnn->encode_sentences( 74 \@lines, 75 vocab => $vocab, 76 invalid_label => $invalid_label, 77 start_label => $start_label 78 ); 79 return ($sentences, $vocab); 80} 81 82my $buckets = [10, 20, 30, 40, 50, 60]; 83my $start_label = 1; 84my $invalid_label = 0; 85 86my ($train_sentences, $vocabulary) = tokenize_text( 87 './data/sherlockholmes.train.txt', start_label => $start_label, 88 invalid_label => $invalid_label 89); 90my ($validation_sentences) = tokenize_text( 91 './data/sherlockholmes.test.txt', vocab => $vocabulary, 92 start_label => $start_label, invalid_label => $invalid_label 93); 94my $data_train = mx->rnn->BucketSentenceIter( 95 $train_sentences, $batch_size, buckets => $buckets, 96 invalid_label => $invalid_label 97); 98my $data_val = mx->rnn->BucketSentenceIter( 99 $validation_sentences, $batch_size, buckets => $buckets, 100 invalid_label => $invalid_label 101); 102 103my $stack = mx->rnn->SequentialRNNCell(); 104for my $i (0..$num_layers-1) 105{ 106 $stack->add(mx->rnn->LSTMCell(num_hidden => $num_hidden, prefix => "lstm_l${i}_")); 107} 108 109my $sym_gen = sub { 110 my $seq_len = shift; 111 my $data = mx->sym->Variable('data'); 112 my $label = mx->sym->Variable('softmax_label'); 113 my $embed = mx->sym->Embedding( 114 data => $data, input_dim => scalar(keys %$vocabulary), 115 output_dim => $num_embed, name => 'embed' 116 ); 117 $stack->reset; 118 my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1); 119 my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden]); 120 $pred = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocabulary), name => 'pred'); 121 $label = mx->sym->Reshape($label, shape => [-1]); 122 $pred = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax'); 123 return ($pred, ['data'], ['softmax_label']); 124}; 125 126my $contexts; 127if(defined $gpus) 128{ 129 $contexts = [map { mx->gpu($_) } split(/,/, $gpus)]; 130} 131else 132{ 133 $contexts = mx->cpu(0); 134} 135 136my $model = mx->mod->BucketingModule( 137 sym_gen => $sym_gen, 138 default_bucket_key => $data_train->default_bucket_key, 139 context => $contexts 140); 141 142$model->fit( 143 $data_train, 144 eval_data => $data_val, 145 eval_metric => mx->metric->Perplexity($invalid_label), 146 kvstore => $kv_store, 147 optimizer => $optimizer, 148 optimizer_params => { 149 learning_rate => $lr, 150 momentum => $mom, 151 wd => $wd, 152 }, 153 initializer => mx->init->Xavier(factor_type => "in", magnitude => 2.34), 154 num_epoch => $num_epoch, 155 batch_end_callback => mx->callback->Speedometer($batch_size, $disp_batches), 156 ($chkp_epoch ? (epoch_end_callback => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ()) 157); 158