1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * Hua Zhang mz24cn@hotmail.com
22 * The code implements C++ version charRNN for mxnet\example\rnn\char-rnn.ipynb with MXNet.cpp API.
23 * The generated params file is compatiable with python version.
24 * train() and predict() has been verified with original data samples.
25 * 2017/1/23:
26 * Add faster version charRNN based on built-in cuDNN RNN operator, 10 times faster.
27 * Add time major computation graph, although no substantial performance difference.
28 * Support continuing training from last params file.
29 * Rename params file epoch number starts from zero.
30 */
31
32 #if _MSC_VER
33 #pragma warning(disable: 4996) // VS2015 complains on 'std::copy' ...
34 #endif
35 #include <cstring>
36 #include <iostream>
37 #include <fstream>
38 #include <unordered_map>
39 #include <vector>
40 #include <string>
41 #include <tuple>
42 #include <algorithm>
43 #include <functional>
44 #include <thread>
45 #include <chrono>
46 #include "mxnet-cpp/MxNetCpp.h"
47 #include "utils.h"
48
49 using namespace mxnet::cpp;
50
51 struct LSTMState {
52 Symbol C;
53 Symbol h;
54 };
55
56 struct LSTMParam {
57 Symbol i2h_weight;
58 Symbol i2h_bias;
59 Symbol h2h_weight;
60 Symbol h2h_bias;
61 };
62
63 bool TIME_MAJOR = true;
64
65 // LSTM Cell symbol
LSTM(int num_hidden,const Symbol & indata,const LSTMState & prev_state,const LSTMParam & param,int seqidx,int layeridx,mx_float dropout=0)66 LSTMState LSTM(int num_hidden, const Symbol& indata, const LSTMState& prev_state,
67 const LSTMParam& param, int seqidx, int layeridx, mx_float dropout = 0) {
68 auto input = dropout > 0? Dropout(indata, dropout) : indata;
69 auto prefix = std::string("t") + std::to_string(seqidx) + "_l" + std::to_string(layeridx);
70 auto i2h = FullyConnected(prefix + "_i2h", input, param.i2h_weight, param.i2h_bias,
71 num_hidden * 4);
72 auto h2h = FullyConnected(prefix + "_h2h", prev_state.h, param.h2h_weight, param.h2h_bias,
73 num_hidden * 4);
74 auto gates = i2h + h2h;
75 auto slice_gates = SliceChannel(prefix + "_slice", gates, 4);
76 auto in_gate = Activation(slice_gates[0], ActivationActType::kSigmoid);
77 auto in_transform = Activation(slice_gates[1], ActivationActType::kTanh);
78 auto forget_gate = Activation(slice_gates[2], ActivationActType::kSigmoid);
79 auto out_gate = Activation(slice_gates[3], ActivationActType::kSigmoid);
80
81 LSTMState state;
82 state.C = (forget_gate * prev_state.C) + (in_gate * in_transform);
83 state.h = out_gate * Activation(state.C, ActivationActType::kTanh);
84 return state;
85 }
86
LSTMUnroll(int num_lstm_layer,int sequence_length,int input_dim,int num_hidden,int num_embed,mx_float dropout=0)87 Symbol LSTMUnroll(int num_lstm_layer, int sequence_length, int input_dim,
88 int num_hidden, int num_embed, mx_float dropout = 0) {
89 auto isTrain = sequence_length > 1;
90 auto data = Symbol::Variable("data");
91 if (TIME_MAJOR && isTrain)
92 data = transpose(data);
93 auto embed_weight = Symbol::Variable("embed_weight");
94 auto embed = Embedding("embed", data, embed_weight, input_dim, num_embed);
95 auto wordvec = isTrain? SliceChannel(embed, sequence_length, TIME_MAJOR? 0 : 1, true) : embed;
96
97 std::vector<LSTMState> last_states;
98 std::vector<LSTMParam> param_cells;
99 for (int l = 0; l < num_lstm_layer; l++) {
100 std::string layer = "l" + std::to_string(l);
101 LSTMParam param;
102 param.i2h_weight = Symbol::Variable(layer + "_i2h_weight");
103 param.i2h_bias = Symbol::Variable(layer + "_i2h_bias");
104 param.h2h_weight = Symbol::Variable(layer + "_h2h_weight");
105 param.h2h_bias = Symbol::Variable(layer + "_h2h_bias");
106 param_cells.push_back(param);
107 LSTMState state;
108 state.C = Symbol::Variable(layer + "_init_c");
109 state.h = Symbol::Variable(layer + "_init_h");
110 last_states.push_back(state);
111 }
112
113 std::vector<Symbol> hidden_all;
114 for (int i = 0; i < sequence_length; i++) {
115 auto hidden = wordvec[i];
116 for (int layer = 0; layer < num_lstm_layer; layer++) {
117 double dp_ratio = layer == 0? 0 : dropout;
118 auto next_state = LSTM(num_hidden, hidden, last_states[layer], param_cells[layer],
119 i, layer, dp_ratio);
120 hidden = next_state.h;
121 last_states[layer] = next_state;
122 }
123 if (dropout > 0)
124 hidden = Dropout(hidden, dropout);
125 hidden_all.push_back(hidden);
126 }
127
128 auto hidden_concat = isTrain? Concat(hidden_all, hidden_all.size(), 0) : hidden_all[0];
129 auto cls_weight = Symbol::Variable("cls_weight");
130 auto cls_bias = Symbol::Variable("cls_bias");
131 auto pred = FullyConnected("pred", hidden_concat, cls_weight, cls_bias, input_dim);
132
133 auto label = Symbol::Variable("softmax_label");
134 label = transpose(label);
135 label = Reshape(label, Shape(), false, Shape(0), false); // -1: infer from graph
136 auto sm = SoftmaxOutput("softmax", pred, label);
137 if (isTrain)
138 return sm;
139
140 std::vector<Symbol> outputs = { sm };
141 for (auto& state : last_states) {
142 outputs.push_back(state.C);
143 outputs.push_back(state.h);
144 }
145 return Symbol::Group(outputs);
146 }
147
148 // Currently mxnet GPU version RNN operator is implemented via *fast* NVIDIA cuDNN.
LSTMWithBuiltInRNNOp(int num_lstm_layer,int sequence_length,int input_dim,int num_hidden,int num_embed,mx_float dropout=0)149 Symbol LSTMWithBuiltInRNNOp(int num_lstm_layer, int sequence_length, int input_dim,
150 int num_hidden, int num_embed, mx_float dropout = 0) {
151 auto isTrain = sequence_length > 1;
152 auto data = Symbol::Variable("data");
153 if (TIME_MAJOR && isTrain)
154 data = transpose(data);
155
156 auto embed_weight = Symbol::Variable("embed_weight");
157 auto embed = Embedding("embed", data, embed_weight, input_dim, num_embed);
158 auto label = Symbol::Variable("softmax_label");
159 label = transpose(label);
160 label = Reshape(label, Shape(), false,
161 Shape(0), false); // FullyConnected requires one dimension
162 if (!TIME_MAJOR && isTrain)
163 embed = SwapAxis(embed, 0, 1); // Change to time-major as cuDNN requires
164
165 // We need not do the SwapAxis op as python version does. Direct and better performance in C++!
166 auto rnn_h_init = Symbol::Variable("LSTM_init_h");
167 auto rnn_c_init = Symbol::Variable("LSTM_init_c");
168 auto rnn_params = Symbol::Variable("LSTM_parameters"); // See explanations near RNNXavier class
169 auto variable_sequence_length = Symbol::Variable("sequence_length");
170 auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, variable_sequence_length, num_hidden,
171 num_lstm_layer, RNNMode::kLstm, false, dropout, !isTrain);
172 auto hidden = Reshape(rnn[0], Shape(), false, Shape(0, num_hidden), false);
173
174 auto cls_weight = Symbol::Variable("cls_weight");
175 auto cls_bias = Symbol::Variable("cls_bias");
176 auto pred = FullyConnected("pred", hidden, cls_weight, cls_bias, input_dim);
177 /*In rnn-time-major/rnn_cell_demo.py, the author claimed time-major version speeds up
178 * 1.5~2 times versus batch version. I doubts on the conclusion. In my test, the performance
179 * of both codes are almost same. In fact, there are no substantially differences between
180 * two codes. They are both based on time major cuDNN, the computation graph only differs
181 * slightly on the choices of where to put Reshape/SwapAxis/transpose operation. Here I don't
182 * use Reshape on pred and keep label shape on SoftmaxOutput like time major version code,
183 * but Reshape on label for simplification. It doesn't make influence on performacne. */
184
185 auto sm = SoftmaxOutput("softmax", pred, label);
186 if (isTrain)
187 return sm;
188 else
189 return Symbol::Group({ sm, rnn[1/*RNNOpOutputs::kStateOut=1*/],
190 rnn[2/*RNNOpOutputs::kStateCellOut=2*/] });
191 }
192
193 class Shuffler {
194 std::vector<int> sequence;
195 public:
Shuffler(int size)196 explicit Shuffler(int size) : sequence(size) {
197 int* p = sequence.data();
198 for (int i = 0; i < size; i++)
199 *p++ = i;
200 }
shuffle(std::function<void (int,int)> lambda=nullptr)201 void shuffle(std::function<void(int, int)> lambda = nullptr) {
202 random_shuffle(sequence.begin(), sequence.end());
203 int n = 0;
204 if (lambda != nullptr)
205 for (int i : sequence)
206 lambda(n++, i);
207 }
data()208 const int* data() {
209 return sequence.data();
210 }
211 };
212
213 class BucketSentenceIter : public DataIter {
214 Shuffler* random;
215 int batch, current, end;
216 unsigned int sequence_length;
217 Context device;
218 std::vector<std::vector<mx_float>> sequences;
219 std::vector<wchar_t> index2chars;
220 std::unordered_map<wchar_t, mx_float> charIndices;
221
222 public:
BucketSentenceIter(std::string filename,int minibatch,Context context)223 BucketSentenceIter(std::string filename, int minibatch, Context context) : batch(minibatch),
224 current(-1), device(context) {
225 auto content = readContent(filename);
226 buildCharIndex(content);
227 sequences = convertTextToSequences(content, '\n');
228
229 int N = sequences.size() / batch * batch; // total used samples
230 sequences.resize(N);
231 sort(sequences.begin(), sequences.end(), [](const std::vector<mx_float>& a,
232 const std::vector<mx_float>& b) { return a.size() < b.size(); });
233
234 sequence_length = sequences.back().size();
235 random = new Shuffler(N);
236 // We still can get random results if call Reset() firstly
237 // std::vector<vector<mx_float>>* target = &sequences;
238 // random->shuffle([target](int n, int i) { (*target)[n].swap((*target)[i]); });
239 end = N / batch;
240 }
~BucketSentenceIter()241 virtual ~BucketSentenceIter() {
242 delete random;
243 }
244
maxSequenceLength()245 unsigned int maxSequenceLength() {
246 return sequence_length;
247 }
248
characterSize()249 size_t characterSize() {
250 return charIndices.size();
251 }
252
Next(void)253 virtual bool Next(void) {
254 return ++current < end;
255 }
GetData(void)256 virtual NDArray GetData(void) {
257 const int* indices = random->data();
258 mx_float *data = new mx_float[sequence_length * batch], *pdata = data;
259
260 for (int i = current * batch, end = i + batch; i < end; i++) {
261 memcpy(pdata, sequences[indices[i]].data(), sequences[indices[i]].size() * sizeof(mx_float));
262 if (sequences[indices[i]].size() < sequence_length)
263 memset(pdata + sequences[indices[i]].size(), 0,
264 (sequence_length - sequences[indices[i]].size()) * sizeof(mx_float));
265 pdata += sequence_length;
266 }
267 NDArray array(Shape(batch, sequence_length), device, false);
268 array.SyncCopyFromCPU(data, batch * sequence_length);
269 return array;
270 }
GetLabel(void)271 virtual NDArray GetLabel(void) {
272 const int* indices = random->data();
273 mx_float *label = new mx_float[sequence_length * batch], *plabel = label;
274
275 for (int i = current * batch, end = i + batch; i < end; i++) {
276 memcpy(plabel, sequences[indices[i]].data() + 1,
277 (sequences[indices[i]].size() - 1) * sizeof(mx_float));
278 memset(plabel + sequences[indices[i]].size() - 1, 0,
279 (sequence_length - sequences[indices[i]].size() + 1) * sizeof(mx_float));
280 plabel += sequence_length;
281 }
282 NDArray array(Shape(batch, sequence_length), device, false);
283 array.SyncCopyFromCPU(label, batch * sequence_length);
284 return array;
285 }
GetPadNum(void)286 virtual int GetPadNum(void) {
287 return sequence_length - sequences[random->data()[current * batch]].size();
288 }
GetIndex(void)289 virtual std::vector<int> GetIndex(void) {
290 const int* indices = random->data();
291 std::vector<int> list(indices + current * batch, indices + current * batch + batch);
292 return list;
293 }
BeforeFirst(void)294 virtual void BeforeFirst(void) {
295 current = -1;
296 random->shuffle(nullptr);
297 }
298
readContent(const std::string file)299 std::wstring readContent(const std::string file) {
300 std::wifstream ifs(file, std::ios::binary);
301 if (ifs) {
302 std::wostringstream os;
303 os << ifs.rdbuf();
304 return os.str();
305 }
306 return L"";
307 }
308
buildCharIndex(const std::wstring & content)309 void buildCharIndex(const std::wstring& content) {
310 // This version buildCharIndex() Compatiable with python version char_rnn dictionary
311 int n = 1;
312 charIndices['\0'] = 0; // padding character
313 index2chars.push_back(0); // padding character index
314 for (auto c : content)
315 if (charIndices.find(c) == charIndices.end()) {
316 charIndices[c] = n++;
317 index2chars.push_back(c);
318 }
319 }
320 // void buildCharIndex(wstring& content) {
321 // for (auto c : content)
322 // charIndices[c]++; // char-frequency map; then char-index map
323 // std::vector<tuple<wchar_t, mx_float>> characters;
324 // for (auto& iter : charIndices)
325 // characters.push_back(make_tuple(iter.first, iter.second));
326 // sort(characters.begin(), characters.end(), [](const tuple<wchar_t, mx_float>& a,
327 // const tuple<wchar_t, mx_float>& b) { return get<1>(a) > get<1>(b); });
328 // mx_float index = 1; //0 is left for zero-padding
329 // index2chars.clear();
330 // index2chars.push_back(0); //zero-padding
331 // for (auto& t : characters) {
332 // charIndices[get<0>(t)] = index++;
333 // index2chars.push_back(get<0>(t));
334 // }s
335 // }
336
character(int i)337 inline wchar_t character(int i) {
338 return index2chars[i];
339 }
340
index(wchar_t c)341 inline mx_float index(wchar_t c) {
342 return charIndices[c];
343 }
344
saveCharIndices(const std::string file)345 void saveCharIndices(const std::string file) {
346 std::wofstream ofs(file, std::ios::binary);
347 if (ofs) {
348 ofs.write(index2chars.data() + 1, index2chars.size() - 1);
349 ofs.close();
350 }
351 }
352
loadCharIndices(const std::string file)353 static std::tuple<std::unordered_map<wchar_t, mx_float>, std::vector<wchar_t>> loadCharIndices(
354 const std::string file) {
355 std::wifstream ifs(file, std::ios::binary);
356 std::unordered_map<wchar_t, mx_float> map;
357 std::vector<wchar_t> chars;
358 if (ifs) {
359 std::wostringstream os;
360 os << ifs.rdbuf();
361 int n = 1;
362 map[L'\0'] = 0;
363 chars.push_back(L'\0');
364 for (auto c : os.str()) {
365 map[c] = (mx_float) n++;
366 chars.push_back(c);
367 }
368 }
369 // Note: Can't use {} because this would hit the explicit constructor
370 return std::tuple<std::unordered_map<wchar_t, mx_float>, std::vector<wchar_t>>(map, chars);
371 }
372
373 std::vector<std::vector<mx_float>>
convertTextToSequences(const std::wstring & content,wchar_t spliter)374 convertTextToSequences(const std::wstring& content, wchar_t spliter) {
375 std::vector<std::vector<mx_float>> sequences;
376 sequences.push_back(std::vector<mx_float>());
377 for (auto c : content)
378 if (c == spliter && !sequences.back().empty())
379 sequences.push_back(std::vector<mx_float>());
380 else
381 sequences.back().push_back(charIndices[c]);
382 return sequences;
383 }
384 };
385
OutputPerplexity(NDArray * labels,NDArray * output)386 void OutputPerplexity(NDArray* labels, NDArray* output) {
387 std::vector<mx_float> charIndices, a;
388 labels->SyncCopyToCPU(&charIndices, 0L); // 0L indicates all
389 output->SyncCopyToCPU(&a, 0L)/*4128*84*/;
390 mx_float loss = 0;
391 int batchSize = labels->GetShape()[0]/*32*/, sequenceLength = labels->GetShape()[1]/*129*/,
392 nSamples = output->GetShape()[0]/*4128*/, vocabSize = output->GetShape()[1]/*84*/;
393 for (int n = 0; n < nSamples; n++) {
394 int row = n % batchSize, column = n / batchSize, labelOffset = column +
395 row * sequenceLength; // Search based on column storage: labels.T
396 mx_float safe_value = std::max(1e-10f, a[vocabSize * n +
397 static_cast<int>(charIndices[labelOffset])]);
398 loss += -log(safe_value); // Calculate negative log-likelihood
399 }
400 loss = exp(loss / nSamples);
401 std::cout << "Train-Perplexity=" << loss << std::endl;
402 }
403
SaveCheckpoint(const std::string filepath,Symbol net,Executor * exe)404 void SaveCheckpoint(const std::string filepath, Symbol net, Executor* exe) {
405 std::map<std::string, NDArray> params;
406 for (auto iter : exe->arg_dict())
407 if (iter.first.find("_init_") == std::string::npos
408 && iter.first.rfind("data") != iter.first.length() - 4
409 && iter.first.rfind("label") != iter.first.length() - 5)
410 params.insert({"arg:" + iter.first, iter.second});
411 for (auto iter : exe->aux_dict())
412 params.insert({"aux:" + iter.first, iter.second});
413 NDArray::Save(filepath, params);
414 }
415
LoadCheckpoint(const std::string filepath,Executor * exe)416 void LoadCheckpoint(const std::string filepath, Executor* exe) {
417 std::map<std::string, NDArray> params = NDArray::LoadToMap(filepath);
418 for (auto iter : params) {
419 std::string type = iter.first.substr(0, 4);
420 std::string name = iter.first.substr(4);
421 NDArray target;
422 if (type == "arg:")
423 target = exe->arg_dict()[name];
424 else if (type == "aux:")
425 target = exe->aux_dict()[name];
426 else
427 continue;
428 iter.second.CopyTo(&target);
429 }
430 }
431
432 int input_dim = 0;/*84*/
433 int sequence_length_max = 0;/*129*/
434 int num_embed = 256;
435 int num_lstm_layer = 3;
436 int num_hidden = 512;
437 mx_float dropout = 0.2;
train(const std::string file,int batch_size,int max_epoch,int start_epoch)438 void train(const std::string file, int batch_size, int max_epoch, int start_epoch) {
439 Context device(DeviceType::kGPU, 0);
440 BucketSentenceIter dataIter(file, batch_size, device);
441 std::string prefix = file.substr(0, file.rfind("."));
442 dataIter.saveCharIndices(prefix + ".dictionary");
443
444 input_dim = static_cast<int>(dataIter.characterSize());
445 sequence_length_max = dataIter.maxSequenceLength();
446
447 auto RNN = LSTMUnroll(num_lstm_layer, sequence_length_max, input_dim, num_hidden,
448 num_embed, dropout);
449 std::map<std::string, NDArray> args_map;
450 args_map["data"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
451 args_map["softmax_label"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
452 for (int i = 0; i < num_lstm_layer; i++) {
453 std::string key = "l" + std::to_string(i) + "_init_";
454 args_map[key + "c"] = NDArray(Shape(batch_size, num_hidden), device, false);
455 args_map[key + "h"] = NDArray(Shape(batch_size, num_hidden), device, false);
456 }
457 std::vector<mx_float> zeros(batch_size * num_hidden, 0);
458 // RNN.SimpleBind(device, args_map, {}, {{"data", kNullOp}});
459 Executor* exe = RNN.SimpleBind(device, args_map);
460
461 if (start_epoch == -1) {
462 Xavier xavier = Xavier(Xavier::gaussian, Xavier::in, 2.34);
463 for (auto &arg : exe->arg_dict())
464 xavier(arg.first, &arg.second);
465 } else {
466 LoadCheckpoint(prefix + "-" + std::to_string(start_epoch) + ".params", exe);
467 }
468 start_epoch++;
469
470 mx_float learning_rate = 0.0002;
471 mx_float weight_decay = 0.000002;
472 Optimizer* opt = OptimizerRegistry::Find("sgd");
473 opt->SetParam("lr", learning_rate)
474 ->SetParam("wd", weight_decay);
475 // opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
476 // ->SetParam("clip_gradient", 10);
477
478 for (int epoch = start_epoch; epoch < max_epoch; ++epoch) {
479 dataIter.Reset();
480 auto tic = std::chrono::system_clock::now();
481 while (dataIter.Next()) {
482 auto data_batch = dataIter.GetDataBatch();
483 data_batch.data.CopyTo(&exe->arg_dict()["data"]);
484 data_batch.label.CopyTo(&exe->arg_dict()["softmax_label"]);
485 for (int l = 0; l < num_lstm_layer; l++) {
486 std::string key = "l" + std::to_string(l) + "_init_";
487 exe->arg_dict()[key + "c"].SyncCopyFromCPU(zeros);
488 exe->arg_dict()[key + "h"].SyncCopyFromCPU(zeros);
489 }
490 NDArray::WaitAll();
491
492 exe->Forward(true);
493 exe->Backward();
494 for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
495 opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
496 }
497
498 NDArray::WaitAll();
499 }
500 auto toc = std::chrono::system_clock::now();
501 std::cout << "Epoch[" << epoch << "] Time Cost:" <<
502 std::chrono::duration_cast< std::chrono::seconds>(toc - tic).count() << " seconds ";
503 OutputPerplexity(&exe->arg_dict()["softmax_label"], &exe->outputs[0]);
504 std::string filepath = prefix + "-" + std::to_string(epoch) + ".params";
505 SaveCheckpoint(filepath, RNN, exe);
506 }
507
508 delete exe;
509 delete opt;
510 }
511
512 /*The original example, rnn_cell_demo.py, uses default Xavier as initalizer, which relies on
513 * variable name, cannot initialize LSTM_parameters. Thus it was renamed to LSTM_bias,
514 * which can be initialized as zero. But it cannot converge after 100 epochs in this corpus
515 * example. Using RNNXavier, after 15 oscillating epochs, it rapidly converges like old
516 * LSTMUnroll version. */
517 class RNNXavier : public Xavier {
518 public:
RNNXavier(RandType rand_type=gaussian,FactorType factor_type=avg,float magnitude=3)519 RNNXavier(RandType rand_type = gaussian, FactorType factor_type = avg,
520 float magnitude = 3) : Xavier(rand_type, factor_type, magnitude) {
521 }
~RNNXavier()522 virtual ~RNNXavier() {}
523 protected:
InitDefault(NDArray * arr)524 virtual void InitDefault(NDArray* arr) {
525 Xavier::InitWeight(arr);
526 }
527 };
528
trainWithBuiltInRNNOp(const std::string file,int batch_size,int max_epoch,int start_epoch)529 void trainWithBuiltInRNNOp(const std::string file, int batch_size, int max_epoch, int start_epoch) {
530 Context device(DeviceType::kGPU, 0);
531 BucketSentenceIter dataIter(file, batch_size, device);
532 std::string prefix = file.substr(0, file.rfind("."));
533 dataIter.saveCharIndices(prefix + ".dictionary");
534
535 input_dim = static_cast<int>(dataIter.characterSize());
536 sequence_length_max = dataIter.maxSequenceLength();
537
538 auto RNN = LSTMWithBuiltInRNNOp(num_lstm_layer, sequence_length_max, input_dim, num_hidden,
539 num_embed, dropout);
540 std::map<std::string, NDArray> args_map;
541 args_map["data"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
542 // Avoiding SwapAxis, batch_size is of second dimension.
543 args_map["LSTM_init_c"] = NDArray(Shape(num_lstm_layer, batch_size, num_hidden), device, false);
544 args_map["LSTM_init_h"] = NDArray(Shape(num_lstm_layer, batch_size, num_hidden), device, false);
545 args_map["softmax_label"] = NDArray(Shape(batch_size, sequence_length_max), device, false);
546 std::vector<mx_float> zeros(batch_size * num_lstm_layer * num_hidden, 0);
547 Executor* exe = RNN.SimpleBind(device, args_map);
548
549 if (start_epoch == -1) {
550 RNNXavier xavier = RNNXavier(Xavier::gaussian, Xavier::in, 2.34);
551 for (auto &arg : exe->arg_dict())
552 xavier(arg.first, &arg.second);
553 } else {
554 LoadCheckpoint(prefix + "-" + std::to_string(start_epoch) + ".params", exe);
555 }
556 start_epoch++;
557
558 Optimizer* opt = OptimizerRegistry::Find("ccsgd");
559 // opt->SetParam("momentum", 0.9)->SetParam("rescale_grad", 1.0 / batch_size)
560 // ->SetParam("clip_gradient", 10);
561
562 for (int epoch = start_epoch; epoch < max_epoch; ++epoch) {
563 dataIter.Reset();
564 auto tic = std::chrono::system_clock::now();
565 while (dataIter.Next()) {
566 auto data_batch = dataIter.GetDataBatch();
567 data_batch.data.CopyTo(&exe->arg_dict()["data"]);
568 data_batch.label.CopyTo(&exe->arg_dict()["softmax_label"]);
569 exe->arg_dict()["LSTM_init_c"].SyncCopyFromCPU(zeros);
570 exe->arg_dict()["LSTM_init_h"].SyncCopyFromCPU(zeros);
571 NDArray::WaitAll();
572
573 exe->Forward(true);
574 exe->Backward();
575 for (size_t i = 0; i < exe->arg_arrays.size(); ++i) {
576 opt->Update(i, exe->arg_arrays[i], exe->grad_arrays[i]);
577 }
578 NDArray::WaitAll();
579 }
580 auto toc = std::chrono::system_clock::now();
581 std::cout << "Epoch[" << epoch << "] Time Cost:" <<
582 std::chrono::duration_cast< std::chrono::seconds>(toc - tic).count() << " seconds ";
583 OutputPerplexity(&exe->arg_dict()["softmax_label"], &exe->outputs[0]);
584 std::string filepath = prefix + "-" + std::to_string(epoch) + ".params";
585 SaveCheckpoint(filepath, RNN, exe);
586 }
587
588 delete exe;
589 delete opt;
590 }
591
predict(std::wstring * ptext,int sequence_length,const std::string param_file,const std::string dictionary_file)592 void predict(std::wstring* ptext, int sequence_length, const std::string param_file,
593 const std::string dictionary_file) {
594 Context device(DeviceType::kGPU, 0);
595 auto results = BucketSentenceIter::loadCharIndices(dictionary_file);
596 auto dictionary = std::get<0>(results);
597 auto charIndices = std::get<1>(results);
598 input_dim = static_cast<int>(charIndices.size());
599 auto RNN = LSTMUnroll(num_lstm_layer, 1, input_dim, num_hidden, num_embed, 0);
600
601 std::map<std::string, NDArray> args_map;
602 args_map["data"] = NDArray(Shape(1, 1), device, false);
603 args_map["softmax_label"] = NDArray(Shape(1, 1), device, false);
604 std::vector<mx_float> zeros(1 * num_hidden, 0);
605 for (int l = 0; l < num_lstm_layer; l++) {
606 std::string key = "l" + std::to_string(l) + "_init_";
607 args_map[key + "c"] = NDArray(Shape(1, num_hidden), device, false);
608 args_map[key + "h"] = NDArray(Shape(1, num_hidden), device, false);
609 args_map[key + "c"].SyncCopyFromCPU(zeros);
610 args_map[key + "h"].SyncCopyFromCPU(zeros);
611 }
612 Executor* exe = RNN.SimpleBind(device, args_map);
613 LoadCheckpoint(param_file, exe);
614
615 mx_float index;
616 wchar_t next = 0;
617 std::vector<mx_float> softmax;
618 softmax.resize(input_dim);
619 for (auto c : *ptext) {
620 exe->arg_dict()["data"].SyncCopyFromCPU(&dictionary[c], 1);
621 exe->Forward(false);
622
623 exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
624 for (int l = 0; l < num_lstm_layer; l++) {
625 std::string key = "l" + std::to_string(l) + "_init_";
626 exe->outputs[l * 2 + 1].CopyTo(&args_map[key + "c"]);
627 exe->outputs[l * 2 + 2].CopyTo(&args_map[key + "h"]);
628 }
629
630 size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
631 index = (mx_float) n;
632 next = charIndices[n];
633 }
634 ptext->push_back(next);
635
636 for (int i = 0; i < sequence_length; i++) {
637 exe->arg_dict()["data"].SyncCopyFromCPU(&index, 1);
638 exe->Forward(false);
639
640 exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
641 for (int l = 0; l < num_lstm_layer; l++) {
642 std::string key = "l" + std::to_string(l) + "_init_";
643 exe->outputs[l * 2 + 1].CopyTo(&args_map[key + "c"]);
644 exe->outputs[l * 2 + 2].CopyTo(&args_map[key + "h"]);
645 }
646
647 size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
648 index = (mx_float) n;
649 next = charIndices[n];
650 ptext->push_back(next);
651 }
652
653 delete exe;
654 }
655
predictWithBuiltInRNNOp(std::wstring * ptext,int sequence_length,const std::string param_file,const std::string dictionary_file)656 void predictWithBuiltInRNNOp(std::wstring* ptext, int sequence_length, const std::string param_file,
657 const std::string dictionary_file) {
658 Context device(DeviceType::kGPU, 0);
659 auto results = BucketSentenceIter::loadCharIndices(dictionary_file);
660 auto dictionary = std::get<0>(results);
661 auto charIndices = std::get<1>(results);
662 input_dim = static_cast<int>(charIndices.size());
663 auto RNN = LSTMWithBuiltInRNNOp(num_lstm_layer, 1, input_dim, num_hidden, num_embed, 0);
664
665 std::map<std::string, NDArray> args_map;
666 args_map["data"] = NDArray(Shape(1, 1), device, false);
667 args_map["softmax_label"] = NDArray(Shape(1, 1), device, false);
668 std::vector<mx_float> zeros(1 * num_lstm_layer * num_hidden, 0);
669 // Avoiding SwapAxis, batch_size=1 is of second dimension.
670 args_map["LSTM_init_c"] = NDArray(Shape(num_lstm_layer, 1, num_hidden), device, false);
671 args_map["LSTM_init_h"] = NDArray(Shape(num_lstm_layer, 1, num_hidden), device, false);
672 args_map["LSTM_init_c"].SyncCopyFromCPU(zeros);
673 args_map["LSTM_init_h"].SyncCopyFromCPU(zeros);
674 Executor* exe = RNN.SimpleBind(device, args_map);
675 LoadCheckpoint(param_file, exe);
676
677 mx_float index;
678 wchar_t next = 0;
679 std::vector<mx_float> softmax;
680 softmax.resize(input_dim);
681 for (auto c : *ptext) {
682 exe->arg_dict()["data"].SyncCopyFromCPU(&dictionary[c], 1);
683 exe->Forward(false);
684
685 exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
686 exe->outputs[1].CopyTo(&args_map["LSTM_init_h"]);
687 exe->outputs[2].CopyTo(&args_map["LSTM_init_c"]);
688
689 size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
690 index = (mx_float) n;
691 next = charIndices[n];
692 }
693 ptext->push_back(next);
694
695 for (int i = 0; i < sequence_length; i++) {
696 exe->arg_dict()["data"].SyncCopyFromCPU(&index, 1);
697 exe->Forward(false);
698
699 exe->outputs[0].SyncCopyToCPU(softmax.data(), input_dim);
700 exe->outputs[1].CopyTo(&args_map["LSTM_init_h"]);
701 exe->outputs[2].CopyTo(&args_map["LSTM_init_c"]);
702
703 size_t n = max_element(softmax.begin(), softmax.end()) - softmax.begin();
704 index = (mx_float) n;
705 next = charIndices[n];
706 ptext->push_back(next);
707 }
708
709 delete exe;
710 }
711
main(int argc,char ** argv)712 int main(int argc, char** argv) {
713 if (argc < 5) {
714 std::cout << "Usage for training: charRNN train[BuiltIn][TimeMajor] {corpus file}"
715 " {batch size} {max epoch} [{starting epoch}]" << std::endl;
716 std::cout <<"Usage for prediction: charRNN predict[BuiltIn][TimeMajor] {params file}"
717 " {dictionary file} {beginning of text}" << std::endl;
718 std::cout <<"Note: The {params file} of train/trainBuiltIn/trainTimeMajor/trainBuiltInTimeMajor"
719 " are not compatible with each other." << std::endl;
720 return 0;
721 }
722
723 std::string task = argv[1];
724 bool builtIn = task.find("BuiltIn") != std::string::npos;
725 TIME_MAJOR = task.find("TimeMajor") != std::string::npos;
726 std::cout << "use BuiltIn cuDNN RNN: " << builtIn << std::endl
727 << "use data as TimeMajor: " << TIME_MAJOR << std::endl;
728 TRY
729 if (task.find("train") == 0) {
730 std::cout << "train batch size: " << argv[3] << std::endl
731 << "train max epoch: " << argv[4] << std::endl;
732 int start_epoch = argc > 5? atoi(argv[5]) : -1;
733 // this function will generate dictionary file and params file.
734 if (builtIn)
735 trainWithBuiltInRNNOp(argv[2], atoi(argv[3]), atoi(argv[4]), start_epoch);
736 else
737 train(argv[2], atoi(argv[3]), atoi(argv[4]), start_epoch); // ditto
738 } else if (task.find("predict") == 0) {
739 std::wstring text; // = L"If there is anyone out there who still doubts ";
740 // Considering of extending to Chinese samples in future, use wchar_t instead of char
741 for (char c : std::string(argv[4]))
742 text.push_back((wchar_t) c);
743 /*Python version predicts text default to random selecltions. Here I didn't write the random
744 code, always choose the 'best' character. So the text length reduced to 600. Longer size often
745 leads to repeated sentances, since training sequence length is only 129 for obama corpus.*/
746 if (builtIn)
747 predictWithBuiltInRNNOp(&text, 600, argv[2], argv[3]);
748 else
749 predict(&text, 600, argv[2], argv[3]);
750 std::wcout << text << std::endl;
751 }
752
753 MXNotifyShutdown();
754 CATCH
755 return 0;
756 }
757