1 
2 /******************************************************
3  *  Presage, an extensible predictive text entry system
4  *  ---------------------------------------------------
5  *
6  *  Copyright (C) 2008  Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7 
8     This program is free software; you can redistribute it and/or modify
9     it under the terms of the GNU General Public License as published by
10     the Free Software Foundation; either version 2 of the License, or
11     (at your option) any later version.
12 
13     This program is distributed in the hope that it will be useful,
14     but WITHOUT ANY WARRANTY; without even the implied warranty of
15     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16     GNU General Public License for more details.
17 
18     You should have received a copy of the GNU General Public License along
19     with this program; if not, write to the Free Software Foundation, Inc.,
20     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21                                                                              *
22                                                                 **********(*)*/
23 
24 
25 #include "smoothedNgramPredictor.h"
26 
27 #include <sstream>
28 #include <algorithm>
29 
30 
SmoothedNgramPredictor(Configuration * config,ContextTracker * ct,const char * name)31 SmoothedNgramPredictor::SmoothedNgramPredictor(Configuration* config, ContextTracker* ct, const char* name)
32     : Predictor(config,
33 		ct,
34 		name,
35 		"SmoothedNgramPredictor, a linear interpolating n-gram predictor",
36 		"SmoothedNgramPredictor, long description." ),
37       db (0),
38       cardinality (0),
39       learn_mode_set (false),
40       dispatcher (this)
41 {
42     LOGGER     = PREDICTORS + name + ".LOGGER";
43     DBFILENAME = PREDICTORS + name + ".DBFILENAME";
44     DELTAS     = PREDICTORS + name + ".DELTAS";
45     LEARN      = PREDICTORS + name + ".LEARN";
46     DATABASE_LOGGER = PREDICTORS + name + ".DatabaseConnector.LOGGER";
47 
48     // build notification dispatch map
49     dispatcher.map (config->find (LOGGER), & SmoothedNgramPredictor::set_logger);
50     dispatcher.map (config->find (DATABASE_LOGGER), & SmoothedNgramPredictor::set_database_logger_level);
51     dispatcher.map (config->find (DBFILENAME), & SmoothedNgramPredictor::set_dbfilename);
52     dispatcher.map (config->find (DELTAS), & SmoothedNgramPredictor::set_deltas);
53     dispatcher.map (config->find (LEARN), & SmoothedNgramPredictor::set_learn);
54 }
55 
56 
57 
~SmoothedNgramPredictor()58 SmoothedNgramPredictor::~SmoothedNgramPredictor()
59 {
60     delete db;
61 }
62 
63 
set_dbfilename(const std::string & filename)64 void SmoothedNgramPredictor::set_dbfilename (const std::string& filename)
65 {
66     dbfilename = filename;
67     logger << INFO << "DBFILENAME: " << dbfilename << endl;
68 
69     init_database_connector_if_ready ();
70 }
71 
72 
set_database_logger_level(const std::string & value)73 void SmoothedNgramPredictor::set_database_logger_level (const std::string& value)
74 {
75     dbloglevel = value;
76 }
77 
78 
set_deltas(const std::string & value)79 void SmoothedNgramPredictor::set_deltas (const std::string& value)
80 {
81     std::stringstream ss_deltas(value);
82     cardinality = 0;
83     std::string delta;
84     while (ss_deltas >> delta) {
85         logger << DEBUG << "Pushing delta: " << delta << endl;
86 	deltas.push_back (Utility::toDouble (delta));
87 	cardinality++;
88     }
89     logger << INFO << "DELTAS: " << value << endl;
90     logger << INFO << "CARDINALITY: " << cardinality << endl;
91 
92     init_database_connector_if_ready ();
93 }
94 
95 
set_learn(const std::string & value)96 void SmoothedNgramPredictor::set_learn (const std::string& value)
97 {
98     learn_mode = Utility::isYes (value);
99     logger << INFO << "LEARN: " << value << endl;
100 
101     learn_mode_set = true;
102 
103     init_database_connector_if_ready ();
104 }
105 
106 
init_database_connector_if_ready()107 void SmoothedNgramPredictor::init_database_connector_if_ready ()
108 {
109     // we can only init the sqlite database connector once we know the
110     // following:
111     //  - what database file we need to open
112     //  - what cardinality we expect the database file to be
113     //  - whether we need to open the database in read only or
114     //    read/write mode (learning requires read/write access)
115     //
116     if (! dbfilename.empty()
117         && cardinality > 0
118         && learn_mode_set ) {
119 
120         delete db;
121 
122         if (dbloglevel.empty ()) {
123 	    // open database connector
124 	    db = new SqliteDatabaseConnector(dbfilename,
125 					     cardinality,
126 					     learn_mode);
127         } else {
128             // open database connector with logger lever
129             db = new SqliteDatabaseConnector(dbfilename,
130 					     cardinality,
131 					     learn_mode,
132 					     dbloglevel);
133         }
134     }
135 }
136 
137 
138 // convenience function to convert ngram to string
139 //
ngram_to_string(const Ngram & ngram)140 static std::string ngram_to_string(const Ngram& ngram)
141 {
142     const char separator[] = "|";
143     std::string result = separator;
144 
145     for (Ngram::const_iterator it = ngram.begin();
146 	 it != ngram.end();
147 	 it++)
148     {
149 	result += *it + separator;
150     }
151 
152     return result;
153 }
154 
155 
156 /** \brief Builds the required n-gram and returns its count.
157  *
158  * \param tokens tokens[i] contains ContextTracker::getToken(i)
159  * \param offset entry point into tokens, must be a non-positive number
160  * \param ngram_size size of the ngram whose count is returned, must not be greater than tokens size
161  * \return count of the ngram built based on tokens, offset and ngram_size
162  *
163  * \verbatim
164  Let tokens = [ "how", "are", "you", "today" ];
165 
166  count(tokens,  0, 3) returns the count associated with 3-gram [ "are", "you", "today" ].
167  count(tokens, -1, 2) returns the count associated with 2-gram [ "are", "you" ];
168  * \endverbatim
169  *
170  */
count(const std::vector<std::string> & tokens,int offset,int ngram_size) const171 unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const
172 {
173     unsigned int result = 0;
174 
175     assert(offset <= 0);      // TODO: handle this better
176     assert(ngram_size >= 0);
177 
178     if (ngram_size > 0) {
179 	Ngram ngram(ngram_size);
180 	copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin());
181 	result = db->getNgramCount(ngram);
182 	logger << DEBUG << "count ngram: " << ngram_to_string (ngram) << " : " << result << endl;
183     } else {
184 	result = db->getUnigramCountsSum();
185 	logger << DEBUG << "unigram counts sum: " << result << endl;
186     }
187 
188     return result;
189 }
190 
predict(const size_t max_partial_prediction_size,const char ** filter) const191 Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
192 {
193     logger << DEBUG << "predict()" << endl;
194 
195     // Result prediction
196     Prediction prediction;
197 
198     // Cache all the needed tokens.
199     // tokens[k] corresponds to w_{i-k} in the generalized smoothed
200     // n-gram probability formula
201     //
202     std::vector<std::string> tokens(cardinality);
203     for (int i = 0; i < cardinality; i++) {
204 	tokens[cardinality - 1 - i] = contextTracker->getToken(i);
205 	logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
206     }
207 
208     // Generate list of prefix completition candidates.
209     //
210     // The prefix completion candidates used to be obtained from the
211     // _1_gram table because in a well-constructed ngram database the
212     // _1_gram table (which contains all known tokens). However, this
213     // introduced a skew, since the unigram counts will take
214     // precedence over the higher-order counts.
215     //
216     // The current solution retrieves candidates from the highest
217     // n-gram table, falling back on lower order n-gram tables if
218     // initial completion set is smaller than required.
219     //
220     std::vector<std::string> prefixCompletionCandidates;
221     for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
222         logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
223         // create n-gram used to retrieve initial prefix completion table
224         Ngram prefix_ngram(k);
225         copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());
226 
227 	if (logger.shouldLog()) {
228 	    logger << DEBUG << "prefix_ngram: ";
229 	    for (size_t r = 0; r < prefix_ngram.size(); r++) {
230 		logger << DEBUG << prefix_ngram[r] << ' ';
231 	    }
232 	    logger << DEBUG << endl;
233 	}
234 
235         // obtain initial prefix completion candidates
236         db->beginTransaction();
237 
238         NgramTable partial;
239 
240         if (filter == 0) {
241 	    partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size());
242 	} else {
243 	    partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size());
244 	}
245 
246         db->endTransaction();
247 
248 	if (logger.shouldLog()) {
249 	    logger << DEBUG << "partial prefixCompletionCandidates" << endl
250 	           << DEBUG << "----------------------------------" << endl;
251 	    for (size_t j = 0; j < partial.size(); j++) {
252 		for (size_t k = 0; k < partial[j].size(); k++) {
253 		    logger << DEBUG << partial[j][k] << " ";
254 		}
255 		logger << endl;
256 	    }
257 	}
258 
259         logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;
260 
261         // append newly discovered potential completions to prefix
262         // completion candidates array to fill it up to
263         // max_partial_prediction_size
264         //
265         std::vector<Ngram>::const_iterator it = partial.begin();
266         while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
267             // only add new candidates, iterator it points to Ngram,
268             // it->end() - 2 points to the token candidate
269             //
270             std::string candidate = *(it->end() - 2);
271             if (find(prefixCompletionCandidates.begin(),
272                      prefixCompletionCandidates.end(),
273                      candidate) == prefixCompletionCandidates.end()) {
274                 prefixCompletionCandidates.push_back(candidate);
275             }
276             it++;
277         }
278     }
279 
280     if (logger.shouldLog()) {
281 	logger << DEBUG << "prefixCompletionCandidates" << endl
282 	       << DEBUG << "--------------------------" << endl;
283 	for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
284 	    logger << DEBUG << prefixCompletionCandidates[j] << endl;
285 	}
286     }
287 
288     // compute smoothed probabilities for all candidates
289     //
290     db->beginTransaction();
291     // getUnigramCountsSum is an expensive SQL query
292     // caching it here saves much time later inside the loop
293     int unigrams_counts_sum = db->getUnigramCountsSum();
294     for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
295         // store w_i candidate at end of tokens
296         tokens[cardinality - 1] = prefixCompletionCandidates[j];
297 
298 	logger << DEBUG << "------------------" << endl;
299 	logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;
300 
301 	double probability = 0;
302 	for (int k = 0; k < cardinality; k++) {
303 	    double numerator = count(tokens, 0, k+1);
304 	    // reuse cached unigrams_counts_sum to speed things up
305 	    double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
306 	    double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
307 	    probability += deltas[k] * frequency;
308 
309 	    logger << DEBUG << "numerator:   " << numerator << endl;
310 	    logger << DEBUG << "denominator: " << denominator << endl;
311 	    logger << DEBUG << "frequency:   " << frequency << endl;
312 	    logger << DEBUG << "delta:       " << deltas[k] << endl;
313 
314             // for some sanity checks
315 	    assert(numerator <= denominator);
316 	    assert(frequency <= 1);
317 	}
318 
319         logger << DEBUG << "____________" << endl;
320         logger << DEBUG << "probability: " << probability << endl;
321 
322 	if (probability > 0) {
323 	    prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
324 	}
325     }
326     db->endTransaction();
327 
328     logger << DEBUG << "Prediction:" << endl;
329     logger << DEBUG << "-----------" << endl;
330     logger << DEBUG << prediction << endl;
331 
332     return prediction;
333 }
334 
learn(const std::vector<std::string> & change)335 void SmoothedNgramPredictor::learn(const std::vector<std::string>& change)
336 {
337     logger << INFO << "learn(\"" << ngram_to_string(change) << "\")" << endl;
338 
339     if (learn_mode) {
340 	// learning is turned on
341 
342 	std::map<std::list<std::string>, int> ngramMap;
343 
344 	// build up ngram map for all cardinalities
345 	// i.e. learn all ngrams and counts in memory
346 	for (size_t curr_cardinality = 1;
347 	     curr_cardinality < cardinality + 1;
348 	     curr_cardinality++)
349 	{
350 	    int change_idx = 0;
351 	    int change_size = change.size();
352 
353 	    std::list<std::string> ngram_list;
354 
355 	    // take care of first N-1 tokens
356 	    for (int i = 0;
357 		 (i < curr_cardinality - 1 && change_idx < change_size);
358 		 i++)
359 	    {
360 		ngram_list.push_back(change[change_idx]);
361 		change_idx++;
362 	    }
363 
364 	    while (change_idx < change_size)
365 	    {
366 		ngram_list.push_back(change[change_idx++]);
367 		ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
368 		ngram_list.pop_front();
369 	    }
370 	}
371 
372 	// use (past stream - change) to learn token at the boundary
373 	// change, i.e.
374 	//
375 
376 	// if change is "bar foobar", then "bar" will only occur in a
377 	// 1-gram, since there are no token before it. By dipping in
378 	// the past stream, we additional context to learn a 2-gram by
379 	// getting extra tokens (assuming past stream ends with token
380 	// "foo":
381 	//
382 	// <"foo", "bar"> will be learnt
383 	//
384 	// We do this till we build up to n equal to cardinality.
385 	//
386 	// First check that change is not empty (nothing to learn) and
387 	// that change and past stream match by sampling first and
388 	// last token in change and comparing them with corresponding
389 	// tokens from past stream
390 	//
391 	if (change.size() > 0 &&
392 	    change.back() == contextTracker->getToken(1) &&
393 	    change.front() == contextTracker->getToken(change.size()))
394 	{
395 	    // create ngram list with first (oldest) token from change
396 	    std::list<std::string> ngram_list(change.begin(), change.begin() + 1);
397 
398 	    // prepend token to ngram list by grabbing extra tokens
399 	    // from past stream (if there are any) till we have built
400 	    // up to n==cardinality ngrams, and commit them to
401 	    // ngramMap
402 	    //
403 	    for (int tk_idx = 1;
404 		 ngram_list.size() < cardinality;
405 		 tk_idx++)
406 	    {
407 		// getExtraTokenToLearn returns tokens from
408 		// past stream that come before and are not in
409 		// change vector
410 		//
411 		std::string extra_token = contextTracker->getExtraTokenToLearn(tk_idx, change);
412 		logger << DEBUG << "Adding extra token: " << extra_token << endl;
413 
414 		if (extra_token.empty())
415 		{
416 		    break;
417 		}
418 		ngram_list.push_front(extra_token);
419 
420 		ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
421 	    }
422 	}
423 
424 	// then write out to language model database
425 	try
426 	{
427 	    db->beginTransaction();
428 
429 	    std::map<std::list<std::string>, int>::const_iterator it;
430 	    for (it = ngramMap.begin(); it != ngramMap.end(); it++)
431 	    {
432 		// convert ngram from list to vector based Ngram
433 		Ngram ngram((it->first).begin(), (it->first).end());
434 
435 		// update the counts
436 		int count = db->getNgramCount(ngram);
437 		if (count > 0)
438 		{
439 		    // ngram already in database, update count
440 		    db->updateNgram(ngram, count + it->second);
441 		    check_learn_consistency(ngram);
442 		}
443 		else
444 		{
445 		    // ngram not in database, insert it
446 		    db->insertNgram(ngram, it->second);
447 		}
448 	    }
449 
450 	    db->endTransaction();
451 	    logger << INFO << "Committed learning update to database" << endl;
452 	}
453 	catch (SqliteDatabaseConnector::SqliteDatabaseConnectorException& ex)
454 	{
455 	    db->rollbackTransaction();
456 	    logger << ERROR << "Rolling back learning update : " << ex.what() << endl;
457 	    throw;
458 	}
459     }
460 
461     logger << DEBUG << "end learn()" << endl;
462 }
463 
check_learn_consistency(const Ngram & ngram) const464 void SmoothedNgramPredictor::check_learn_consistency(const Ngram& ngram) const
465 {
466     // no need to begin a new transaction, as we'll be called from
467     // within an existing transaction from learn()
468 
469     // BEWARE: if the previous sentence is not true, then performance
470     // WILL suffer!
471 
472     size_t size = ngram.size();
473     for (size_t i = 0; i < size; i++) {
474 	if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) {
475 	    logger << INFO << "consistency adjustment needed!" << endl;
476 
477 	    int offset = -(i + 1);
478 	    int sub_ngram_size = size - (i + 1);
479 
480 	    logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl;
481 
482 	    Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram
483 	    copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin());
484 
485 	    if (logger.shouldLog()) {
486 		logger << "ngram to be count adjusted is: ";
487 		for (size_t i = 0; i < sub_ngram.size(); i++) {
488 		    logger << sub_ngram[i] << ' ';
489 		}
490 		logger << endl;
491 	    }
492 
493 	    db->incrementNgramCount(sub_ngram);
494 	    logger << DEBUG << "consistency adjusted" << endl;
495 	}
496     }
497 }
498 
update(const Observable * var)499 void SmoothedNgramPredictor::update (const Observable* var)
500 {
501     logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
502     dispatcher.dispatch (var);
503 }
504