1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * Hua Zhang mz24cn@hotmail.com
22  * The code implements C++ version charRNN for mxnet\example\rnn\char-rnn.ipynb with MXNet.cpp API.
23  * The generated params file is compatiable with python version.
24  * train() and predict() has been verified with original data samples.
25  * 2017/1/23:
26  * Add faster version charRNN based on built-in cuDNN RNN operator, 10 times faster.
27  * Add time major computation graph, although no substantial performance difference.
28  * Support continuing training from last params file.
29  * Rename params file epoch number starts from zero.
30  */
31 
32 #if _MSC_VER
33 #pragma warning(disable: 4996)  // VS2015 complains on 'std::copy' ...
34 #endif
35 #include <cstring>
36 #include <iostream>
37 #include <fstream>
38 #include <unordered_map>
39 #include <vector>
40 #include <string>
41 #include <tuple>
42 #include <algorithm>
43 #include <functional>
44 #include <thread>
45 #include <chrono>
46 #include "mxnet-cpp/MxNetCpp.h"
47 #include "utils.h"
48 
49 using namespace mxnet::cpp;
50 
51 struct LSTMState {
52   Symbol C;
53   Symbol h;
54 };
55 
56 struct LSTMParam {
57   Symbol i2h_weight;
58   Symbol i2h_bias;
59   Symbol h2h_weight;
60   Symbol h2h_bias;
61 };
62 
63 bool TIME_MAJOR = true;
64 
65 // LSTM Cell symbol
LSTM(int num_hidden,const Symbol & indata,const LSTMState & prev_state,const LSTMParam & param,int seqidx,int layeridx,mx_float dropout=0)66 LSTMState LSTM(int num_hidden, const Symbol& indata, const LSTMState& prev_state,
67     const LSTMParam& param, int seqidx, int layeridx, mx_float dropout = 0) {
68   auto input = dropout > 0? Dropout(indata, dropout) : indata;
69   auto prefix = std::string("t") + std::to_string(seqidx) + "_l" + std::to_string(layeridx);
70   auto i2h = FullyConnected(prefix + "_i2h", input, param.i2h_weight, param.i2h_bias,
71       num_hidden * 4);
72   auto h2h = FullyConnected(prefix + "_h2h", prev_state.h, param.h2h_weight, param.h2h_bias,
73       num_hidden * 4);
74   auto gates = i2h + h2h;
75   auto slice_gates = SliceChannel(prefix + "_slice", gates, 4);
76   auto in_gate = Activation(slice_gates[0], ActivationActType::kSigmoid);
77   auto in_transform = Activation(slice_gates[1], ActivationActType::kTanh);
78   auto forget_gate = Activation(slice_gates[2], ActivationActType::kSigmoid);
79   auto out_gate = Activation(slice_gates[3], ActivationActType::kSigmoid);
80 
81   LSTMState state;
82   state.C = (forget_gate * prev_state.C) + (in_gate * in_transform);
83   state.h = out_gate * Activation(state.C, ActivationActType::kTanh);
84   return state;
85 }
86 
LSTMUnroll(int num_lstm_layer,int sequence_length,int input_dim,int num_hidden,int num_embed,mx_float dropout=0)87 Symbol LSTMUnroll(int num_lstm_layer, int sequence_length, int input_dim,
88         int num_hidden, int num_embed, mx_float dropout = 0) {
89   auto isTrain = sequence_length > 1;
90   auto data = Symbol::Variable("data");
91   if (TIME_MAJOR && isTrain)
92     data = transpose(data);
93   auto embed_weight = Symbol::Variable("embed_weight");
94   auto embed = Embedding("embed", data, embed_weight, input_dim, num_embed);
95   auto wordvec = isTrain? SliceChannel(embed, sequence_length, TIME_MAJOR? 0 : 1, true) : embed;
96 
97   std::vector<LSTMState> last_states;
98   std::vector<LSTMParam> param_cells;
99   for (int l = 0; l < num_lstm_layer; l++) {
100     std::string layer = "l" + std::to_string(l);
101     LSTMParam param;
102     param.i2h_weight = Symbol::Variable(layer + "_i2h_weight");
103     param.i2h_bias = Symbol::Variable(layer + "_i2h_bias");
104     param.h2h_weight = Symbol::Variable(layer + "_h2h_weight");
105     param.h2h_bias = Symbol::Variable(layer + "_h2h_bias");
106     param_cells.push_back(param);
107     LSTMState state;
108     state.C = Symbol::Variable(layer + "_init_c");
109     state.h = Symbol::Variable(layer + "_init_h");
110     last_states.push_back(state);
111   }
112 
113   std::vector<Symbol> hidden_all;
114   for (int i = 0; i < sequence_length; i++) {
115     auto hidden = wordvec[i];
116     for (int layer = 0; layer < num_lstm_layer; layer++) {
117       double dp_ratio = layer == 0? 0 : dropout;
118       auto next_state = LSTM(num_hidden, hidden, last_states[layer], param_cells[layer],
119           i, layer, dp_ratio);
120       hidden = next_state.h;
121       last_states[layer] = next_state;
122     }
123     if (dropout > 0)
124       hidden = Dropout(hidden, dropout);
125     hidden_all.push_back(hidden);
126   }
127 
128   auto hidden_concat = isTrain? Concat(hidden_all, hidden_all.size(), 0) : hidden_all[0];
129   auto cls_weight = Symbol::Variable("cls_weight");
130   auto cls_bias = Symbol::Variable("cls_bias");
131   auto pred = FullyConnected("pred", hidden_concat, cls_weight, cls_bias, input_dim);
132 
133   auto label = Symbol::Variable("softmax_label");
134   label = transpose(label);
135   label = Reshape(label, Shape(), false, Shape(0), false);  // -1: infer from graph
136   auto sm = SoftmaxOutput("softmax", pred, label);
137   if (isTrain)
138     return sm;
139 
140   std::vector<Symbol> outputs = { sm };
141   for (auto& state : last_states) {
142     outputs.push_back(state.C);
143     outputs.push_back(state.h);
144   }
145   return Symbol::Group(outputs);
146 }
147 
148 // Currently mxnet GPU version RNN operator is implemented via *fast* NVIDIA cuDNN.
LSTMWithBuiltInRNNOp(int num_lstm_layer,int sequence_length,int input_dim,int num_hidden,int num_embed,mx_float dropout=0)149 Symbol LSTMWithBuiltInRNNOp(int num_lstm_layer, int sequence_length, int input_dim,
150  int num_hidden, int num_embed, mx_float dropout = 0) {
151   auto isTrain = sequence_length > 1;
152   auto data = Symbol::Variable("data");
153   if (TIME_MAJOR && isTrain)
154     data = transpose(data);
155 
156   auto embed_weight = Symbol::Variable("embed_weight");
157   auto embed = Embedding("embed", data, embed_weight, input_dim, num_embed);
158   auto label = Symbol::Variable("softmax_label");
159   label = transpose(label);
160   label = Reshape(label, Shape(), false,
161                   Shape(0), false);  // FullyConnected requires one dimension
162   if (!TIME_MAJOR && isTrain)
163     embed = SwapAxis(embed, 0, 1);  // Change to time-major as cuDNN requires
164 
165   // We need not do the SwapAxis op as python version does. Direct and better performance in C++!
166   auto rnn_h_init = Symbol::Variable("LSTM_init_h");
167   auto rnn_c_init = Symbol::Variable("LSTM_init_c");
168   auto rnn_params = Symbol::Variable("LSTM_parameters");  // See explanations near RNNXavier class
169   auto variable_sequence_length = Symbol::Variable("sequence_length");
170   auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, variable_sequence_length, num_hidden,
171                  num_lstm_layer, RNNMode::kLstm, false, dropout, !isTrain);
172   auto hidden = Reshape(rnn[0], Shape(), false, Shape(0, num_hidden), false);
173 
174   auto cls_weight = Symbol::Variable("cls_weight");
175   auto cls_bias = Symbol::Variable("cls_bias");
176   auto pred = FullyConnected("pred", hidden, cls_weight, cls_bias, input_dim);
177   /*In rnn-time-major/rnn_cell_demo.py, the author claimed time-major version speeds up
178    * 1.5~2 times versus batch version. I doubts on the conclusion. In my test, the performance
179    * of both codes are almost same. In fact, there are no substantially differences between
180    * two codes. They are both based on time major cuDNN, the computation graph only differs
181    * slightly on the choices of where to put Reshape/SwapAxis/transpose operation. Here I don't
182    * use Reshape on pred and keep label shape on SoftmaxOutput like time major version code,
183    * but Reshape on label for simplification. It doesn't make influence on performacne. */
184 
185   auto sm = SoftmaxOutput("softmax", pred, label);
186   if (isTrain)
187     return sm;
188   else
189     return Symbol::Group({ sm, rnn[1/*RNNOpOutputs::kStateOut=1*/],
190     rnn[2/*RNNOpOutputs::kStateCellOut=2*/] });
191 }
192 
193 class Shuffler {
194   std::vector<int> sequence;
195  public:
Shuffler(int size)196   explicit Shuffler(int size) : sequence(size) {
197     int* p = sequence.data();
198     for (int i = 0; i < size; i++)
199       *p++ = i;
200   }
shuffle(std::function<void (int,int)> lambda=nullptr)201   void shuffle(std::function<void(int, int)> lambda = nullptr) {
202     random_shuffle(sequence.begin(), sequence.end());
203     int n = 0;
204     if (lambda != nullptr)
205       for (int i : sequence)
206         lambda(n++, i);
207   }
data()208   const int* data() {
209     return sequence.data();
210   }
211 };
212 
213 class BucketSentenceIter : public DataIter {
214   Shuffler* random;
215   int batch, current, end;
216   unsigned int sequence_length;
217   Context device;
218   std::vector<std::vector<mx_float>> sequences;
219   std::vector<wchar_t> index2chars;
220   std::unordered_map<wchar_t, mx_float> charIndices;
221 
222  public:
BucketSentenceIter(std::string filename,int minibatch,Context context)223   BucketSentenceIter(std::string filename, int minibatch, Context context) : batch(minibatch),
224   current(-1), device(context) {
225     auto content = readContent(filename);
226     buildCharIndex(content);
227     sequences = convertTextToSequences(content, '\n');
228 
229     int N = sequences.size() / batch * batch;  // total used samples
230     sequences.resize(N);
231     sort(sequences.begin(), sequences.end(), [](const std::vector<mx_float>& a,
232         const std::vector<mx_float>& b) { return a.size() < b.size(); });
233 
234     sequence_length = sequences.back().size();
235     random = new Shuffler(N);
236     // We still can get random results if call Reset() firstly
237 //    std::vector<vector<mx_float>>* target = &sequences;
238 //    random->shuffle([target](int n, int i) { (*target)[n].swap((*target)[i]); });
239     end = N / batch;
240   }
~BucketSentenceIter()241   virtual ~BucketSentenceIter() {
242     delete random;
243   }
244 
maxSequenceLength()245   unsigned int maxSequenceLength() {
246     return sequence_length;
247   }
248 
characterSize()249   size_t characterSize() {
250     return charIndices.size();
251   }
252 
Next(void)253   virtual bool Next(void) {
254     return ++current < end;
255   }
GetData(void)256   virtual NDArray GetData(void) {
257     const int* indices = random->data();
258     mx_float *data = new mx_float[sequence_length * batch], *pdata = data;
259 
260     for (int i = current * batch, end = i + batch; i < end; i++) {
261       memcpy(pdata, sequences[indices[i]].data(), sequences[indices[i]].size() * sizeof(mx_float));
262       if (sequences[indices[i]].size() < sequence_length)
263         memset(pdata + sequences[indices[i]].size(), 0,
264             (sequence_length - sequences[indices[i]].size()) * sizeof(mx_float));
265       pdata += sequence_length;
266     }
267     NDArray array(Shape(batch, sequence_length), device, false);
268     array.SyncCopyFromCPU(data, batch * sequence_length);
269     return array;
270   }
GetLabel(void)271   virtual NDArray GetLabel(void) {
272     const int* indices = random->data();
273     mx_float *label = new mx_float[sequence_length * batch], *plabel = label;
274 
275     for (int i = current * batch, end = i + batch; i < end; i++) {
276       memcpy(plabel, sequences[indices[i]].data() + 1,
277           (sequences[indices[i]].size() - 1) * sizeof(mx_float));
278       memset(plabel + sequences[indices[i]].size() - 1, 0,
279           (sequence_length - sequences[indices[i]].size() + 1) * sizeof(mx_float));
280       plabel += sequence_length;
281     }
282     NDArray array(Shape(batch, sequence_length), device, false);
283     array.SyncCopyFromCPU(label, batch * sequence_length);
284     return array;
285   }
GetPadNum(void)286   virtual int GetPadNum(void) {
287     return sequence_length - sequences[random->data()[current * batch]].size();
288   }
GetIndex(void)289   virtual std::vector<int> GetIndex(void) {
290     const int* indices = random->data();
291     std::vector<int> list(indices + current * batch, indices + current * batch + batch);
292     return list;
293   }
BeforeFirst(void)294   virtual void BeforeFirst(void) {
295     current = -1;
296     random->shuffle(nullptr);
297   }
298 
readContent(const std::string file)299   std::wstring readContent(const std::string file) {
300     std::wifstream ifs(file, std::ios::binary);
301     if (ifs) {
302       std::wostringstream os;
303       os << ifs.rdbuf();
304       return os.str();
305     }
306     return L"";
307   }
308 
buildCharIndex(const std::wstring & content)309   void buildCharIndex(const std::wstring& content) {
310   // This version buildCharIndex() Compatiable with python version char_rnn dictionary
311     int n = 1;
312     charIndices['\0'] = 0;  // padding character
313     index2chars.push_back(0);  // padding character index
314     for (auto c : content)
315       if (charIndices.find(c) == charIndices.end()) {
316         charIndices[c] = n++;
317         index2chars.push_back(c);
318       }
319   }
320 //  void buildCharIndex(wstring& content) {
321 //    for (auto c : content)
322 //      charIndices[c]++; // char-frequency map; then char-index map
323 //    std::vector<tuple<wchar_t, mx_float>> characters;
324 //    for (auto& iter : charIndices)
325 //      characters.push_back(make_tuple(iter.first, iter.second));
326 //    sort(characters.begin(), characters.end(), [](const tuple<wchar_t, mx_float>& a,
327 //      const tuple<wchar_t, mx_float>& b) { return get<1>(a) > get<1>(b); });
328 //    mx_float index = 1; //0 is left for zero-padding
329 //    index2chars.clear();
330 //    index2chars.push_back(0); //zero-padding
331 //    for (auto& t : characters) {
332 //      charIndices[get<0>(t)] = index++;
333 //      index2chars.push_back(get<0>(t));
334 //    }s
335 //  }
336 
character(int i)337   inline wchar_t character(int i) {
338     return index2chars[i];
339   }
340 
index(wchar_t c)341   inline mx_float index(wchar_t c) {
342     return charIndices[c];
343   }
344 
saveCharIndices(const std::string file)345   void saveCharIndices(const std::string file) {
346     std::wofstream ofs(file, std::ios::binary);
347     if (ofs) {
348       ofs.write(index2chars.data() + 1, index2chars.size() - 1);
349       ofs.close();
350     }
351   }
352 
loadCharIndices(const std::string file)353   static std::tuple<std::unordered_map<wchar_t, mx_float>, std::vector<wchar_t>> loadCharIndices(
354       const std::string file) {
355     std::wifstream ifs(file, std::ios::binary);
356     std::unordered_map<wchar_t, mx_float> map;
357     std::vector<wchar_t> chars;
358     if (ifs) {
359       std::wostringstream os;
360       os << ifs.rdbuf();
361       int n = 1;
362       map[L'\0'] = 0;
363       chars.push_back(L'\0');
364       for (auto c : os.str()) {
365         map[c] = (mx_float) n++;
366         chars.push_back(c);
367       }
368     }
369     // Note: Can't use {} because this would hit the explicit constructor
370     return std::tuple<std::unordered_map<wchar_t, mx_float>, std::vector<wchar_t>>(map, chars);
371   }
372 
373   std::vector<std::vector<mx_float>>
convertTextToSequences(const std::wstring & content,wchar_t spliter)374   convertTextToSequences(const std::wstring& content, wchar_t spliter) {
375     std::vector<std::vector<mx_float>> sequences;
376     sequences.push_back(std::vector<mx_float>());
377     for (auto c : content)
378       if (c == spliter && !sequences.back().empty())
379         sequences.push_back(std::vector<mx_float>());
380       else
381         sequences.back().push_back(charIndices[c]);
382     return sequences;
383   }
384 };
385 
OutputPerplexity(NDArray * labels,NDArray * output)386 void OutputPerplexity(NDArray* labels, NDArray* output) {
387   std::vector<mx_float> charIndices, a;
388   labels->SyncCopyToCPU(&charIndices, 0L);  // 0L indicates all
389   output->SyncCopyToCPU(&a, 0L)/*4128*84*/;
390   mx_float loss = 0;
391   int batchSize = labels->GetShape()[0]/*32*/, sequenceLength = labels->GetShape()[1]/*129*/,
392       nSamples = output->GetShape()[0]/*4128*/, vocabSize = output->GetShape()[1]/*84*/;
393   for (int n = 0; n < nSamples; n++) {
394     int row = n % batchSize, column = n / batchSize, labelOffset = column +
395         row * sequenceLength;  // Search based on column storage: labels.T
396     mx_float safe_value = std::max(1e-10f, a[vocabSize * n +
397                                     static_cast<int>(charIndices[labelOffset])]);
398     loss += -log(safe_value);  // Calculate negative log-likelihood
399   }
400   loss = exp(loss / nSamples);
401   std::cout << "Train-Perplexity=" << loss << std::endl;
402 }
403 
SaveCheckpoint(const std::string filepath,Symbol net,Executor * exe)404 void SaveCheckpoint(const std::string filepath, Symbol net, Executor* exe) {
405   std::map<std::string, NDArray> params;
406   for (auto iter : exe->arg_dict())
407     if (iter.first.find("_init_") == std::string::npos
408         && iter.first.rfind("data") != iter.first.length() - 4
409         && iter.first.rfind("label") != iter.first.length() - 5)
410       params.insert({"arg:" + iter.first, iter.second});
411   for (auto iter : exe->aux_dict())
412       params.insert({"aux:" + iter.first, iter.second});
413   NDArray::Save(filepath, params);
414 }
415 
LoadCheckpoint(const std::string filepath,Executor * exe)416 void LoadCheckpoint(const std::string filepath, Executor* exe) {
417   std::map<std::string, NDArray> params = NDArray::LoadToMap(filepath);
418   for (auto iter : params) {
419     std::string type = iter.first.substr(0, 4);
420     std::string name = iter.first.substr(4);
421     NDArray target;
422     if (type == "arg:")
423       target = exe->arg_dict()[name];
424     else if (type == "aux:")
425       target = exe->aux_dict()[name];
426     else
427       continue;
428     iter.second.CopyTo(&target);
429   }
430 }
431 
432 int input_dim = 0;/*84*/
433 int sequence_length_max = 0;/*129*/
434 int num_embed = 256;
435 int num_lstm_layer = 3;
436 int num_hidden = 512;
437 mx_float dropout = 0.2;
train(const std::string file,int batch_size,int max_epoch,int start_epoch)438 void train(const std::string file, int batch_size, int max_epoch, int start_epoch) {
439   Context device(DeviceType::kGPU, 0);
440   BucketSentenceIter dataIter(file, batch_size, device);
441   std::string prefix = file.substr(0, file.rfind("."));
442   dataIter.saveCharIndices(prefix + ".dictionary");
443 
444   input_dim = static_cast<int>(dataIter.characterSize());
445   sequence_length_max = dataIter.maxSequenceLength();
446 
447   auto RNN = LSTMUnroll(num_lstm_layer, sequence_length_max, input_dim, num_hidden,
448       num_embed, dropout);
449   std::map<std::string, NDArray> args_map;
450   args_map["data"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
451   args_map["softmax_label"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
452   for (int i = 0; i < num_lstm_layer; i++) {
453     std::string key = "l" + std::to_string(i) + "_init_";
454     args_map[key + "c"] = NDArray(Shape(batch_size, num_hidden), device, false);
455     args_map[key + "h"] = NDArray(Shape(batch_size, num_hidden), device, false);
456   }
457   std::vector<mx_float> zeros(batch_size * num_hidden, 0);
458   // RNN.SimpleBind(device, args_map, {}, {{"data", kNullOp}});
459   Executor* exe = RNN.SimpleBind(device, args_map);
460 
461   if (start_epoch == -1) {
462     Xavier xavier = Xavier(Xavier::gaussian, Xavier::in, 2.34);
463     for (auto &arg : exe->arg_dict())
464       xavier(arg.first, &arg.second);
465   } else {
466     LoadCheckpoint(prefix + "-" + std::to_string(start_epoch) + ".params", exe);
467   }
468   start_epoch++;
469 
470   mx_float learning_rate = 0.0002;
471   mx_float weight_decay = 0.000002;
472   Optimizer* opt = OptimizerRegistry::Find("sgd");
473   opt->SetParam("lr", learning_rate)
474      ->SetParam("wd", weight_decay);
475 //  opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
476 //  ->SetParam("clip_gradient", 10);
477 
478   for (int epoch = start_epoch; epoch < max_epoch; ++epoch) {
479     dataIter.Reset();
480     auto tic =  std::chrono::system_clock::now();
481     while (dataIter.Next()) {
482       auto data_batch = dataIter.GetDataBatch();
483       data_batch.data.CopyTo(&exe->arg_dict()["data"]);
484       data_batch.label.CopyTo(&exe->arg_dict()["softmax_label"]);
485       for (int l = 0; l < num_lstm_layer; l++) {
486         std::string key = "l" + std::to_string(l) + "_init_";
487         exe->arg_dict()[key + "c"].SyncCopyFromCPU(zeros);
488         exe->arg_dict()[key + "h"].SyncCopyFromCPU(zeros);
489       }
490       NDArray::WaitAll();
491 
492       exe->Forward(true);
493       exe->Backward();
494       for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
495         opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
496       }
497 
498       NDArray::WaitAll();
499     }
500     auto toc =  std::chrono::system_clock::now();
501     std::cout << "Epoch[" << epoch << "] Time Cost:" <<
502          std::chrono::duration_cast< std::chrono::seconds>(toc - tic).count() << " seconds ";
503     OutputPerplexity(&exe->arg_dict()["softmax_label"], &exe->outputs[0]);
504     std::string filepath = prefix + "-" + std::to_string(epoch) + ".params";
505     SaveCheckpoint(filepath, RNN, exe);
506   }
507 
508   delete exe;
509   delete opt;
510 }
511 
512 /*The original example, rnn_cell_demo.py, uses default Xavier as initalizer, which relies on
513  * variable name, cannot initialize LSTM_parameters. Thus it was renamed to LSTM_bias,
514  * which can be initialized as zero. But it cannot converge after 100 epochs in this corpus
515  * example. Using RNNXavier, after 15 oscillating epochs,  it rapidly converges like old
516  * LSTMUnroll version. */
517 class RNNXavier : public Xavier {
518  public:
RNNXavier(RandType rand_type=gaussian,FactorType factor_type=avg,float magnitude=3)519   RNNXavier(RandType rand_type = gaussian, FactorType factor_type = avg,
520     float magnitude = 3) : Xavier(rand_type, factor_type, magnitude) {
521   }
~RNNXavier()522   virtual ~RNNXavier() {}
523  protected:
InitDefault(NDArray * arr)524   virtual void InitDefault(NDArray* arr) {
525     Xavier::InitWeight(arr);
526   }
527 };
528 
trainWithBuiltInRNNOp(const std::string file,int batch_size,int max_epoch,int start_epoch)529 void trainWithBuiltInRNNOp(const std::string file, int batch_size, int max_epoch, int start_epoch) {
530   Context device(DeviceType::kGPU, 0);
531   BucketSentenceIter dataIter(file, batch_size, device);
532   std::string prefix = file.substr(0, file.rfind("."));
533   dataIter.saveCharIndices(prefix + ".dictionary");
534 
535   input_dim = static_cast<int>(dataIter.characterSize());
536   sequence_length_max = dataIter.maxSequenceLength();
537 
538   auto RNN = LSTMWithBuiltInRNNOp(num_lstm_layer, sequence_length_max, input_dim, num_hidden,
539       num_embed, dropout);
540   std::map<std::string, NDArray> args_map;
541   args_map["data"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
542   // Avoiding SwapAxis, batch_size is of second dimension.
543   args_map["LSTM_init_c"] = NDArray(Shape(num_lstm_layer, batch_size, num_hidden), device, false);
544   args_map["LSTM_init_h"] = NDArray(Shape(num_lstm_layer, batch_size, num_hidden), device, false);
545   args_map["softmax_label"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
546   std::vector<mx_float> zeros(batch_size * num_lstm_layer * num_hidden, 0);
547   Executor* exe = RNN.SimpleBind(device, args_map);
548 
549   if (start_epoch == -1) {
550     RNNXavier xavier = RNNXavier(Xavier::gaussian, Xavier::in, 2.34);
551     for (auto &arg : exe->arg_dict())
552       xavier(arg.first, &arg.second);
553   } else {
554     LoadCheckpoint(prefix + "-" + std::to_string(start_epoch) + ".params", exe);
555   }
556   start_epoch++;
557 
558   Optimizer* opt = OptimizerRegistry::Find("ccsgd");
559 //  opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
560 //  ->SetParam("clip_gradient", 10);
561 
562   for (int epoch = start_epoch; epoch < max_epoch; ++epoch) {
563     dataIter.Reset();
564     auto tic =  std::chrono::system_clock::now();
565     while (dataIter.Next()) {
566       auto data_batch = dataIter.GetDataBatch();
567       data_batch.data.CopyTo(&exe->arg_dict()["data"]);
568       data_batch.label.CopyTo(&exe->arg_dict()["softmax_label"]);
569       exe->arg_dict()["LSTM_init_c"].SyncCopyFromCPU(zeros);
570       exe->arg_dict()["LSTM_init_h"].SyncCopyFromCPU(zeros);
571       NDArray::WaitAll();
572 
573       exe->Forward(true);
574       exe->Backward();
575       for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
576         opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
577       }
578       NDArray::WaitAll();
579     }
580     auto toc =  std::chrono::system_clock::now();
581     std::cout << "Epoch[" << epoch << "] Time Cost:" <<
582          std::chrono::duration_cast< std::chrono::seconds>(toc - tic).count() << " seconds ";
583     OutputPerplexity(&exe->arg_dict()["softmax_label"], &exe->outputs[0]);
584     std::string filepath = prefix + "-" + std::to_string(epoch) + ".params";
585     SaveCheckpoint(filepath, RNN, exe);
586   }
587 
588   delete exe;
589   delete opt;
590 }
591 
predict(std::wstring * ptext,int sequence_length,const std::string param_file,const std::string dictionary_file)592 void predict(std::wstring* ptext, int sequence_length, const std::string param_file,
593     const std::string dictionary_file) {
594   Context device(DeviceType::kGPU, 0);
595   auto results = BucketSentenceIter::loadCharIndices(dictionary_file);
596   auto dictionary = std::get<0>(results);
597   auto charIndices = std::get<1>(results);
598   input_dim = static_cast<int>(charIndices.size());
599   auto RNN = LSTMUnroll(num_lstm_layer, 1, input_dim, num_hidden, num_embed, 0);
600 
601   std::map<std::string, NDArray> args_map;
602   args_map["data"] = NDArray(Shape(1, 1), device, false);
603   args_map["softmax_label"] = NDArray(Shape(1, 1), device, false);
604   std::vector<mx_float> zeros(1 * num_hidden, 0);
605   for (int l = 0; l < num_lstm_layer; l++) {
606     std::string key = "l" + std::to_string(l) + "_init_";
607     args_map[key + "c"] = NDArray(Shape(1, num_hidden), device, false);
608     args_map[key + "h"] = NDArray(Shape(1, num_hidden), device, false);
609     args_map[key + "c"].SyncCopyFromCPU(zeros);
610     args_map[key + "h"].SyncCopyFromCPU(zeros);
611   }
612   Executor* exe = RNN.SimpleBind(device, args_map);
613   LoadCheckpoint(param_file, exe);
614 
615   mx_float index;
616   wchar_t next = 0;
617   std::vector<mx_float> softmax;
618   softmax.resize(input_dim);
619   for (auto c : *ptext) {
620     exe->arg_dict()["data"].SyncCopyFromCPU(&dictionary[c], 1);
621     exe->Forward(false);
622 
623     exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
624     for (int l = 0; l < num_lstm_layer; l++) {
625       std::string key = "l" + std::to_string(l) + "_init_";
626       exe->outputs[l * 2 + 1].CopyTo(&args_map[key + "c"]);
627       exe->outputs[l * 2 + 2].CopyTo(&args_map[key + "h"]);
628     }
629 
630     size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
631     index = (mx_float) n;
632     next = charIndices[n];
633   }
634   ptext->push_back(next);
635 
636   for (int i = 0; i < sequence_length; i++) {
637     exe->arg_dict()["data"].SyncCopyFromCPU(&index, 1);
638     exe->Forward(false);
639 
640     exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
641     for (int l = 0; l < num_lstm_layer; l++) {
642       std::string key = "l" + std::to_string(l) + "_init_";
643       exe->outputs[l * 2 + 1].CopyTo(&args_map[key + "c"]);
644       exe->outputs[l * 2 + 2].CopyTo(&args_map[key + "h"]);
645     }
646 
647     size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
648     index = (mx_float) n;
649     next = charIndices[n];
650     ptext->push_back(next);
651   }
652 
653   delete exe;
654 }
655 
predictWithBuiltInRNNOp(std::wstring * ptext,int sequence_length,const std::string param_file,const std::string dictionary_file)656 void predictWithBuiltInRNNOp(std::wstring* ptext, int sequence_length, const std::string param_file,
657   const std::string dictionary_file) {
658   Context device(DeviceType::kGPU, 0);
659   auto results = BucketSentenceIter::loadCharIndices(dictionary_file);
660   auto dictionary = std::get<0>(results);
661   auto charIndices = std::get<1>(results);
662   input_dim = static_cast<int>(charIndices.size());
663   auto RNN = LSTMWithBuiltInRNNOp(num_lstm_layer, 1, input_dim, num_hidden, num_embed, 0);
664 
665   std::map<std::string, NDArray> args_map;
666   args_map["data"] = NDArray(Shape(1, 1), device, false);
667   args_map["softmax_label"] = NDArray(Shape(1, 1), device, false);
668   std::vector<mx_float> zeros(1 * num_lstm_layer * num_hidden, 0);
669   // Avoiding SwapAxis, batch_size=1 is of second dimension.
670   args_map["LSTM_init_c"] = NDArray(Shape(num_lstm_layer, 1, num_hidden), device, false);
671   args_map["LSTM_init_h"] = NDArray(Shape(num_lstm_layer, 1, num_hidden), device, false);
672   args_map["LSTM_init_c"].SyncCopyFromCPU(zeros);
673   args_map["LSTM_init_h"].SyncCopyFromCPU(zeros);
674   Executor* exe = RNN.SimpleBind(device, args_map);
675   LoadCheckpoint(param_file, exe);
676 
677   mx_float index;
678   wchar_t next = 0;
679   std::vector<mx_float> softmax;
680   softmax.resize(input_dim);
681   for (auto c : *ptext) {
682     exe->arg_dict()["data"].SyncCopyFromCPU(&dictionary[c], 1);
683     exe->Forward(false);
684 
685     exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
686     exe->outputs[1].CopyTo(&args_map["LSTM_init_h"]);
687     exe->outputs[2].CopyTo(&args_map["LSTM_init_c"]);
688 
689     size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
690     index = (mx_float) n;
691     next = charIndices[n];
692   }
693   ptext->push_back(next);
694 
695   for (int i = 0; i < sequence_length; i++) {
696     exe->arg_dict()["data"].SyncCopyFromCPU(&index, 1);
697     exe->Forward(false);
698 
699     exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
700     exe->outputs[1].CopyTo(&args_map["LSTM_init_h"]);
701     exe->outputs[2].CopyTo(&args_map["LSTM_init_c"]);
702 
703     size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
704     index = (mx_float) n;
705     next = charIndices[n];
706     ptext->push_back(next);
707   }
708 
709   delete exe;
710 }
711 
main(int argc,char ** argv)712 int main(int argc, char** argv) {
713   if (argc < 5) {
714     std::cout << "Usage for training: charRNN train[BuiltIn][TimeMajor] {corpus file}"
715             " {batch size} {max epoch} [{starting epoch}]" << std::endl;
716     std::cout <<"Usage for prediction: charRNN predict[BuiltIn][TimeMajor] {params file}"
717             " {dictionary file} {beginning of text}" << std::endl;
718     std::cout <<"Note: The {params file} of train/trainBuiltIn/trainTimeMajor/trainBuiltInTimeMajor"
719             " are not compatible with each other." << std::endl;
720     return 0;
721   }
722 
723   std::string task = argv[1];
724   bool builtIn = task.find("BuiltIn") != std::string::npos;
725   TIME_MAJOR = task.find("TimeMajor") != std::string::npos;
726   std::cout << "use BuiltIn cuDNN RNN: " << builtIn << std::endl
727          << "use data as TimeMajor: " << TIME_MAJOR << std::endl;
728   TRY
729   if (task.find("train") == 0) {
730     std::cout << "train batch size:      " << argv[3] << std::endl
731            << "train max epoch:       " << argv[4] << std::endl;
732     int start_epoch = argc > 5? atoi(argv[5]) : -1;
733     // this function will generate dictionary file and params file.
734     if (builtIn)
735       trainWithBuiltInRNNOp(argv[2], atoi(argv[3]), atoi(argv[4]), start_epoch);
736     else
737       train(argv[2], atoi(argv[3]), atoi(argv[4]), start_epoch);  // ditto
738   } else if (task.find("predict") == 0) {
739     std::wstring text;  // = L"If there is anyone out there who still doubts ";
740     // Considering of extending to Chinese samples in future, use wchar_t instead of char
741     for (char c : std::string(argv[4]))
742       text.push_back((wchar_t) c);
743     /*Python version predicts text default to random selecltions. Here I didn't write the random
744     code, always choose the 'best' character. So the text length reduced to 600. Longer size often
745     leads to repeated sentances, since training sequence length is only 129 for obama corpus.*/
746     if (builtIn)
747       predictWithBuiltInRNNOp(&text, 600, argv[2], argv[3]);
748     else
749       predict(&text, 600, argv[2], argv[3]);
750     std::wcout << text << std::endl;
751   }
752 
753   MXNotifyShutdown();
754   CATCH
755   return 0;
756 }
757