1
2 /******************************************************
3 * Presage, an extensible predictive text entry system
4 * ---------------------------------------------------
5 *
6 * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7
8 This program is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 2 of the License, or
11 (at your option) any later version.
12
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License along
19 with this program; if not, write to the Free Software Foundation, Inc.,
20 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 *
22 **********(*)*/
23
24
25 #include "smoothedNgramPredictor.h"
26
27 #include <sstream>
28 #include <algorithm>
29
30
SmoothedNgramPredictor(Configuration * config,ContextTracker * ct,const char * name)31 SmoothedNgramPredictor::SmoothedNgramPredictor(Configuration* config, ContextTracker* ct, const char* name)
32 : Predictor(config,
33 ct,
34 name,
35 "SmoothedNgramPredictor, a linear interpolating n-gram predictor",
36 "SmoothedNgramPredictor, long description." ),
37 db (0),
38 cardinality (0),
39 learn_mode_set (false),
40 dispatcher (this)
41 {
42 LOGGER = PREDICTORS + name + ".LOGGER";
43 DBFILENAME = PREDICTORS + name + ".DBFILENAME";
44 DELTAS = PREDICTORS + name + ".DELTAS";
45 LEARN = PREDICTORS + name + ".LEARN";
46 DATABASE_LOGGER = PREDICTORS + name + ".DatabaseConnector.LOGGER";
47
48 // build notification dispatch map
49 dispatcher.map (config->find (LOGGER), & SmoothedNgramPredictor::set_logger);
50 dispatcher.map (config->find (DATABASE_LOGGER), & SmoothedNgramPredictor::set_database_logger_level);
51 dispatcher.map (config->find (DBFILENAME), & SmoothedNgramPredictor::set_dbfilename);
52 dispatcher.map (config->find (DELTAS), & SmoothedNgramPredictor::set_deltas);
53 dispatcher.map (config->find (LEARN), & SmoothedNgramPredictor::set_learn);
54 }
55
56
57
~SmoothedNgramPredictor()58 SmoothedNgramPredictor::~SmoothedNgramPredictor()
59 {
60 delete db;
61 }
62
63
set_dbfilename(const std::string & filename)64 void SmoothedNgramPredictor::set_dbfilename (const std::string& filename)
65 {
66 dbfilename = filename;
67 logger << INFO << "DBFILENAME: " << dbfilename << endl;
68
69 init_database_connector_if_ready ();
70 }
71
72
set_database_logger_level(const std::string & value)73 void SmoothedNgramPredictor::set_database_logger_level (const std::string& value)
74 {
75 dbloglevel = value;
76 }
77
78
set_deltas(const std::string & value)79 void SmoothedNgramPredictor::set_deltas (const std::string& value)
80 {
81 std::stringstream ss_deltas(value);
82 cardinality = 0;
83 std::string delta;
84 while (ss_deltas >> delta) {
85 logger << DEBUG << "Pushing delta: " << delta << endl;
86 deltas.push_back (Utility::toDouble (delta));
87 cardinality++;
88 }
89 logger << INFO << "DELTAS: " << value << endl;
90 logger << INFO << "CARDINALITY: " << cardinality << endl;
91
92 init_database_connector_if_ready ();
93 }
94
95
set_learn(const std::string & value)96 void SmoothedNgramPredictor::set_learn (const std::string& value)
97 {
98 learn_mode = Utility::isYes (value);
99 logger << INFO << "LEARN: " << value << endl;
100
101 learn_mode_set = true;
102
103 init_database_connector_if_ready ();
104 }
105
106
init_database_connector_if_ready()107 void SmoothedNgramPredictor::init_database_connector_if_ready ()
108 {
109 // we can only init the sqlite database connector once we know the
110 // following:
111 // - what database file we need to open
112 // - what cardinality we expect the database file to be
113 // - whether we need to open the database in read only or
114 // read/write mode (learning requires read/write access)
115 //
116 if (! dbfilename.empty()
117 && cardinality > 0
118 && learn_mode_set ) {
119
120 delete db;
121
122 if (dbloglevel.empty ()) {
123 // open database connector
124 db = new SqliteDatabaseConnector(dbfilename,
125 cardinality,
126 learn_mode);
127 } else {
128 // open database connector with logger lever
129 db = new SqliteDatabaseConnector(dbfilename,
130 cardinality,
131 learn_mode,
132 dbloglevel);
133 }
134 }
135 }
136
137
138 // convenience function to convert ngram to string
139 //
ngram_to_string(const Ngram & ngram)140 static std::string ngram_to_string(const Ngram& ngram)
141 {
142 const char separator[] = "|";
143 std::string result = separator;
144
145 for (Ngram::const_iterator it = ngram.begin();
146 it != ngram.end();
147 it++)
148 {
149 result += *it + separator;
150 }
151
152 return result;
153 }
154
155
156 /** \brief Builds the required n-gram and returns its count.
157 *
158 * \param tokens tokens[i] contains ContextTracker::getToken(i)
159 * \param offset entry point into tokens, must be a non-positive number
160 * \param ngram_size size of the ngram whose count is returned, must not be greater than tokens size
161 * \return count of the ngram built based on tokens, offset and ngram_size
162 *
163 * \verbatim
164 Let tokens = [ "how", "are", "you", "today" ];
165
166 count(tokens, 0, 3) returns the count associated with 3-gram [ "are", "you", "today" ].
167 count(tokens, -1, 2) returns the count associated with 2-gram [ "are", "you" ];
168 * \endverbatim
169 *
170 */
count(const std::vector<std::string> & tokens,int offset,int ngram_size) const171 unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const
172 {
173 unsigned int result = 0;
174
175 assert(offset <= 0); // TODO: handle this better
176 assert(ngram_size >= 0);
177
178 if (ngram_size > 0) {
179 Ngram ngram(ngram_size);
180 copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin());
181 result = db->getNgramCount(ngram);
182 logger << DEBUG << "count ngram: " << ngram_to_string (ngram) << " : " << result << endl;
183 } else {
184 result = db->getUnigramCountsSum();
185 logger << DEBUG << "unigram counts sum: " << result << endl;
186 }
187
188 return result;
189 }
190
predict(const size_t max_partial_prediction_size,const char ** filter) const191 Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
192 {
193 logger << DEBUG << "predict()" << endl;
194
195 // Result prediction
196 Prediction prediction;
197
198 // Cache all the needed tokens.
199 // tokens[k] corresponds to w_{i-k} in the generalized smoothed
200 // n-gram probability formula
201 //
202 std::vector<std::string> tokens(cardinality);
203 for (int i = 0; i < cardinality; i++) {
204 tokens[cardinality - 1 - i] = contextTracker->getToken(i);
205 logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
206 }
207
208 // Generate list of prefix completition candidates.
209 //
210 // The prefix completion candidates used to be obtained from the
211 // _1_gram table because in a well-constructed ngram database the
212 // _1_gram table (which contains all known tokens). However, this
213 // introduced a skew, since the unigram counts will take
214 // precedence over the higher-order counts.
215 //
216 // The current solution retrieves candidates from the highest
217 // n-gram table, falling back on lower order n-gram tables if
218 // initial completion set is smaller than required.
219 //
220 std::vector<std::string> prefixCompletionCandidates;
221 for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
222 logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
223 // create n-gram used to retrieve initial prefix completion table
224 Ngram prefix_ngram(k);
225 copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());
226
227 if (logger.shouldLog()) {
228 logger << DEBUG << "prefix_ngram: ";
229 for (size_t r = 0; r < prefix_ngram.size(); r++) {
230 logger << DEBUG << prefix_ngram[r] << ' ';
231 }
232 logger << DEBUG << endl;
233 }
234
235 // obtain initial prefix completion candidates
236 db->beginTransaction();
237
238 NgramTable partial;
239
240 if (filter == 0) {
241 partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size());
242 } else {
243 partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size());
244 }
245
246 db->endTransaction();
247
248 if (logger.shouldLog()) {
249 logger << DEBUG << "partial prefixCompletionCandidates" << endl
250 << DEBUG << "----------------------------------" << endl;
251 for (size_t j = 0; j < partial.size(); j++) {
252 for (size_t k = 0; k < partial[j].size(); k++) {
253 logger << DEBUG << partial[j][k] << " ";
254 }
255 logger << endl;
256 }
257 }
258
259 logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;
260
261 // append newly discovered potential completions to prefix
262 // completion candidates array to fill it up to
263 // max_partial_prediction_size
264 //
265 std::vector<Ngram>::const_iterator it = partial.begin();
266 while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
267 // only add new candidates, iterator it points to Ngram,
268 // it->end() - 2 points to the token candidate
269 //
270 std::string candidate = *(it->end() - 2);
271 if (find(prefixCompletionCandidates.begin(),
272 prefixCompletionCandidates.end(),
273 candidate) == prefixCompletionCandidates.end()) {
274 prefixCompletionCandidates.push_back(candidate);
275 }
276 it++;
277 }
278 }
279
280 if (logger.shouldLog()) {
281 logger << DEBUG << "prefixCompletionCandidates" << endl
282 << DEBUG << "--------------------------" << endl;
283 for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
284 logger << DEBUG << prefixCompletionCandidates[j] << endl;
285 }
286 }
287
288 // compute smoothed probabilities for all candidates
289 //
290 db->beginTransaction();
291 // getUnigramCountsSum is an expensive SQL query
292 // caching it here saves much time later inside the loop
293 int unigrams_counts_sum = db->getUnigramCountsSum();
294 for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
295 // store w_i candidate at end of tokens
296 tokens[cardinality - 1] = prefixCompletionCandidates[j];
297
298 logger << DEBUG << "------------------" << endl;
299 logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;
300
301 double probability = 0;
302 for (int k = 0; k < cardinality; k++) {
303 double numerator = count(tokens, 0, k+1);
304 // reuse cached unigrams_counts_sum to speed things up
305 double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
306 double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
307 probability += deltas[k] * frequency;
308
309 logger << DEBUG << "numerator: " << numerator << endl;
310 logger << DEBUG << "denominator: " << denominator << endl;
311 logger << DEBUG << "frequency: " << frequency << endl;
312 logger << DEBUG << "delta: " << deltas[k] << endl;
313
314 // for some sanity checks
315 assert(numerator <= denominator);
316 assert(frequency <= 1);
317 }
318
319 logger << DEBUG << "____________" << endl;
320 logger << DEBUG << "probability: " << probability << endl;
321
322 if (probability > 0) {
323 prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
324 }
325 }
326 db->endTransaction();
327
328 logger << DEBUG << "Prediction:" << endl;
329 logger << DEBUG << "-----------" << endl;
330 logger << DEBUG << prediction << endl;
331
332 return prediction;
333 }
334
learn(const std::vector<std::string> & change)335 void SmoothedNgramPredictor::learn(const std::vector<std::string>& change)
336 {
337 logger << INFO << "learn(\"" << ngram_to_string(change) << "\")" << endl;
338
339 if (learn_mode) {
340 // learning is turned on
341
342 std::map<std::list<std::string>, int> ngramMap;
343
344 // build up ngram map for all cardinalities
345 // i.e. learn all ngrams and counts in memory
346 for (size_t curr_cardinality = 1;
347 curr_cardinality < cardinality + 1;
348 curr_cardinality++)
349 {
350 int change_idx = 0;
351 int change_size = change.size();
352
353 std::list<std::string> ngram_list;
354
355 // take care of first N-1 tokens
356 for (int i = 0;
357 (i < curr_cardinality - 1 && change_idx < change_size);
358 i++)
359 {
360 ngram_list.push_back(change[change_idx]);
361 change_idx++;
362 }
363
364 while (change_idx < change_size)
365 {
366 ngram_list.push_back(change[change_idx++]);
367 ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
368 ngram_list.pop_front();
369 }
370 }
371
372 // use (past stream - change) to learn token at the boundary
373 // change, i.e.
374 //
375
376 // if change is "bar foobar", then "bar" will only occur in a
377 // 1-gram, since there are no token before it. By dipping in
378 // the past stream, we additional context to learn a 2-gram by
379 // getting extra tokens (assuming past stream ends with token
380 // "foo":
381 //
382 // <"foo", "bar"> will be learnt
383 //
384 // We do this till we build up to n equal to cardinality.
385 //
386 // First check that change is not empty (nothing to learn) and
387 // that change and past stream match by sampling first and
388 // last token in change and comparing them with corresponding
389 // tokens from past stream
390 //
391 if (change.size() > 0 &&
392 change.back() == contextTracker->getToken(1) &&
393 change.front() == contextTracker->getToken(change.size()))
394 {
395 // create ngram list with first (oldest) token from change
396 std::list<std::string> ngram_list(change.begin(), change.begin() + 1);
397
398 // prepend token to ngram list by grabbing extra tokens
399 // from past stream (if there are any) till we have built
400 // up to n==cardinality ngrams, and commit them to
401 // ngramMap
402 //
403 for (int tk_idx = 1;
404 ngram_list.size() < cardinality;
405 tk_idx++)
406 {
407 // getExtraTokenToLearn returns tokens from
408 // past stream that come before and are not in
409 // change vector
410 //
411 std::string extra_token = contextTracker->getExtraTokenToLearn(tk_idx, change);
412 logger << DEBUG << "Adding extra token: " << extra_token << endl;
413
414 if (extra_token.empty())
415 {
416 break;
417 }
418 ngram_list.push_front(extra_token);
419
420 ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
421 }
422 }
423
424 // then write out to language model database
425 try
426 {
427 db->beginTransaction();
428
429 std::map<std::list<std::string>, int>::const_iterator it;
430 for (it = ngramMap.begin(); it != ngramMap.end(); it++)
431 {
432 // convert ngram from list to vector based Ngram
433 Ngram ngram((it->first).begin(), (it->first).end());
434
435 // update the counts
436 int count = db->getNgramCount(ngram);
437 if (count > 0)
438 {
439 // ngram already in database, update count
440 db->updateNgram(ngram, count + it->second);
441 check_learn_consistency(ngram);
442 }
443 else
444 {
445 // ngram not in database, insert it
446 db->insertNgram(ngram, it->second);
447 }
448 }
449
450 db->endTransaction();
451 logger << INFO << "Committed learning update to database" << endl;
452 }
453 catch (SqliteDatabaseConnector::SqliteDatabaseConnectorException& ex)
454 {
455 db->rollbackTransaction();
456 logger << ERROR << "Rolling back learning update : " << ex.what() << endl;
457 throw;
458 }
459 }
460
461 logger << DEBUG << "end learn()" << endl;
462 }
463
check_learn_consistency(const Ngram & ngram) const464 void SmoothedNgramPredictor::check_learn_consistency(const Ngram& ngram) const
465 {
466 // no need to begin a new transaction, as we'll be called from
467 // within an existing transaction from learn()
468
469 // BEWARE: if the previous sentence is not true, then performance
470 // WILL suffer!
471
472 size_t size = ngram.size();
473 for (size_t i = 0; i < size; i++) {
474 if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) {
475 logger << INFO << "consistency adjustment needed!" << endl;
476
477 int offset = -(i + 1);
478 int sub_ngram_size = size - (i + 1);
479
480 logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl;
481
482 Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram
483 copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin());
484
485 if (logger.shouldLog()) {
486 logger << "ngram to be count adjusted is: ";
487 for (size_t i = 0; i < sub_ngram.size(); i++) {
488 logger << sub_ngram[i] << ' ';
489 }
490 logger << endl;
491 }
492
493 db->incrementNgramCount(sub_ngram);
494 logger << DEBUG << "consistency adjusted" << endl;
495 }
496 }
497 }
498
update(const Observable * var)499 void SmoothedNgramPredictor::update (const Observable* var)
500 {
501 logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
502 dispatcher.dispatch (var);
503 }
504