1 // Copyright 2010-2018, Google Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 // * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 // * Neither the name of Google Inc. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30 #include "prediction/predictor.h"
31
32 #include <cstddef>
33 #include <memory>
34 #include <string>
35
36 #include "base/logging.h"
37 #include "base/singleton.h"
38 #include "base/system_util.h"
39 #include "composer/composer.h"
40 #include "config/config_handler.h"
41 #include "converter/segments.h"
42 #include "data_manager/testing/mock_data_manager.h"
43 #include "dictionary/dictionary_mock.h"
44 #include "dictionary/pos_matcher.h"
45 #include "dictionary/suppression_dictionary.h"
46 #include "prediction/predictor_interface.h"
47 #include "prediction/user_history_predictor.h"
48 #include "protocol/commands.pb.h"
49 #include "protocol/config.pb.h"
50 #include "request/conversion_request.h"
51 #include "session/request_test_util.h"
52 #include "testing/base/public/gmock.h"
53 #include "testing/base/public/googletest.h"
54 #include "testing/base/public/gunit.h"
55
56 using std::unique_ptr;
57
58 using mozc::dictionary::DictionaryMock;
59 using mozc::dictionary::SuppressionDictionary;
60 using testing::AtMost;
61 using testing::Return;
62 using testing::_;
63
64 namespace mozc {
65 namespace {
66
67 class CheckCandSizePredictor : public PredictorInterface {
68 public:
CheckCandSizePredictor(int expected_cand_size)69 explicit CheckCandSizePredictor(int expected_cand_size)
70 : expected_cand_size_(expected_cand_size),
71 predictor_name_("CheckCandSizePredictor") {}
72
PredictForRequest(const ConversionRequest & request,Segments * segments) const73 bool PredictForRequest(const ConversionRequest &request,
74 Segments *segments) const override {
75 EXPECT_EQ(expected_cand_size_, segments->max_prediction_candidates_size());
76 return true;
77 }
78
GetPredictorName() const79 const string &GetPredictorName() const override {
80 return predictor_name_;
81 }
82
83 private:
84 const int expected_cand_size_;
85 const string predictor_name_;
86 };
87
88 class NullPredictor : public PredictorInterface {
89 public:
NullPredictor(bool ret)90 explicit NullPredictor(bool ret)
91 : return_value_(ret), predict_called_(false),
92 predictor_name_("NullPredictor") {}
93
PredictForRequest(const ConversionRequest & request,Segments * segments) const94 bool PredictForRequest(const ConversionRequest &request,
95 Segments *segments) const override {
96 predict_called_ = true;
97 return return_value_;
98 }
99
predict_called() const100 bool predict_called() const {
101 return predict_called_;
102 }
103
Clear()104 void Clear() {
105 predict_called_ = false;
106 }
107
GetPredictorName() const108 const string &GetPredictorName() const override {
109 return predictor_name_;
110 }
111
112 private:
113 bool return_value_;
114 mutable bool predict_called_;
115 const string predictor_name_;
116 };
117
118 class MockPredictor : public PredictorInterface {
119 public:
120 MockPredictor() = default;
121 ~MockPredictor() override = default;
122 MOCK_CONST_METHOD2(
123 PredictForRequest,
124 bool(const ConversionRequest &request, Segments *segments));
125 MOCK_CONST_METHOD0(GetPredictorName, const string &());
126 };
127
128 } // namespace
129
130 class MobilePredictorTest : public ::testing::Test {
131 protected:
SetUp()132 void SetUp() override {
133 config_.reset(new config::Config);
134 config::ConfigHandler::GetDefaultConfig(config_.get());
135
136 request_.reset(new commands::Request);
137 commands::RequestForUnitTest::FillMobileRequest(request_.get());
138 composer_.reset(new composer::Composer(
139 nullptr, request_.get(), config_.get()));
140
141 convreq_.reset(
142 new ConversionRequest(composer_.get(), request_.get(), config_.get()));
143 }
144
145 unique_ptr<mozc::composer::Composer> composer_;
146 unique_ptr<commands::Request> request_;
147 unique_ptr<config::Config> config_;
148 unique_ptr<ConversionRequest> convreq_;
149 };
150
TEST_F(MobilePredictorTest,CallPredictorsForMobileSuggestion)151 TEST_F(MobilePredictorTest, CallPredictorsForMobileSuggestion) {
152 unique_ptr<MobilePredictor> predictor(
153 new MobilePredictor(new CheckCandSizePredictor(20),
154 new CheckCandSizePredictor(3)));
155 Segments segments;
156 {
157 segments.set_request_type(Segments::SUGGESTION);
158 Segment *segment;
159 segment = segments.add_segment();
160 CHECK(segment);
161 }
162 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
163 }
164
TEST_F(MobilePredictorTest,CallPredictorsForMobilePartialSuggestion)165 TEST_F(MobilePredictorTest, CallPredictorsForMobilePartialSuggestion) {
166 unique_ptr<MobilePredictor> predictor(
167 new MobilePredictor(new CheckCandSizePredictor(20),
168 // We don't call history predictior
169 new CheckCandSizePredictor(-1)));
170 Segments segments;
171 {
172 segments.set_request_type(Segments::PARTIAL_SUGGESTION);
173 Segment *segment;
174 segment = segments.add_segment();
175 CHECK(segment);
176 }
177 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
178 }
179
TEST_F(MobilePredictorTest,CallPredictorsForMobilePrediction)180 TEST_F(MobilePredictorTest, CallPredictorsForMobilePrediction) {
181 unique_ptr<MobilePredictor> predictor(
182 new MobilePredictor(new CheckCandSizePredictor(200),
183 new CheckCandSizePredictor(3)));
184 Segments segments;
185 {
186 segments.set_request_type(Segments::PREDICTION);
187 Segment *segment;
188 segment = segments.add_segment();
189 CHECK(segment);
190 }
191 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
192 }
193
TEST_F(MobilePredictorTest,CallPredictorsForMobilePartialPrediction)194 TEST_F(MobilePredictorTest, CallPredictorsForMobilePartialPrediction) {
195 DictionaryMock dictionary_mock;
196 testing::MockDataManager data_manager;
197 const dictionary::POSMatcher pos_matcher(data_manager.GetPOSMatcherData());
198 unique_ptr<MobilePredictor> predictor(
199 new MobilePredictor(
200 new CheckCandSizePredictor(200),
201 new UserHistoryPredictor(
202 &dictionary_mock,
203 &pos_matcher,
204 Singleton<SuppressionDictionary>::get(),
205 true)));
206 Segments segments;
207 {
208 segments.set_request_type(Segments::PARTIAL_PREDICTION);
209 Segment *segment;
210 segment = segments.add_segment();
211 CHECK(segment);
212 }
213 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
214 }
215
TEST_F(MobilePredictorTest,CallPredictForRequetMobile)216 TEST_F(MobilePredictorTest, CallPredictForRequetMobile) {
217 // Will be owned by MobilePredictor
218 MockPredictor *predictor1 = new MockPredictor;
219 MockPredictor *predictor2 = new MockPredictor;
220 unique_ptr<MobilePredictor> predictor(
221 new MobilePredictor(predictor1, predictor2));
222 Segments segments;
223 {
224 segments.set_request_type(Segments::SUGGESTION);
225 Segment *segment;
226 segment = segments.add_segment();
227 CHECK(segment);
228 }
229 EXPECT_CALL(*predictor1, PredictForRequest(_, _))
230 .Times(AtMost(1)).WillOnce(Return(true));
231 EXPECT_CALL(*predictor2, PredictForRequest(_, _))
232 .Times(AtMost(1)).WillOnce(Return(true));
233 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
234 }
235
236
237 class PredictorTest : public ::testing::Test {
238 protected:
SetUp()239 virtual void SetUp() {
240 config_.reset(new config::Config);
241 config::ConfigHandler::GetDefaultConfig(config_.get());
242
243 request_.reset(new commands::Request);
244 composer_.reset(new composer::Composer(
245 nullptr, request_.get(), config_.get()));
246
247 convreq_.reset(
248 new ConversionRequest(composer_.get(), request_.get(), config_.get()));
249 }
250
251 unique_ptr<mozc::composer::Composer> composer_;
252 unique_ptr<commands::Request> request_;
253 unique_ptr<config::Config> config_;
254 unique_ptr<ConversionRequest> convreq_;
255 };
256
TEST_F(PredictorTest,AllPredictorsReturnTrue)257 TEST_F(PredictorTest, AllPredictorsReturnTrue) {
258 unique_ptr<DefaultPredictor> predictor(
259 new DefaultPredictor(new NullPredictor(true),
260 new NullPredictor(true)));
261 Segments segments;
262 {
263 segments.set_request_type(Segments::SUGGESTION);
264 Segment *segment;
265 segment = segments.add_segment();
266 CHECK(segment);
267 }
268 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
269 }
270
TEST_F(PredictorTest,MixedReturnValue)271 TEST_F(PredictorTest, MixedReturnValue) {
272 unique_ptr<DefaultPredictor> predictor(
273 new DefaultPredictor(new NullPredictor(true),
274 new NullPredictor(false)));
275 Segments segments;
276 {
277 segments.set_request_type(Segments::SUGGESTION);
278 Segment *segment;
279 segment = segments.add_segment();
280 CHECK(segment);
281 }
282 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
283 }
284
TEST_F(PredictorTest,AllPredictorsReturnFalse)285 TEST_F(PredictorTest, AllPredictorsReturnFalse) {
286 unique_ptr<DefaultPredictor> predictor(
287 new DefaultPredictor(new NullPredictor(false),
288 new NullPredictor(false)));
289 Segments segments;
290 {
291 segments.set_request_type(Segments::SUGGESTION);
292 Segment *segment;
293 segment = segments.add_segment();
294 CHECK(segment);
295 }
296 EXPECT_FALSE(predictor->PredictForRequest(*convreq_, &segments));
297 }
298
TEST_F(PredictorTest,CallPredictorsForSuggestion)299 TEST_F(PredictorTest, CallPredictorsForSuggestion) {
300 const int suggestions_size =
301 config::ConfigHandler::DefaultConfig().suggestions_size();
302 unique_ptr<DefaultPredictor> predictor(
303 new DefaultPredictor(
304 new CheckCandSizePredictor(suggestions_size),
305 new CheckCandSizePredictor(suggestions_size)));
306 Segments segments;
307 {
308 segments.set_request_type(Segments::SUGGESTION);
309 Segment *segment;
310 segment = segments.add_segment();
311 CHECK(segment);
312 }
313 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
314 }
315
TEST_F(PredictorTest,CallPredictorsForPrediction)316 TEST_F(PredictorTest, CallPredictorsForPrediction) {
317 const int kPredictionSize = 100;
318 unique_ptr<DefaultPredictor> predictor(
319 new DefaultPredictor(new CheckCandSizePredictor(kPredictionSize),
320 new CheckCandSizePredictor(kPredictionSize)));
321 Segments segments;
322 {
323 segments.set_request_type(Segments::PREDICTION);
324 Segment *segment;
325 segment = segments.add_segment();
326 CHECK(segment);
327 }
328 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
329 }
330
TEST_F(PredictorTest,CallPredictForRequet)331 TEST_F(PredictorTest, CallPredictForRequet) {
332 // To be owned by DefaultPredictor
333 MockPredictor *predictor1 = new MockPredictor;
334 MockPredictor *predictor2 = new MockPredictor;
335 unique_ptr<DefaultPredictor> predictor(new DefaultPredictor(predictor1,
336 predictor2));
337 Segments segments;
338 {
339 segments.set_request_type(Segments::SUGGESTION);
340 Segment *segment;
341 segment = segments.add_segment();
342 CHECK(segment);
343 }
344 EXPECT_CALL(*predictor1, PredictForRequest(_, _))
345 .Times(AtMost(1)).WillOnce(Return(true));
346 EXPECT_CALL(*predictor2, PredictForRequest(_, _))
347 .Times(AtMost(1)).WillOnce(Return(true));
348 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
349 }
350
351
TEST_F(PredictorTest,DisableAllSuggestion)352 TEST_F(PredictorTest, DisableAllSuggestion) {
353 NullPredictor *predictor1 = new NullPredictor(true);
354 NullPredictor *predictor2 = new NullPredictor(true);
355 unique_ptr<DefaultPredictor> predictor(new DefaultPredictor(predictor1,
356 predictor2));
357 Segments segments;
358 {
359 segments.set_request_type(Segments::SUGGESTION);
360 Segment *segment;
361 segment = segments.add_segment();
362 CHECK(segment);
363 }
364
365 config_->set_presentation_mode(true);
366 EXPECT_FALSE(predictor->PredictForRequest(*convreq_, &segments));
367 EXPECT_FALSE(predictor1->predict_called());
368 EXPECT_FALSE(predictor2->predict_called());
369
370 config_->set_presentation_mode(false);
371 EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
372 EXPECT_TRUE(predictor1->predict_called());
373 EXPECT_TRUE(predictor2->predict_called());
374 }
375
376 } // namespace mozc
377