1#!/usr/bin/env perl
2
3# Licensed to the Apache Software Foundation (ASF) under one
4# or more contributor license agreements.  See the NOTICE file
5# distributed with this work for additional information
6# regarding copyright ownership.  The ASF licenses this file
7# to you under the Apache License, Version 2.0 (the
8# "License"); you may not use this file except in compliance
9# with the License.  You may obtain a copy of the License at
10#
11#   http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing,
14# software distributed under the License is distributed on an
15# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16# KIND, either express or implied.  See the License for the
17# specific language governing permissions and limitations
18# under the License.
19
20use strict;
21use warnings;
22use PDL;
23use AI::MXNet qw(mx);
24use AI::MXNet::Function::Parameters;
25use Getopt::Long qw(HelpMessage);
26
27GetOptions(
28    'num-layers=i'   => \(my $num_layers   = 2       ),
29    'num-hidden=i'   => \(my $num_hidden   = 200     ),
30    'num-embed=i'    => \(my $num_embed    = 200     ),
31    'gpus=s'         => \(my $gpus                   ),
32    'kv-store=s'     => \(my $kv_store     = 'device'),
33    'num-epoch=i'    => \(my $num_epoch    = 25      ),
34    'lr=f'           => \(my $lr           = 0.01    ),
35    'optimizer=s'    => \(my $optimizer    = 'sgd'   ),
36    'mom=f'          => \(my $mom          = 0       ),
37    'wd=f'           => \(my $wd           = 0.00001 ),
38    'batch-size=i'   => \(my $batch_size   = 32      ),
39    'disp-batches=i' => \(my $disp_batches = 50      ),
40    'chkp-prefix=s'  => \(my $chkp_prefix  = 'lstm_' ),
41    'chkp-epoch=i'   => \(my $chkp_epoch   = 0       ),
42    'help'           => sub { HelpMessage(0) },
43) or HelpMessage(1);
44
45=head1 NAME
46
47    lstm_bucketing.pl - Example of training LSTM RNN on Sherlock Holmes data using high level RNN interface
48
49=head1 SYNOPSIS
50
51    --num-layers     number of stacked RNN layers, default=2
52    --num-hidden     hidden layer size, default=200
53    --num-embed      embedding layer size, default=200
54    --gpus           list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.
55                     Increase batch size when using multiple gpus for best performance.
56    --kv-store       key-value store type, default='device'
57    --num-epochs     max num of epochs, default=25
58    --lr             initial learning rate, default=0.01
59    --optimizer      the optimizer type, default='sgd'
60    --mom            momentum for sgd, default=0.0
61    --wd             weight decay for sgd, default=0.00001
62    --batch-size     the batch size type, default=32
63    --disp-batches   show progress for every n batches, default=50
64    --chkp-prefix    prefix for checkpoint files, default='lstm_'
65    --chkp-epoch     save checkpoint after this many epoch, default=0 (saving checkpoints is disabled)
66
67=cut
68func tokenize_text($fname, :$vocab=, :$invalid_label=-1, :$start_label=0)
69{
70    open(F, $fname) or die "Can't open $fname: $!";
71    my @lines = map { my $l = [split(/ /)]; shift(@$l); $l } (<F>);
72    my $sentences;
73    ($sentences, $vocab) = mx->rnn->encode_sentences(
74        \@lines,
75        vocab         => $vocab,
76        invalid_label => $invalid_label,
77        start_label   => $start_label
78    );
79    return ($sentences, $vocab);
80}
81
82my $buckets = [10, 20, 30, 40, 50, 60];
83my $start_label   = 1;
84my $invalid_label = 0;
85
86my ($train_sentences, $vocabulary) = tokenize_text(
87    './data/sherlockholmes.train.txt', start_label => $start_label,
88    invalid_label => $invalid_label
89);
90my ($validation_sentences) = tokenize_text(
91    './data/sherlockholmes.test.txt', vocab => $vocabulary,
92    start_label => $start_label, invalid_label => $invalid_label
93);
94my $data_train  = mx->rnn->BucketSentenceIter(
95    $train_sentences, $batch_size, buckets => $buckets,
96    invalid_label => $invalid_label
97);
98my $data_val    = mx->rnn->BucketSentenceIter(
99    $validation_sentences, $batch_size, buckets => $buckets,
100    invalid_label => $invalid_label
101);
102
103my $stack = mx->rnn->SequentialRNNCell();
104for my $i (0..$num_layers-1)
105{
106    $stack->add(mx->rnn->LSTMCell(num_hidden => $num_hidden, prefix => "lstm_l${i}_"));
107}
108
109my $sym_gen = sub {
110    my $seq_len = shift;
111    my $data  = mx->sym->Variable('data');
112    my $label = mx->sym->Variable('softmax_label');
113    my $embed = mx->sym->Embedding(
114        data => $data, input_dim => scalar(keys %$vocabulary),
115        output_dim => $num_embed, name => 'embed'
116    );
117    $stack->reset;
118    my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1);
119    my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden]);
120    $pred    = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocabulary), name => 'pred');
121    $label   = mx->sym->Reshape($label, shape => [-1]);
122    $pred    = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax');
123    return ($pred, ['data'], ['softmax_label']);
124};
125
126my $contexts;
127if(defined $gpus)
128{
129    $contexts = [map { mx->gpu($_) } split(/,/, $gpus)];
130}
131else
132{
133    $contexts = mx->cpu(0);
134}
135
136my $model = mx->mod->BucketingModule(
137    sym_gen             => $sym_gen,
138    default_bucket_key  => $data_train->default_bucket_key,
139    context             => $contexts
140);
141
142$model->fit(
143    $data_train,
144    eval_data           => $data_val,
145    eval_metric         => mx->metric->Perplexity($invalid_label),
146    kvstore             => $kv_store,
147    optimizer           => $optimizer,
148    optimizer_params    => {
149                                learning_rate => $lr,
150                                momentum      => $mom,
151                                wd            => $wd,
152                        },
153    initializer         => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
154    num_epoch           => $num_epoch,
155    batch_end_callback  => mx->callback->Speedometer($batch_size, $disp_batches),
156    ($chkp_epoch ? (epoch_end_callback  => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ())
157);
158