1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // This file contains a set of tests related to a toy neural network
18 // that can hyphenate words. This isn't meant to represent great strides
19 // in machine learning research, but rather to exercise the CPU JIT
20 // compiler with a small end-to-end example. The toy network is small
21 // enough that it can be trained as part of the unit test suite.
22 
23 #include "BackendTestUtils.h"
24 
25 #include "glow/ExecutionEngine/ExecutionEngine.h"
26 #include "glow/Graph/Graph.h"
27 #include "glow/IR/IR.h"
28 #include "glow/IR/IRBuilder.h"
29 #include "glow/IR/Instrs.h"
30 #include "glow/Support/Random.h"
31 
32 #include "gtest/gtest.h"
33 
34 #include <cctype>
35 #include <string>
36 
37 using namespace glow;
38 using llvm::cast;
39 using std::string;
40 using std::vector;
41 
42 // Network architecture
43 // ====================
44 //
45 // The network is a simple multi-layer perceptron with 27 x 6 input
46 // nodes, 10 inner nodes, and a 2-way soft-max output node.
47 //
48 // The input nodes represent 6 letters of a candidate word, and the
49 // output node indicates the probability that the word can be hyphenated
50 // between the 3rd and 4th letters. As a word slides past the 6-letter
51 // window, the network classifies each possible hyphen position.
52 //
53 // Example: "hyphenate"
54 //
55 // "..hyph" -> 0, h-yphenate is wrong.
56 // ".hyphe" -> 1, hy-phenate is right.
57 // "hyphen" -> 0, hyp-henate is wrong.
58 // "yphena" -> 0, hyph-enate is wrong.
59 // "phenat" -> 0, hyphe-nate is wrong.
60 // "henate" -> 1, hyphen-ate is right.
61 // "enate." -> 0, hyphena-te is wrong.
62 // "nate.." -> 0, hyphenat-e is wrong.
63 
64 /// Parse an already hyphenated word into word windows and hyphen labels.
65 ///
66 /// Given a word with embedded hyphens, generate a sequence of sliding
67 /// 6-character windows and associated boolean labels, like the table above.
68 ///
dehyphenate(const char * hword,vector<string> & words,vector<bool> & hyphens)69 static void dehyphenate(const char *hword, vector<string> &words,
70                         vector<bool> &hyphens) {
71   EXPECT_EQ(words.size(), hyphens.size());
72 
73   // The first character can't be a hyphen, and the word can't be null.
74   EXPECT_TRUE(std::islower(*hword));
75   string word = "..";
76   word.push_back(*hword++);
77 
78   // Parse `hword` and add all the letters to `word` and hyphen/no-hyphen
79   // entries to `hyphens`.
80   for (; *hword; hword++) {
81     bool hyph = (*hword == '-');
82     hyphens.push_back(hyph);
83     if (hyph) {
84       hword++;
85     }
86     // There can't be multiple adjacent hyphens, and the word can't
87     // end with a hyphen.
88     EXPECT_TRUE(std::islower(*hword));
89     word.push_back(*hword);
90   }
91   word += "..";
92 
93   // Now `word` contains the letters of `hword` surrounded by '..' on both
94   // sides. Generate all 6-character windows and append them to `words`.
95   for (size_t i = 0, e = word.size(); i + 5 < e; i++) {
96     words.push_back(word.substr(i, 6));
97   }
98   EXPECT_EQ(words.size(), hyphens.size());
99 }
100 
TEST(HyphenTest,dehyphenate)101 TEST(HyphenTest, dehyphenate) {
102   vector<string> words;
103   vector<bool> hyphens;
104 
105   dehyphenate("x", words, hyphens);
106   EXPECT_EQ(words.size(), 0);
107   EXPECT_EQ(hyphens.size(), 0);
108 
109   dehyphenate("xy", words, hyphens);
110   EXPECT_EQ(words, (vector<string>{"..xy.."}));
111   EXPECT_EQ(hyphens, (vector<bool>{0}));
112 
113   dehyphenate("y-z", words, hyphens);
114   EXPECT_EQ(words, (vector<string>{"..xy..", "..yz.."}));
115   EXPECT_EQ(hyphens, (vector<bool>{0, 1}));
116 
117   words.clear();
118   hyphens.clear();
119   dehyphenate("hy-phen-ate", words, hyphens);
120   EXPECT_EQ(words, (vector<string>{"..hyph", ".hyphe", "hyphen", "yphena",
121                                    "phenat", "henate", "enate.", "nate.."}));
122   EXPECT_EQ(hyphens, (vector<bool>{0, 1, 0, 0, 0, 1, 0, 0}));
123 }
124 
125 /// Map a lower-case letter to an input index in the range 0-26.
126 /// Use 0 to represent any characters outside the a-z range.
mapLetter(char l)127 static size_t mapLetter(char l) {
128   unsigned d = l - unsigned('a');
129   return d < 26 ? d + 1 : 0;
130 }
131 
TEST(HyphenTest,mapLetter)132 TEST(HyphenTest, mapLetter) {
133   EXPECT_EQ(mapLetter('a'), 1);
134   EXPECT_EQ(mapLetter('d'), 4);
135   EXPECT_EQ(mapLetter('z'), 26);
136   EXPECT_EQ(mapLetter('.'), 0);
137 }
138 
139 /// Map a 6-letter window of a word to an input tensor using a one-hot encoding.
140 ///
141 /// The tensor must be N x 6 x 27: batch x position x letter.
mapLetterWindow(const string & window,dim_t idx,Handle<float> tensor)142 static void mapLetterWindow(const string &window, dim_t idx,
143                             Handle<float> tensor) {
144   EXPECT_EQ(window.size(), 6);
145   for (dim_t row = 0; row < 6; row++) {
146     dim_t col = mapLetter(window[row]);
147     tensor.at({idx, row, col}) = 1;
148   }
149 }
150 
151 // Training data consisting of pre-hyphenated common words.
152 const vector<const char *> TrainingData{
153     "ad-mi-ni-stra-tion",
154     "ad-mit",
155     "al-low",
156     "al-though",
157     "an-i-mal",
158     "any-one",
159     "ar-rive",
160     "art",
161     "at-tor-ney",
162     "be-cause",
163     "be-fore",
164     "be-ha-vior",
165     "can-cer",
166     "cer-tain-ly",
167     "con-gress",
168     "coun-try",
169     "cul-tural",
170     "cul-ture",
171     "de-cide",
172     "de-fense",
173     "de-gree",
174     "de-sign",
175     "de-spite",
176     "de-velop",
177     "di-rec-tion",
178     "di-rec-tor",
179     "dis-cus-sion",
180     "eco-nomy",
181     "elec-tion",
182     "en-vi-ron-men-tal",
183     "es-tab-lish",
184     "ev-ery-one",
185     "ex-actly",
186     "ex-ec-u-tive",
187     "ex-ist",
188     "ex-pe-ri-ence",
189     "ex-plain",
190     "fi-nally",
191     "for-get",
192     "hun-dred",
193     "in-crease",
194     "in-di-vid-ual",
195     "it-self",
196     "lan-guage",
197     "le-gal",
198     "lit-tle",
199     "lo-cal",
200     "ma-jo-ri-ty",
201     "ma-te-rial",
202     "may-be",
203     "me-di-cal",
204     "meet-ing",
205     "men-tion",
206     "mid-dle",
207     "na-tion",
208     "na-tional",
209     "oc-cur",
210     "of-fi-cer",
211     "par-tic-u-lar-ly",
212     "pat-tern",
213     "pe-riod",
214     "phy-si-cal",
215     "po-si-tion",
216     "pol-icy",
217     "pos-si-ble",
218     "pre-vent",
219     "pres-sure",
220     "pro-per-ty",
221     "pur-pose",
222     "re-cog-nize",
223     "re-gion",
224     "re-la-tion-ship",
225     "re-main",
226     "re-sponse",
227     "re-sult",
228     "rea-son",
229     "sea-son",
230     "sex-ual",
231     "si-mi-lar",
232     "sig-ni-fi-cant",
233     "sim-ple",
234     "sud-den-ly",
235     "sum-mer",
236     "thou-sand",
237     "to-day",
238     "train-ing",
239     "treat-ment",
240     "va-ri-ous",
241     "value",
242     "vi-o-lence",
243 };
244 
245 namespace {
246 struct HyphenNetwork {
247   /// The execution context.
248   PlaceholderBindings bindings_;
249 
250   /// The input variable is N x 6 x 27 as encoded by mapLetterWindow().
251   Placeholder *input_;
252 
253   /// The expected output index when training: 0 = no hyphen, 1 = hyphen.
254   Placeholder *expected_;
255 
256   /// The forward inference function.
257   Function *infer_;
258 
259   /// The result of the forward inference. N x 1 float with a probability.
260   SaveNode *result_;
261 
262   /// The corresponding gradient function for training.
263   Function *train_;
264 
HyphenNetwork__anon196783880111::HyphenNetwork265   HyphenNetwork(Module &mod, TrainingConfig &conf)
266       : input_(mod.createPlaceholder(ElemKind::FloatTy, {conf.batchSize, 6, 27},
267                                      "input", false)),
268         expected_(mod.createPlaceholder(ElemKind::Int64ITy, {conf.batchSize, 1},
269                                         "expected", false)),
270         infer_(mod.createFunction("infer")), result_(nullptr), train_(nullptr) {
271     bindings_.allocate(input_);
272     bindings_.allocate(expected_);
273     Node *n;
274 
275     n = infer_->createFullyConnected(bindings_, "hidden_fc", input_, 10);
276     n = infer_->createRELU("hidden", n);
277     n = infer_->createFullyConnected(bindings_, "output_fc", n, 2);
278     n = infer_->createSoftMax("output", n, expected_);
279     result_ = infer_->createSave("result", n);
280     bindings_.allocate(result_->getPlaceholder());
281     train_ = glow::differentiate(infer_, conf);
282   }
283 
284   // Run `inputs` through the inference function and check the results against
285   // `hyphens`. Return the number of errors.
inferenceErrors__anon196783880111::HyphenNetwork286   unsigned inferenceErrors(ExecutionEngine &EE, llvm::StringRef fName,
287                            Tensor &inputs, const vector<bool> &hyphens,
288                            TrainingConfig &TC) {
289     dim_t batchSize = TC.batchSize;
290     dim_t numSamples = inputs.dims()[0];
291     EXPECT_LE(batchSize, numSamples);
292     auto resultHandle =
293         bindings_.get(bindings_.getPlaceholderByNameSlow("result"))
294             ->getHandle<>();
295     unsigned errors = 0;
296 
297     for (dim_t bi = 0; bi < numSamples; bi += batchSize) {
298       // Get a batch-sized slice of inputs and run them through the inference
299       // function. Do a bit of overlapping if the batch size doesn't divide the
300       // number of samples.
301       if (bi + batchSize > numSamples) {
302         bi = numSamples - batchSize;
303       }
304       auto batchInputs = inputs.getUnowned({batchSize, 6, 27}, {bi, 0, 0});
305       updateInputPlaceholders(bindings_, {input_}, {&batchInputs});
306       EE.run(bindings_, fName);
307 
308       // Check each output in the batch.
309       for (dim_t i = 0; i != batchSize; i++) {
310         // Note that the two softmax outputs always sum to 1, so we only look at
311         // one.
312         float value = resultHandle.at({i, 1});
313         if ((value > 0.5) != hyphens[bi + i]) {
314           errors++;
315         }
316       }
317     }
318     return errors;
319   }
320 };
321 } // namespace
322 
TEST(HyphenTest,network)323 TEST(HyphenTest, network) {
324   ExecutionEngine EE("CPU");
325 
326   // Convert the training data to word windows and labels.
327   vector<string> words;
328   vector<bool> hyphens;
329   for (auto *hword : TrainingData) {
330     dehyphenate(hword, words, hyphens);
331   }
332 
333   // This depends on the training data, of course.
334   const dim_t numSamples = 566;
335   ASSERT_EQ(hyphens.size(), numSamples);
336   ASSERT_EQ(words.size(), numSamples);
337 
338   // Randomly shuffle the training data.
339   // This is required for stochastic gradient descent training.
340   auto &PRNG = EE.getModule().getPRNG();
341   for (size_t i = numSamples - 1; i > 0; i--) {
342     size_t j = PRNG.nextRandInt(0, i);
343     std::swap(words[i], words[j]);
344     std::swap(hyphens[i], hyphens[j]);
345   }
346 
347   // Convert words and hyphens to a tensor representation.
348   Tensor inputs(ElemKind::FloatTy, {numSamples, 6, 27});
349   Tensor expected(ElemKind::Int64ITy, {numSamples, 1});
350   inputs.zero();
351   auto inputHandle = inputs.getHandle<float>();
352   auto expectedHandle = expected.getHandle<int64_t>();
353   for (dim_t i = 0; i != numSamples; i++) {
354     mapLetterWindow(words[i], i, inputHandle);
355     expectedHandle.at({i, 0}) = hyphens[i];
356   }
357 
358   // Now build the network.
359   TrainingConfig TC;
360   TC.learningRate = 0.8;
361   TC.batchSize = 50;
362   HyphenNetwork net(EE.getModule(), TC);
363   auto fName = net.infer_->getName();
364   auto tfName = net.train_->getName();
365 
366   // This variable records the number of the next sample to be used for
367   // training.
368   size_t sampleCounter = 0;
369 
370   // Train using mini-batch SGD.
371   EE.compile(CompilationMode::Train);
372   runBatch(EE, net.bindings_, 1000, sampleCounter, {net.input_, net.expected_},
373            {&inputs, &expected}, tfName);
374 
375   // Now test inference on the trained network.
376   // Note that we have probably overfitted the data, so we expect 100% accuracy.
377   EXPECT_EQ(net.inferenceErrors(EE, fName, inputs, hyphens, TC), 0);
378 
379   // See of the interpreter gets the same result.
380 
381   ExecutionEngine EE2("CPU");
382   HyphenNetwork netInterpreter(EE2.getModule(), TC);
383   EE2.compile(CompilationMode::Train);
384   // Copy the trained weights from the CPU run.
385   net.bindings_.copyToTarget("bias", netInterpreter.bindings_);
386   net.bindings_.copyToTarget("bias__1", netInterpreter.bindings_);
387   net.bindings_.copyToTarget("weights", netInterpreter.bindings_);
388   net.bindings_.copyToTarget("weights__1", netInterpreter.bindings_);
389 
390   EXPECT_EQ(netInterpreter.inferenceErrors(EE2, fName, inputs, hyphens, TC), 0);
391 }
392