1 // Copyright 2010-2018, Google Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 // * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 // * Neither the name of Google Inc. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30 #include <iostream> // NOLINT
31 #include <map>
32 #include <numeric> // accumulate
33 #include <string>
34 #include <vector>
35
36 #include "base/file_stream.h"
37 #include "base/flags.h"
38 #include "base/init_mozc.h"
39 #include "base/logging.h"
40 #include "base/multifile.h"
41 #include "base/port.h"
42 #include "base/util.h"
43 #include "client/client.h"
44 #include "evaluation/scorer.h"
45 #include "protocol/commands.pb.h"
46
47 // Test data automatically generated by gen_client_quality_test_data.py
48 // TestCase test_cases[] is defined.
49 #include "client/client_quality_test_data.inc"
50
51 DEFINE_string(server_path, "", "specify server path");
52 DEFINE_string(log_path, "", "specify log output file path");
53 DEFINE_int32(max_case_for_source, 500,
54 "specify max test case number for each test sources");
55
56 namespace mozc {
IsValidSourceSentence(const string & str)57 bool IsValidSourceSentence(const string &str) {
58 // TODO(noriyukit) Treat alphabets by changing to Eisu-mode
59 if (Util::ContainsScriptType(str, Util::ALPHABET)) {
60 LOG(WARNING) << "contains ALPHABET: " << str;
61 return false;
62 }
63
64 // Source should not contain kanji
65 if (Util::ContainsScriptType(str, Util::KANJI)) {
66 LOG(WARNING) << "contains KANJI: " << str;
67 return false;
68 }
69
70 // Source should not contain katakana
71 string tmp, tmp2;
72 Util::StringReplace(str, "ー", "", true, &tmp);
73 Util::StringReplace(tmp, "・", "", true, &tmp2);
74 if (Util::ContainsScriptType(tmp2, Util::KATAKANA)) {
75 LOG(WARNING) << "contain KATAKANA: " << str;
76 return false;
77 }
78 return true;
79 }
80
GenerateKeySequenceFrom(const string & hiragana_sentence,std::vector<commands::KeyEvent> * keys)81 bool GenerateKeySequenceFrom(const string& hiragana_sentence,
82 std::vector<commands::KeyEvent>* keys) {
83 CHECK(keys);
84 keys->clear();
85
86 string tmp, input;
87 Util::HiraganaToRomanji(hiragana_sentence, &tmp);
88 Util::FullWidthToHalfWidth(tmp, &input);
89
90 for (ConstChar32Iterator iter(input); !iter.Done(); iter.Next()) {
91 const char32 ucs4 = iter.Get();
92
93 // TODO(noriyukit) Improve key sequence generation; currently, a few ucs4
94 // codes, like FF5E and 300E, cannot be handled.
95 commands::KeyEvent key;
96 if (ucs4 >= 0x20 && ucs4 <= 0x7F) {
97 key.set_key_code(static_cast<int>(ucs4));
98 } else if (ucs4 == 0x3001 || ucs4 == 0xFF64) {
99 key.set_key_code(0x002C); // Full-width comma -> Half-width comma
100 } else if (ucs4 == 0x3002 || ucs4 == 0xFF0E || ucs4 == 0xFF61) {
101 key.set_key_code(0x002E); // Full-width period -> Half-width period
102 } else if (ucs4 == 0x2212 || ucs4 == 0x2015) {
103 key.set_key_code(0x002D); // "−" -> "-"
104 } else if (ucs4 == 0x300C || ucs4 == 0xff62) {
105 key.set_key_code(0x005B); // "「" -> "["
106 } else if (ucs4 == 0x300D || ucs4 == 0xff63) {
107 key.set_key_code(0x005D); // "」" -> "]"
108 } else if (ucs4 == 0x30FB || ucs4 == 0xFF65) {
109 key.set_key_code(0x002F); // "・" -> "/" "・" -> "/"
110 } else {
111 LOG(WARNING) << "Unexpected character: " << std::hex << ucs4 << ": in "
112 << input << " (" << hiragana_sentence << ")";
113 return false;
114 }
115 keys->push_back(key);
116 }
117
118 // Conversion key
119 {
120 commands::KeyEvent key;
121 key.set_special_key(commands::KeyEvent::SPACE);
122 keys->push_back(key);
123 }
124 return true;
125 }
126
GetPreedit(const commands::Output & output,string * str)127 bool GetPreedit(const commands::Output &output, string* str) {
128 CHECK(str);
129
130 if (!output.has_preedit()) {
131 LOG(WARNING) << "No result";
132 return false;
133 }
134
135 str->clear();
136 for (size_t i = 0; i < output.preedit().segment_size(); ++i) {
137 str->append(output.preedit().segment(i).value());
138 }
139
140 return true;
141 }
142
CalculateBLEU(client::Client * client,const string & hiragana_sentence,const string & expected_result,double * score)143 bool CalculateBLEU(client::Client* client,
144 const string& hiragana_sentence,
145 const string& expected_result, double* score) {
146 // Prepare key events
147 std::vector<commands::KeyEvent> keys;
148 if (!GenerateKeySequenceFrom(hiragana_sentence, &keys)) {
149 LOG(WARNING) << "Failed to generated key events from: "
150 << hiragana_sentence;
151 return false;
152 }
153
154 // Must send ON first
155 commands::Output output;
156 {
157 commands::KeyEvent key;
158 key.set_special_key(commands::KeyEvent::ON);
159 client->SendKey(key, &output);
160 }
161
162 // Send keys
163 for (size_t i = 0; i < keys.size(); ++i) {
164 client->SendKey(keys[i], &output);
165 }
166 VLOG(2) << "Server response: " << output.Utf8DebugString();
167
168 // Calculate score
169 string expected_normalized;
170 Scorer::NormalizeForEvaluate(expected_result, &expected_normalized);
171 std::vector<string> goldens;
172 goldens.push_back(expected_normalized);
173 string preedit, preedit_normalized;
174 if (!GetPreedit(output, &preedit) || preedit.empty()) {
175 LOG(WARNING) << "Could not get output";
176 return false;
177 }
178 Scorer::NormalizeForEvaluate(preedit, &preedit_normalized);
179
180 *score = Scorer::BLEUScore(goldens, preedit_normalized);
181
182 VLOG(1) << hiragana_sentence << std::endl
183 << " score: " << (*score) << std::endl
184 << " preedit: " << preedit_normalized << std::endl
185 << "expected: " << expected_normalized;
186
187 // Revert session to prevent server from learning this conversion
188 commands::SessionCommand command;
189 command.set_type(commands::SessionCommand::REVERT);
190 client->SendCommand(command, &output);
191
192 return true;
193 }
194
CalculateMean(const std::vector<double> & scores)195 double CalculateMean(const std::vector<double>& scores) {
196 CHECK(!scores.empty());
197 const double sum = accumulate(scores.begin(), scores.end(), 0.0);
198 return sum / static_cast<double>(scores.size());
199 }
200 } // namespace mozc
201
202
main(int argc,char * argv[])203 int main(int argc, char* argv[]) {
204 mozc::InitMozc(argv[0], &argc, &argv, true);
205
206 mozc::client::Client client;
207 if (!FLAGS_server_path.empty()) {
208 client.set_server_program(FLAGS_server_path);
209 }
210
211 CHECK(client.IsValidRunLevel()) << "IsValidRunLevel failed";
212 CHECK(client.EnsureSession()) << "EnsureSession failed";
213 CHECK(client.NoOperation()) << "Server is not respoinding";
214
215 std::map<string, std::vector<double> > scores; // Results to be averaged
216
217 for (mozc::TestCase* test_case = mozc::test_cases; test_case->source != NULL;
218 ++test_case) {
219 const string &source = test_case->source;
220 const string &hiragana_sentence = test_case->hiragana_sentence;
221 const string &expected_result = test_case->expected_result;
222
223 if (scores.find(source) == scores.end()) {
224 scores[source] = std::vector<double>();
225 }
226 if (scores[source].size() >= FLAGS_max_case_for_source) {
227 continue;
228 }
229
230 VLOG(1) << "Processing " << hiragana_sentence;
231 if (!mozc::IsValidSourceSentence(hiragana_sentence)) {
232 LOG(WARNING) << "Invalid test case: " << std::endl
233 << " source: " << source << std::endl
234 << " hiragana: " << hiragana_sentence << std::endl
235 << " expected: " << expected_result;
236 continue;
237 }
238
239 double score;
240 if (!mozc::CalculateBLEU(&client, hiragana_sentence,
241 expected_result, &score)) {
242 LOG(WARNING) << "Failed to calculate BLEU score: " << std::endl
243 << " source: " << source << std::endl
244 << " hiragana: " << hiragana_sentence << std::endl
245 << " expected: " << expected_result;
246 continue;
247 }
248 scores[source].push_back(score);
249 }
250
251 std::ostream* ofs = &std::cout;
252 if (!FLAGS_log_path.empty()) {
253 ofs = new mozc::OutputFileStream(FLAGS_log_path.c_str());
254 }
255
256 // Average the scores
257 for (std::map<string, std::vector<double> >::iterator it = scores.begin();
258 it != scores.end(); ++it) {
259 const double mean = mozc::CalculateMean(it->second);
260 (*ofs) << it->first << " : " << mean << std::endl;
261 }
262 if (ofs != &std::cout) {
263 delete ofs;
264 }
265
266 return 0;
267 }
268