1 // Copyright 2010-2018, Google Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 //     * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 //     * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 //     * Neither the name of Google Inc. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 
30 #include <iostream>  // NOLINT
31 #include <map>
32 #include <numeric>  // accumulate
33 #include <string>
34 #include <vector>
35 
36 #include "base/file_stream.h"
37 #include "base/flags.h"
38 #include "base/init_mozc.h"
39 #include "base/logging.h"
40 #include "base/multifile.h"
41 #include "base/port.h"
42 #include "base/util.h"
43 #include "client/client.h"
44 #include "evaluation/scorer.h"
45 #include "protocol/commands.pb.h"
46 
47 // Test data automatically generated by gen_client_quality_test_data.py
48 // TestCase test_cases[] is defined.
49 #include "client/client_quality_test_data.inc"
50 
51 DEFINE_string(server_path, "", "specify server path");
52 DEFINE_string(log_path, "", "specify log output file path");
53 DEFINE_int32(max_case_for_source, 500,
54              "specify max test case number for each test sources");
55 
56 namespace mozc {
IsValidSourceSentence(const string & str)57 bool IsValidSourceSentence(const string &str) {
58   // TODO(noriyukit) Treat alphabets by changing to Eisu-mode
59   if (Util::ContainsScriptType(str, Util::ALPHABET)) {
60     LOG(WARNING) << "contains ALPHABET: " << str;
61     return false;
62   }
63 
64   // Source should not contain kanji
65   if (Util::ContainsScriptType(str, Util::KANJI)) {
66     LOG(WARNING) << "contains KANJI: " << str;
67     return false;
68   }
69 
70   // Source should not contain katakana
71   string tmp, tmp2;
72   Util::StringReplace(str, "ー", "", true, &tmp);
73   Util::StringReplace(tmp, "・", "", true, &tmp2);
74   if (Util::ContainsScriptType(tmp2, Util::KATAKANA)) {
75     LOG(WARNING) << "contain KATAKANA: " << str;
76     return false;
77   }
78   return true;
79 }
80 
GenerateKeySequenceFrom(const string & hiragana_sentence,std::vector<commands::KeyEvent> * keys)81 bool GenerateKeySequenceFrom(const string& hiragana_sentence,
82                              std::vector<commands::KeyEvent>* keys) {
83   CHECK(keys);
84   keys->clear();
85 
86   string tmp, input;
87   Util::HiraganaToRomanji(hiragana_sentence, &tmp);
88   Util::FullWidthToHalfWidth(tmp, &input);
89 
90   for (ConstChar32Iterator iter(input); !iter.Done(); iter.Next()) {
91     const char32 ucs4 = iter.Get();
92 
93     // TODO(noriyukit) Improve key sequence generation; currently, a few ucs4
94     // codes, like FF5E and 300E, cannot be handled.
95     commands::KeyEvent key;
96     if (ucs4 >= 0x20 && ucs4 <= 0x7F) {
97       key.set_key_code(static_cast<int>(ucs4));
98     } else if (ucs4 == 0x3001 || ucs4 == 0xFF64) {
99       key.set_key_code(0x002C);  // Full-width comma -> Half-width comma
100     } else if (ucs4 == 0x3002 || ucs4 == 0xFF0E || ucs4 == 0xFF61) {
101       key.set_key_code(0x002E);  // Full-width period -> Half-width period
102     } else if (ucs4 == 0x2212 || ucs4 == 0x2015) {
103       key.set_key_code(0x002D);  // "−" -> "-"
104     } else if (ucs4 == 0x300C || ucs4 == 0xff62) {
105       key.set_key_code(0x005B);  // "「" -> "["
106     } else if (ucs4 == 0x300D || ucs4 == 0xff63) {
107       key.set_key_code(0x005D);  // "」" -> "]"
108     } else if (ucs4 == 0x30FB || ucs4 == 0xFF65) {
109       key.set_key_code(0x002F);  // "・" -> "/"  "・" -> "/"
110     } else {
111       LOG(WARNING) << "Unexpected character: " << std::hex << ucs4 << ": in "
112                    << input << " (" << hiragana_sentence << ")";
113       return false;
114     }
115     keys->push_back(key);
116   }
117 
118   // Conversion key
119   {
120     commands::KeyEvent key;
121     key.set_special_key(commands::KeyEvent::SPACE);
122     keys->push_back(key);
123   }
124   return true;
125 }
126 
GetPreedit(const commands::Output & output,string * str)127 bool GetPreedit(const commands::Output &output, string* str) {
128   CHECK(str);
129 
130   if (!output.has_preedit()) {
131     LOG(WARNING) << "No result";
132     return false;
133   }
134 
135   str->clear();
136   for (size_t i = 0; i < output.preedit().segment_size(); ++i) {
137     str->append(output.preedit().segment(i).value());
138   }
139 
140   return true;
141 }
142 
CalculateBLEU(client::Client * client,const string & hiragana_sentence,const string & expected_result,double * score)143 bool CalculateBLEU(client::Client* client,
144                    const string& hiragana_sentence,
145                    const string& expected_result, double* score) {
146   // Prepare key events
147   std::vector<commands::KeyEvent> keys;
148   if (!GenerateKeySequenceFrom(hiragana_sentence, &keys)) {
149     LOG(WARNING) << "Failed to generated key events from: "
150                << hiragana_sentence;
151     return false;
152   }
153 
154   // Must send ON first
155   commands::Output output;
156   {
157     commands::KeyEvent key;
158     key.set_special_key(commands::KeyEvent::ON);
159     client->SendKey(key, &output);
160   }
161 
162   // Send keys
163   for (size_t i = 0; i < keys.size(); ++i) {
164     client->SendKey(keys[i], &output);
165   }
166   VLOG(2) << "Server response: " << output.Utf8DebugString();
167 
168   // Calculate score
169   string expected_normalized;
170   Scorer::NormalizeForEvaluate(expected_result, &expected_normalized);
171   std::vector<string> goldens;
172   goldens.push_back(expected_normalized);
173   string preedit, preedit_normalized;
174   if (!GetPreedit(output, &preedit) || preedit.empty()) {
175     LOG(WARNING) << "Could not get output";
176     return false;
177   }
178   Scorer::NormalizeForEvaluate(preedit, &preedit_normalized);
179 
180   *score = Scorer::BLEUScore(goldens, preedit_normalized);
181 
182   VLOG(1) << hiragana_sentence << std::endl
183           << "   score: " << (*score) << std::endl
184           << " preedit: " << preedit_normalized << std::endl
185           << "expected: " << expected_normalized;
186 
187   // Revert session to prevent server from learning this conversion
188   commands::SessionCommand command;
189   command.set_type(commands::SessionCommand::REVERT);
190   client->SendCommand(command, &output);
191 
192   return true;
193 }
194 
CalculateMean(const std::vector<double> & scores)195 double CalculateMean(const std::vector<double>& scores) {
196   CHECK(!scores.empty());
197   const double sum = accumulate(scores.begin(), scores.end(), 0.0);
198   return sum / static_cast<double>(scores.size());
199 }
200 }  // namespace mozc
201 
202 
main(int argc,char * argv[])203 int main(int argc, char* argv[]) {
204   mozc::InitMozc(argv[0], &argc, &argv, true);
205 
206   mozc::client::Client client;
207   if (!FLAGS_server_path.empty()) {
208     client.set_server_program(FLAGS_server_path);
209   }
210 
211   CHECK(client.IsValidRunLevel()) << "IsValidRunLevel failed";
212   CHECK(client.EnsureSession()) << "EnsureSession failed";
213   CHECK(client.NoOperation()) << "Server is not respoinding";
214 
215   std::map<string, std::vector<double> > scores;    // Results to be averaged
216 
217   for (mozc::TestCase* test_case = mozc::test_cases; test_case->source != NULL;
218        ++test_case) {
219     const string &source = test_case->source;
220     const string &hiragana_sentence = test_case->hiragana_sentence;
221     const string &expected_result = test_case->expected_result;
222 
223     if (scores.find(source) == scores.end()) {
224       scores[source] = std::vector<double>();
225     }
226     if (scores[source].size() >= FLAGS_max_case_for_source) {
227       continue;
228     }
229 
230     VLOG(1) << "Processing " << hiragana_sentence;
231     if (!mozc::IsValidSourceSentence(hiragana_sentence)) {
232       LOG(WARNING) << "Invalid test case: " << std::endl
233                    << "    source: " << source << std::endl
234                    << "  hiragana: " << hiragana_sentence << std::endl
235                    << "  expected: " << expected_result;
236       continue;
237     }
238 
239     double score;
240     if (!mozc::CalculateBLEU(&client, hiragana_sentence,
241                              expected_result, &score)) {
242       LOG(WARNING) << "Failed to calculate BLEU score: " << std::endl
243                    << "    source: " << source << std::endl
244                    << "  hiragana: " << hiragana_sentence << std::endl
245                    << "  expected: " << expected_result;
246       continue;
247     }
248     scores[source].push_back(score);
249   }
250 
251   std::ostream* ofs = &std::cout;
252   if (!FLAGS_log_path.empty()) {
253     ofs = new mozc::OutputFileStream(FLAGS_log_path.c_str());
254   }
255 
256   // Average the scores
257   for (std::map<string, std::vector<double> >::iterator it = scores.begin();
258        it != scores.end(); ++it) {
259     const double mean = mozc::CalculateMean(it->second);
260     (*ofs) << it->first << " : " << mean << std::endl;
261   }
262   if (ofs != &std::cout) {
263     delete ofs;
264   }
265 
266   return 0;
267 }
268