1 // Copyright 2010-2018, Google Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 //     * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 //     * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 //     * Neither the name of Google Inc. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 
30 #include "prediction/predictor.h"
31 
32 #include <cstddef>
33 #include <memory>
34 #include <string>
35 
36 #include "base/logging.h"
37 #include "base/singleton.h"
38 #include "base/system_util.h"
39 #include "composer/composer.h"
40 #include "config/config_handler.h"
41 #include "converter/segments.h"
42 #include "data_manager/testing/mock_data_manager.h"
43 #include "dictionary/dictionary_mock.h"
44 #include "dictionary/pos_matcher.h"
45 #include "dictionary/suppression_dictionary.h"
46 #include "prediction/predictor_interface.h"
47 #include "prediction/user_history_predictor.h"
48 #include "protocol/commands.pb.h"
49 #include "protocol/config.pb.h"
50 #include "request/conversion_request.h"
51 #include "session/request_test_util.h"
52 #include "testing/base/public/gmock.h"
53 #include "testing/base/public/googletest.h"
54 #include "testing/base/public/gunit.h"
55 
56 using std::unique_ptr;
57 
58 using mozc::dictionary::DictionaryMock;
59 using mozc::dictionary::SuppressionDictionary;
60 using testing::AtMost;
61 using testing::Return;
62 using testing::_;
63 
64 namespace mozc {
65 namespace {
66 
67 class CheckCandSizePredictor : public PredictorInterface {
68  public:
CheckCandSizePredictor(int expected_cand_size)69   explicit CheckCandSizePredictor(int expected_cand_size)
70       : expected_cand_size_(expected_cand_size),
71         predictor_name_("CheckCandSizePredictor") {}
72 
PredictForRequest(const ConversionRequest & request,Segments * segments) const73   bool PredictForRequest(const ConversionRequest &request,
74                          Segments *segments) const override {
75     EXPECT_EQ(expected_cand_size_, segments->max_prediction_candidates_size());
76     return true;
77   }
78 
GetPredictorName() const79   const string &GetPredictorName() const override {
80     return predictor_name_;
81   }
82 
83  private:
84   const int expected_cand_size_;
85   const string predictor_name_;
86 };
87 
88 class NullPredictor : public PredictorInterface {
89  public:
NullPredictor(bool ret)90   explicit NullPredictor(bool ret)
91       : return_value_(ret), predict_called_(false),
92         predictor_name_("NullPredictor") {}
93 
PredictForRequest(const ConversionRequest & request,Segments * segments) const94   bool PredictForRequest(const ConversionRequest &request,
95                          Segments *segments) const override {
96     predict_called_ = true;
97     return return_value_;
98   }
99 
predict_called() const100   bool predict_called() const {
101     return predict_called_;
102   }
103 
Clear()104   void Clear() {
105     predict_called_ = false;
106   }
107 
GetPredictorName() const108   const string &GetPredictorName() const override {
109     return predictor_name_;
110   }
111 
112  private:
113   bool return_value_;
114   mutable bool predict_called_;
115   const string predictor_name_;
116 };
117 
118 class MockPredictor : public PredictorInterface {
119  public:
120   MockPredictor() = default;
121   ~MockPredictor() override = default;
122   MOCK_CONST_METHOD2(
123       PredictForRequest,
124       bool(const ConversionRequest &request, Segments *segments));
125   MOCK_CONST_METHOD0(GetPredictorName, const string &());
126 };
127 
128 }  // namespace
129 
130 class MobilePredictorTest : public ::testing::Test {
131  protected:
SetUp()132   void SetUp() override {
133     config_.reset(new config::Config);
134     config::ConfigHandler::GetDefaultConfig(config_.get());
135 
136     request_.reset(new commands::Request);
137     commands::RequestForUnitTest::FillMobileRequest(request_.get());
138     composer_.reset(new composer::Composer(
139         nullptr, request_.get(), config_.get()));
140 
141     convreq_.reset(
142         new ConversionRequest(composer_.get(), request_.get(), config_.get()));
143   }
144 
145   unique_ptr<mozc::composer::Composer> composer_;
146   unique_ptr<commands::Request> request_;
147   unique_ptr<config::Config> config_;
148   unique_ptr<ConversionRequest> convreq_;
149 };
150 
TEST_F(MobilePredictorTest,CallPredictorsForMobileSuggestion)151 TEST_F(MobilePredictorTest, CallPredictorsForMobileSuggestion) {
152   unique_ptr<MobilePredictor> predictor(
153       new MobilePredictor(new CheckCandSizePredictor(20),
154                           new CheckCandSizePredictor(3)));
155   Segments segments;
156   {
157     segments.set_request_type(Segments::SUGGESTION);
158     Segment *segment;
159     segment = segments.add_segment();
160     CHECK(segment);
161   }
162   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
163 }
164 
TEST_F(MobilePredictorTest,CallPredictorsForMobilePartialSuggestion)165 TEST_F(MobilePredictorTest, CallPredictorsForMobilePartialSuggestion) {
166   unique_ptr<MobilePredictor> predictor(
167       new MobilePredictor(new CheckCandSizePredictor(20),
168                           // We don't call history predictior
169                           new CheckCandSizePredictor(-1)));
170   Segments segments;
171   {
172     segments.set_request_type(Segments::PARTIAL_SUGGESTION);
173     Segment *segment;
174     segment = segments.add_segment();
175     CHECK(segment);
176   }
177   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
178 }
179 
TEST_F(MobilePredictorTest,CallPredictorsForMobilePrediction)180 TEST_F(MobilePredictorTest, CallPredictorsForMobilePrediction) {
181   unique_ptr<MobilePredictor> predictor(
182       new MobilePredictor(new CheckCandSizePredictor(200),
183                           new CheckCandSizePredictor(3)));
184   Segments segments;
185   {
186     segments.set_request_type(Segments::PREDICTION);
187     Segment *segment;
188     segment = segments.add_segment();
189     CHECK(segment);
190   }
191   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
192 }
193 
TEST_F(MobilePredictorTest,CallPredictorsForMobilePartialPrediction)194 TEST_F(MobilePredictorTest, CallPredictorsForMobilePartialPrediction) {
195   DictionaryMock dictionary_mock;
196   testing::MockDataManager data_manager;
197   const dictionary::POSMatcher pos_matcher(data_manager.GetPOSMatcherData());
198   unique_ptr<MobilePredictor> predictor(
199       new MobilePredictor(
200           new CheckCandSizePredictor(200),
201           new UserHistoryPredictor(
202               &dictionary_mock,
203               &pos_matcher,
204               Singleton<SuppressionDictionary>::get(),
205               true)));
206   Segments segments;
207   {
208     segments.set_request_type(Segments::PARTIAL_PREDICTION);
209     Segment *segment;
210     segment = segments.add_segment();
211     CHECK(segment);
212   }
213   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
214 }
215 
TEST_F(MobilePredictorTest,CallPredictForRequetMobile)216 TEST_F(MobilePredictorTest, CallPredictForRequetMobile) {
217   // Will be owned by MobilePredictor
218   MockPredictor *predictor1 = new MockPredictor;
219   MockPredictor *predictor2 = new MockPredictor;
220   unique_ptr<MobilePredictor> predictor(
221       new MobilePredictor(predictor1, predictor2));
222   Segments segments;
223   {
224     segments.set_request_type(Segments::SUGGESTION);
225     Segment *segment;
226     segment = segments.add_segment();
227     CHECK(segment);
228   }
229   EXPECT_CALL(*predictor1, PredictForRequest(_, _))
230       .Times(AtMost(1)).WillOnce(Return(true));
231   EXPECT_CALL(*predictor2, PredictForRequest(_, _))
232       .Times(AtMost(1)).WillOnce(Return(true));
233   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
234 }
235 
236 
237 class PredictorTest : public ::testing::Test {
238  protected:
SetUp()239   virtual void SetUp() {
240     config_.reset(new config::Config);
241     config::ConfigHandler::GetDefaultConfig(config_.get());
242 
243     request_.reset(new commands::Request);
244     composer_.reset(new composer::Composer(
245         nullptr, request_.get(), config_.get()));
246 
247     convreq_.reset(
248         new ConversionRequest(composer_.get(), request_.get(), config_.get()));
249   }
250 
251   unique_ptr<mozc::composer::Composer> composer_;
252   unique_ptr<commands::Request> request_;
253   unique_ptr<config::Config> config_;
254   unique_ptr<ConversionRequest> convreq_;
255 };
256 
TEST_F(PredictorTest,AllPredictorsReturnTrue)257 TEST_F(PredictorTest, AllPredictorsReturnTrue) {
258   unique_ptr<DefaultPredictor> predictor(
259       new DefaultPredictor(new NullPredictor(true),
260                            new NullPredictor(true)));
261   Segments segments;
262   {
263     segments.set_request_type(Segments::SUGGESTION);
264     Segment *segment;
265     segment = segments.add_segment();
266     CHECK(segment);
267   }
268   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
269 }
270 
TEST_F(PredictorTest,MixedReturnValue)271 TEST_F(PredictorTest, MixedReturnValue) {
272   unique_ptr<DefaultPredictor> predictor(
273       new DefaultPredictor(new NullPredictor(true),
274                            new NullPredictor(false)));
275   Segments segments;
276   {
277     segments.set_request_type(Segments::SUGGESTION);
278     Segment *segment;
279     segment = segments.add_segment();
280     CHECK(segment);
281   }
282   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
283 }
284 
TEST_F(PredictorTest,AllPredictorsReturnFalse)285 TEST_F(PredictorTest, AllPredictorsReturnFalse) {
286   unique_ptr<DefaultPredictor> predictor(
287       new DefaultPredictor(new NullPredictor(false),
288                            new NullPredictor(false)));
289   Segments segments;
290   {
291     segments.set_request_type(Segments::SUGGESTION);
292     Segment *segment;
293     segment = segments.add_segment();
294     CHECK(segment);
295   }
296   EXPECT_FALSE(predictor->PredictForRequest(*convreq_, &segments));
297 }
298 
TEST_F(PredictorTest,CallPredictorsForSuggestion)299 TEST_F(PredictorTest, CallPredictorsForSuggestion) {
300   const int suggestions_size =
301       config::ConfigHandler::DefaultConfig().suggestions_size();
302   unique_ptr<DefaultPredictor> predictor(
303       new DefaultPredictor(
304           new CheckCandSizePredictor(suggestions_size),
305           new CheckCandSizePredictor(suggestions_size)));
306   Segments segments;
307   {
308     segments.set_request_type(Segments::SUGGESTION);
309     Segment *segment;
310     segment = segments.add_segment();
311     CHECK(segment);
312   }
313   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
314 }
315 
TEST_F(PredictorTest,CallPredictorsForPrediction)316 TEST_F(PredictorTest, CallPredictorsForPrediction) {
317   const int kPredictionSize = 100;
318   unique_ptr<DefaultPredictor> predictor(
319       new DefaultPredictor(new CheckCandSizePredictor(kPredictionSize),
320                            new CheckCandSizePredictor(kPredictionSize)));
321   Segments segments;
322   {
323     segments.set_request_type(Segments::PREDICTION);
324     Segment *segment;
325     segment = segments.add_segment();
326     CHECK(segment);
327   }
328   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
329 }
330 
TEST_F(PredictorTest,CallPredictForRequet)331 TEST_F(PredictorTest, CallPredictForRequet) {
332   // To be owned by DefaultPredictor
333   MockPredictor *predictor1 = new MockPredictor;
334   MockPredictor *predictor2 = new MockPredictor;
335   unique_ptr<DefaultPredictor> predictor(new DefaultPredictor(predictor1,
336                                                               predictor2));
337   Segments segments;
338   {
339     segments.set_request_type(Segments::SUGGESTION);
340     Segment *segment;
341     segment = segments.add_segment();
342     CHECK(segment);
343   }
344   EXPECT_CALL(*predictor1, PredictForRequest(_, _))
345       .Times(AtMost(1)).WillOnce(Return(true));
346   EXPECT_CALL(*predictor2, PredictForRequest(_, _))
347       .Times(AtMost(1)).WillOnce(Return(true));
348   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
349 }
350 
351 
TEST_F(PredictorTest,DisableAllSuggestion)352 TEST_F(PredictorTest, DisableAllSuggestion) {
353   NullPredictor *predictor1 = new NullPredictor(true);
354   NullPredictor *predictor2 = new NullPredictor(true);
355   unique_ptr<DefaultPredictor> predictor(new DefaultPredictor(predictor1,
356                                                               predictor2));
357   Segments segments;
358   {
359     segments.set_request_type(Segments::SUGGESTION);
360     Segment *segment;
361     segment = segments.add_segment();
362     CHECK(segment);
363   }
364 
365   config_->set_presentation_mode(true);
366   EXPECT_FALSE(predictor->PredictForRequest(*convreq_, &segments));
367   EXPECT_FALSE(predictor1->predict_called());
368   EXPECT_FALSE(predictor2->predict_called());
369 
370   config_->set_presentation_mode(false);
371   EXPECT_TRUE(predictor->PredictForRequest(*convreq_, &segments));
372   EXPECT_TRUE(predictor1->predict_called());
373   EXPECT_TRUE(predictor2->predict_called());
374 }
375 
376 }  // namespace mozc
377