1 /******************************************************************************
2 IrstLM: IRST Language Model Toolkit, compile LM
3 Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
4
5 This library is free software; you can redistribute it and/or
6 modify it under the terms of the GNU Lesser General Public
7 License as published by the Free Software Foundation; either
8 version 2.1 of the License, or (at your option) any later version.
9
10 This library is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13 Lesser General Public License for more details.
14
15 You should have received a copy of the GNU Lesser General Public
16 License along with this library; if not, write to the Free Software
17 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18
19 ******************************************************************************/
20
21 using namespace std;
22
23 #include <iostream>
24 #include <fstream>
25 #include <sstream>
26 #include <stdexcept>
27 #include <vector>
28 #include <string>
29 #include <stdlib.h>
30 #include "cmd.h"
31 #include "util.h"
32 #include "math.h"
33 #include "lmContainer.h"
34 /********************************/
35
36
error(const char * message)37 inline void error(const char* message)
38 {
39 std::cerr << message << "\n";
40 throw std::runtime_error(message);
41 }
42
43 lmContainer* load_lm(std::string file,int requiredMaxlev,int dub,int memmap, float nlf, float dlf);
44
print_help(int TypeFlag=0)45 void print_help(int TypeFlag=0){
46 std::cerr << std::endl << "interpolate-lm - interpolates language models" << std::endl;
47 std::cerr << std::endl << "USAGE:" << std::endl;
48 std::cerr << " interpolate-lm [options] <lm-list-file> [lm-list-file.out]" << std::endl;
49
50 std::cerr << std::endl << "DESCRIPTION:" << std::endl;
51 std::cerr << " interpolate-lm reads a LM list file including interpolation weights " << std::endl;
52 std::cerr << " with the format: N\\n w1 lm1 \\n w2 lm2 ...\\n wN lmN\n" << std::endl;
53 std::cerr << " It estimates new weights on a development text, " << std::endl;
54 std::cerr << " computes the perplexity on an evaluation text, " << std::endl;
55 std::cerr << " computes probabilities of n-grams read from stdin." << std::endl;
56 std::cerr << " It reads LMs in ARPA and IRSTLM binary format." << std::endl;
57
58 std::cerr << std::endl << "OPTIONS:" << std::endl;
59 FullPrintParams(TypeFlag, 0, 1, stderr);
60
61 }
62
usage(const char * msg=0)63 void usage(const char *msg = 0)
64 {
65 if (msg){
66 std::cerr << msg << std::endl;
67 }
68 else{
69 print_help();
70 }
71 exit(1);
72 }
73
main(int argc,char ** argv)74 int main(int argc, char **argv)
75 {
76 char *slearn = NULL;
77 char *seval = NULL;
78 bool learn=false;
79 bool score=false;
80 bool sent_PP_flag = false;
81
82 int order = 0;
83 int debug = 0;
84 int memmap = 0;
85 int requiredMaxlev = 1000;
86 int dub = 10000000;
87 float ngramcache_load_factor = 0.0;
88 float dictionary_load_factor = 0.0;
89
90 bool help=false;
91 std::vector<std::string> files;
92
93 DeclareParams((char*)
94
95 "learn", CMDSTRINGTYPE|CMDMSG, &slearn, "learn optimal interpolation for text-file; default is false",
96 "l", CMDSTRINGTYPE|CMDMSG, &slearn, "learn optimal interpolation for text-file; default is false",
97
98 "order", CMDINTTYPE|CMDMSG, &order, "order of n-grams used in --learn (optional)",
99 "o", CMDINTTYPE|CMDMSG, &order, "order of n-grams used in --learn (optional)",
100
101 "eval", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
102 "e", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
103
104 "dub", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
105
106 "score", CMDBOOLTYPE|CMDMSG, &score, "computes log-prob scores of n-grams from standard input",
107 "s", CMDBOOLTYPE|CMDMSG, &score, "computes log-prob scores of n-grams from standard input",
108
109 "debug", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
110 "d", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
111 "memmap", CMDINTTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
112 "mm", CMDINTTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
113 "sentence", CMDBOOLTYPE|CMDMSG, &sent_PP_flag, "computes perplexity at sentence level (identified through the end symbol)",
114 "dict_load_factor", CMDFLOATTYPE|CMDMSG, &dictionary_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is 0",
115 "ngram_load_factor", CMDFLOATTYPE|CMDMSG, &ngramcache_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is false",
116 "level", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
117 "lev", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
118
119 "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
120 "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
121
122 (char *)NULL
123 );
124
125 if (argc == 1){
126 usage();
127 }
128
129 for(int i=1; i < argc; i++) {
130 if(argv[i][0] != '-') files.push_back(argv[i]);
131 }
132
133 GetParams(&argc, &argv, (char*) NULL);
134
135 if (help){
136 usage();
137 }
138
139 if (files.size() > 2) {
140 usage("Warning: Too many arguments");
141 }
142
143 if (files.size() < 1) {
144 usage("Warning: specify a LM list file to read from");
145 }
146
147 std::string infile = files[0];
148 std::string outfile="";
149
150 if (files.size() == 1) {
151 outfile=infile;
152 //remove path information
153 std::string::size_type p = outfile.rfind('/');
154 if (p != std::string::npos && ((p+1) < outfile.size()))
155 outfile.erase(0,p+1);
156 outfile+=".out";
157 } else
158 outfile = files[1];
159
160 std::cerr << "inpfile: " << infile << std::endl;
161 learn = ((slearn != NULL)? true : false);
162
163 if (learn) std::cerr << "outfile: " << outfile << std::endl;
164 if (score) std::cerr << "interactive: " << score << std::endl;
165 if (memmap) std::cerr << "memory mapping: " << memmap << std::endl;
166 std::cerr << "loading up to the LM level " << requiredMaxlev << " (if any)" << std::endl;
167 std::cerr << "order: " << order << std::endl;
168 if (requiredMaxlev > 0) std::cerr << "loading up to the LM level " << requiredMaxlev << " (if any)" << std::endl;
169
170 std::cerr << "dub: " << dub<< std::endl;
171
172 lmContainer *lmt[100], *start_lmt[100]; //interpolated language models
173 std::string lmf[100]; //lm filenames
174
175 float w[100]; //interpolation weights
176 int N;
177
178
179 //Loading Language Models`
180 std::cerr << "Reading " << infile << "..." << std::endl;
181 std::fstream inptxt(infile.c_str(),std::ios::in);
182
183 //std::string line;
184 char line[BUFSIZ];
185 const char* words[3];
186 int tokenN;
187
188 inptxt.getline(line,BUFSIZ,'\n');
189 tokenN = parseWords(line,words,3);
190
191 if (tokenN != 2 || ((strcmp(words[0],"LMINTERPOLATION") != 0) && (strcmp(words[0],"lminterpolation")!=0)))
192 error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMINTERPOLATION number_of_models\nweight_of_LM_1 filename_of_LM_1\nweight_of_LM_2 filename_of_LM_2");
193
194 N=atoi(words[1]);
195 std::cerr << "Number of LMs: " << N << "..." << std::endl;
196 if(N > 100) {
197 std::cerr << "Can't interpolate more than 100 language models." << std::endl;
198 exit(1);
199 }
200
201 for (int i=0; i<N; i++) {
202 inptxt.getline(line,BUFSIZ,'\n');
203 tokenN = parseWords(line,words,3);
204 if(tokenN != 2) {
205 std::cerr << "Wrong input format." << std::endl;
206 exit(1);
207 }
208 w[i] = (float) atof(words[0]);
209 lmf[i] = words[1];
210
211 std::cerr << "i:" << i << " w[i]:" << w[i] << " lmf[i]:" << lmf[i] << std::endl;
212 start_lmt[i] = lmt[i] = load_lm(lmf[i],requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
213 }
214
215 inptxt.close();
216
217 int maxorder = 0;
218 for (int i=0; i<N; i++) {
219 maxorder = (maxorder > lmt[i]->maxlevel())?maxorder:lmt[i]->maxlevel();
220 }
221
222 if (order <= 0) {
223 order = maxorder;
224 std::cerr << "order is not set or wrongly set to a non positive value; reset to the maximum order of LMs: " << order << std::endl;
225 } else if (order > maxorder) {
226 order = maxorder;
227 std::cerr << "order is too high; reset to the maximum order of LMs" << order << std::endl;
228 }
229
230 //Learning mixture weights
231 if (learn) {
232
233 std::vector< std::vector<float> > p(N); //LM probabilities
234 float c[N]; //expected counts
235 float den,norm; //inner denominator, normalization term
236 float variation=1.0; // global variation between new old params
237
238 dictionary* dict=new dictionary(slearn,1000000,dictionary_load_factor);
239 ngram ng(dict);
240 int bos=ng.dict->encode(ng.dict->BoS());
241 std::ifstream dev(slearn,std::ios::in);
242
243 for(;;) {
244 std::string line;
245 getline(dev, line);
246 if(dev.eof())
247 break;
248 if(dev.fail()) {
249 std::cerr << "Problem reading input file " << seval << std::endl;
250 exit(1);
251 }
252 std::istringstream lstream(line);
253 if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
254 std::string token, newlm;
255 int id;
256 lstream >> token >> id >> newlm;
257 if(id <= 0 || id > N) {
258 std::cerr << "LM id out of range." << std::endl;
259 return 1;
260 }
261 id--; // count from 0 now
262 if(lmt[id] != start_lmt[id])
263 delete lmt[id];
264 lmt[id] = load_lm(newlm,requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
265 continue;
266 }
267 while(lstream >> ng) {
268
269 // reset ngram at begin of sentence
270 if (*ng.wordp(1)==bos) {
271 ng.size=1;
272 continue;
273 }
274 if (order > 0 && ng.size > order) ng.size=order;
275 for (int i=0; i<N; i++) {
276 ngram ong(lmt[i]->getDict());
277 ong.trans(ng);
278 double logpr;
279 logpr = lmt[i]->clprob(ong); //LM log-prob (using caches if available)
280 p[i].push_back(pow(10.0,logpr));
281 }
282 }
283
284 for (int i=0; i<N; i++) lmt[i]->check_caches_levels();
285 }
286 dev.close();
287
288 while( variation > 0.01 ) {
289
290 for (int i=0; i<N; i++) c[i]=0; //reset counters
291
292 for(unsigned i = 0; i < p[0].size(); i++) {
293 den=0.0;
294 for(int j = 0; j < N; j++)
295 den += w[j] * p[j][i]; //denominator of EM formula
296 //update expected counts
297 for(int j = 0; j < N; j++)
298 c[j] += w[j] * p[j][i] / den;
299 }
300
301 norm=0.0;
302 for (int i=0; i<N; i++) norm+=c[i];
303
304 //update weights and compute distance
305 variation=0.0;
306 for (int i=0; i<N; i++) {
307 c[i]/=norm; //c[i] is now the new weight
308 variation+=(w[i]>c[i]?(w[i]-c[i]):(c[i]-w[i]));
309 w[i]=c[i]; //update weights
310 }
311 std::cerr << "Variation " << variation << std::endl;
312 }
313
314 //Saving results
315 std::cerr << "Saving in " << outfile << "..." << std::endl;
316 //saving result
317 std::fstream outtxt(outfile.c_str(),std::ios::out);
318 outtxt << "LMINTERPOLATION " << N << "\n";
319 for (int i=0; i<N; i++) outtxt << w[i] << " " << lmf[i] << "\n";
320 outtxt.close();
321 }
322
323 for(int i = 0; i < N; i++)
324 if(lmt[i] != start_lmt[i]) {
325 delete lmt[i];
326 lmt[i] = start_lmt[i];
327 }
328
329 if (seval != NULL) {
330 std::cerr << "Start Eval" << std::endl;
331
332 std::cout.setf(ios::fixed);
333 std::cout.precision(2);
334 int i;
335 int Nw=0,Noov_all=0, Noov_any=0, Nbo=0;
336 double Pr,lPr;
337 double logPr=0,PP=0;
338
339 // variables for storing sentence-based Perplexity
340 int sent_Nw=0, sent_Noov_all=0, sent_Noov_any=0, sent_Nbo=0;
341 double sent_logPr=0,sent_PP=0;
342
343 //normalize weights
344 for (i=0,Pr=0; i<N; i++) Pr+=w[i];
345 for (i=0; i<N; i++) w[i]/=Pr;
346
347 dictionary* dict=new dictionary(NULL,1000000,dictionary_load_factor);
348 dict->incflag(1);
349 ngram ng(dict);
350 int bos=ng.dict->encode(ng.dict->BoS());
351 int eos=ng.dict->encode(ng.dict->EoS());
352
353 std::fstream inptxt(seval,std::ios::in);
354
355 for(;;) {
356 std::string line;
357 getline(inptxt, line);
358 if(inptxt.eof())
359 break;
360 if(inptxt.fail()) {
361 std::cerr << "Problem reading input file " << seval << std::endl;
362 return 1;
363 }
364 std::istringstream lstream(line);
365 if(line.substr(0, 26) == "###interpolate-lm:weights ") {
366 std::string token;
367 lstream >> token;
368 for(int i = 0; i < N; i++) {
369 if(lstream.eof()) {
370 std::cerr << "Not enough weights!" << std::endl;
371 return 1;
372 }
373 lstream >> w[i];
374 }
375 continue;
376 }
377 if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
378 std::string token, newlm;
379 int id;
380 lstream >> token >> id >> newlm;
381 if(id <= 0 || id > N) {
382 std::cerr << "LM id out of range." << std::endl;
383 return 1;
384 }
385 id--; // count from 0 now
386 delete lmt[id];
387 lmt[id] = load_lm(newlm,requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
388 continue;
389 }
390
391 double bow;
392 int bol=0;
393 char *msp;
394 unsigned int statesize;
395
396 while(lstream >> ng) {
397
398 // reset ngram at begin of sentence
399 if (*ng.wordp(1)==bos) {
400 ng.size=1;
401 continue;
402 }
403 if (order > 0 && ng.size > order) ng.size=order;
404
405
406 if (ng.size>=1) {
407
408 int minbol=MAX_NGRAM; //minimum backoff level of the mixture
409 bool OOV_all_flag=true; //OOV flag wrt all LM[i]
410 bool OOV_any_flag=false; //OOV flag wrt any LM[i]
411 float logpr;
412
413 Pr = 0.0;
414 for (i=0; i<N; i++) {
415
416 ngram ong(lmt[i]->getDict());
417 ong.trans(ng);
418 logpr = lmt[i]->clprob(ong,&bow,&bol,&msp,&statesize); //actual prob of the interpolation
419 //logpr = lmt[i]->clprob(ong,&bow,&bol); //LM log-prob
420
421 Pr+=w[i] * pow(10.0,logpr); //actual prob of the interpolation
422 if (bol < minbol) minbol=bol; //backoff of LM[i]
423
424 if (*ong.wordp(1) != lmt[i]->getDict()->oovcode()) OOV_all_flag=false; //OOV wrt LM[i]
425 if (*ong.wordp(1) == lmt[i]->getDict()->oovcode()) OOV_any_flag=true; //OOV wrt LM[i]
426 }
427
428 lPr=log(Pr)/M_LN10;
429 logPr+=lPr;
430 sent_logPr+=lPr;
431
432 if (debug==1) {
433 std::cout << ng.dict->decode(*ng.wordp(1)) << " [" << ng.size-minbol << "]" << " ";
434 if (*ng.wordp(1)==eos) std::cout << std::endl;
435 }
436 if (debug==2)
437 std::cout << ng << " [" << ng.size-minbol << "-gram]" << " " << log(Pr) << std::endl;
438
439 if (minbol) {
440 Nbo++; //all LMs have back-offed by at least one
441 sent_Nbo++;
442 }
443
444 if (OOV_all_flag) {
445 Noov_all++; //word is OOV wrt all LM
446 sent_Noov_all++;
447 }
448 if (OOV_any_flag) {
449 Noov_any++; //word is OOV wrt any LM
450 sent_Noov_any++;
451 }
452
453 Nw++;
454 sent_Nw++;
455
456 if (*ng.wordp(1)==eos && sent_PP_flag) {
457 sent_PP=exp((-sent_logPr * log(10.0)) /sent_Nw);
458 std::cout << "%% sent_Nw=" << sent_Nw
459 << " sent_PP=" << sent_PP
460 << " sent_Nbo=" << sent_Nbo
461 << " sent_Noov=" << sent_Noov_all
462 << " sent_OOV=" << (float)sent_Noov_all/sent_Nw * 100.0 << "%"
463 << " sent_Noov_any=" << sent_Noov_any
464 << " sent_OOV_any=" << (float)sent_Noov_any/sent_Nw * 100.0 << "%" << std::endl;
465 //reset statistics for sentence based Perplexity
466 sent_Nw=sent_Noov_any=sent_Noov_all=sent_Nbo=0;
467 sent_logPr=0.0;
468 }
469
470
471 if ((Nw % 10000)==0) std::cerr << ".";
472 }
473 }
474 }
475
476 PP=exp((-logPr * M_LN10) /Nw);
477
478 std::cout << "%% Nw=" << Nw
479 << " PP=" << PP
480 << " Nbo=" << Nbo
481 << " Noov=" << Noov_all
482 << " OOV=" << (float)Noov_all/Nw * 100.0 << "%"
483 << " Noov_any=" << Noov_any
484 << " OOV_any=" << (float)Noov_any/Nw * 100.0 << "%" << std::endl;
485
486 };
487
488
489 if (score == true) {
490
491
492 dictionary* dict=new dictionary(NULL,1000000,dictionary_load_factor);
493 dict->incflag(1); // start generating the dictionary;
494 ngram ng(dict);
495 int bos=ng.dict->encode(ng.dict->BoS());
496
497 double Pr,logpr;
498
499 double bow;
500 int bol=0, maxbol=0;
501 unsigned int maxstatesize, statesize;
502 int i,n=0;
503 std::cout << "> ";
504 while(std::cin >> ng) {
505
506 // reset ngram at begin of sentence
507 if (*ng.wordp(1)==bos) {
508 ng.size=1;
509 continue;
510 }
511
512 if (ng.size>=maxorder) {
513
514 if (order > 0 && ng.size > order) ng.size=order;
515 n++;
516 maxstatesize=0;
517 maxbol=0;
518 Pr=0.0;
519 for (i=0; i<N; i++) {
520 ngram ong(lmt[i]->getDict());
521 ong.trans(ng);
522 logpr = lmt[i]->clprob(ong,&bow,&bol,NULL,&statesize); //LM log-prob (using caches if available)
523
524 Pr+=w[i] * pow(10.0,logpr); //actual prob of the interpolation
525 std::cout << "lm " << i << ":" << " logpr: " << logpr << " weight: " << w[i] << std::endl;
526 if (maxbol<bol) maxbol=bol;
527 if (maxstatesize<statesize) maxstatesize=statesize;
528 }
529
530 std::cout << ng << " p= " << log(Pr) << " bo= " << maxbol << " recombine= " << maxstatesize << std::endl;
531
532 if ((n % 10000000)==0) {
533 std::cerr << "." << std::endl;
534 for (i=0; i<N; i++) lmt[i]->check_caches_levels();
535 }
536
537 } else {
538 std::cout << ng << " p= NULL" << std::endl;
539 }
540 std::cout << "> ";
541 }
542
543
544 }
545
546 for (int i=0; i<N; i++) delete lmt[i];
547
548 return 0;
549 }
550
load_lm(std::string file,int requiredMaxlev,int dub,int memmap,float nlf,float dlf)551 lmContainer* load_lm(std::string file,int requiredMaxlev,int dub,int memmap, float nlf, float dlf)
552 {
553 lmContainer* lmt=NULL;
554
555 lmt = lmt->CreateLanguageModel(file,nlf,dlf);
556
557 lmt->setMaxLoadedLevel(requiredMaxlev);
558
559 lmt->load(file,memmap);
560
561 if (dub) lmt->setlogOOVpenalty((int)dub);
562
563 //use caches to save time (only if PS_CACHE_ENABLE is defined through compilation flags)
564 lmt->init_caches(lmt->maxlevel());
565 return lmt;
566 }
567