1 /******************************************************************************
2  IrstLM: IRST Language Model Toolkit, compile LM
3  Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
4 
5  This library is free software; you can redistribute it and/or
6  modify it under the terms of the GNU Lesser General Public
7  License as published by the Free Software Foundation; either
8  version 2.1 of the License, or (at your option) any later version.
9 
10  This library is distributed in the hope that it will be useful,
11  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  Lesser General Public License for more details.
14 
15  You should have received a copy of the GNU Lesser General Public
16  License along with this library; if not, write to the Free Software
17  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301 USA
18 
19  ******************************************************************************/
20 
21 using namespace std;
22 
23 #include <iostream>
24 #include <fstream>
25 #include <sstream>
26 #include <stdexcept>
27 #include <vector>
28 #include <string>
29 #include <stdlib.h>
30 #include "cmd.h"
31 #include "util.h"
32 #include "math.h"
33 #include "lmContainer.h"
34 /********************************/
35 
36 
error(const char * message)37 inline void error(const char* message)
38 {
39   std::cerr << message << "\n";
40   throw std::runtime_error(message);
41 }
42 
43 lmContainer* load_lm(std::string file,int requiredMaxlev,int dub,int memmap, float nlf, float dlf);
44 
print_help(int TypeFlag=0)45 void print_help(int TypeFlag=0){
46   std::cerr << std::endl << "interpolate-lm - interpolates language models" << std::endl;
47   std::cerr << std::endl << "USAGE:"  << std::endl;
48 	std::cerr << "       interpolate-lm [options] <lm-list-file> [lm-list-file.out]" << std::endl;
49 
50 	std::cerr << std::endl << "DESCRIPTION:" << std::endl;
51 	std::cerr << "       interpolate-lm reads a LM list file including interpolation weights " << std::endl;
52 	std::cerr << "       with the format: N\\n w1 lm1 \\n w2 lm2 ...\\n wN lmN\n" << std::endl;
53 	std::cerr << "       It estimates new weights on a development text, " << std::endl;
54 	std::cerr << "       computes the perplexity on an evaluation text, " << std::endl;
55 	std::cerr << "       computes probabilities of n-grams read from stdin." << std::endl;
56 	std::cerr << "       It reads LMs in ARPA and IRSTLM binary format." << std::endl;
57 
58   std::cerr << std::endl << "OPTIONS:" << std::endl;
59 	FullPrintParams(TypeFlag, 0, 1, stderr);
60 
61 }
62 
usage(const char * msg=0)63 void usage(const char *msg = 0)
64 {
65   if (msg){
66     std::cerr << msg << std::endl;
67 	}
68   else{
69 		print_help();
70 	}
71 	exit(1);
72 }
73 
main(int argc,char ** argv)74 int main(int argc, char **argv)
75 {
76 	char *slearn = NULL;
77 	char *seval = NULL;
78 	bool learn=false;
79 	bool score=false;
80 	bool sent_PP_flag = false;
81 
82 	int order = 0;
83 	int debug = 0;
84   int memmap = 0;
85   int requiredMaxlev = 1000;
86   int dub = 10000000;
87   float ngramcache_load_factor = 0.0;
88   float dictionary_load_factor = 0.0;
89 
90 	bool help=false;
91   std::vector<std::string> files;
92 
93 	DeclareParams((char*)
94 
95 								"learn", CMDSTRINGTYPE|CMDMSG, &slearn, "learn optimal interpolation for text-file; default is false",
96 								"l", CMDSTRINGTYPE|CMDMSG, &slearn, "learn optimal interpolation for text-file; default is false",
97 
98 								"order", CMDINTTYPE|CMDMSG, &order, "order of n-grams used in --learn (optional)",
99 								"o", CMDINTTYPE|CMDMSG, &order, "order of n-grams used in --learn (optional)",
100 
101                 "eval", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
102 								"e", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
103 
104                 "dub", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
105 
106                 "score", CMDBOOLTYPE|CMDMSG, &score, "computes log-prob scores of n-grams from standard input",
107 								"s", CMDBOOLTYPE|CMDMSG, &score, "computes log-prob scores of n-grams from standard input",
108 
109                 "debug", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
110 								"d", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
111                 "memmap", CMDINTTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
112 								"mm", CMDINTTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
113 								"sentence", CMDBOOLTYPE|CMDMSG, &sent_PP_flag, "computes perplexity at sentence level (identified through the end symbol)",
114                 "dict_load_factor", CMDFLOATTYPE|CMDMSG, &dictionary_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is 0",
115                 "ngram_load_factor", CMDFLOATTYPE|CMDMSG, &ngramcache_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is false",
116                 "level", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
117 								"lev", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
118 
119 								"Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
120 								"h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
121 
122 								(char *)NULL
123 								);
124 
125 	if (argc == 1){
126 		usage();
127 	}
128 
129 	for(int i=1; i < argc; i++) {
130 		if(argv[i][0] != '-') files.push_back(argv[i]);
131 	}
132 
133   GetParams(&argc, &argv, (char*) NULL);
134 
135 	if (help){
136 		usage();
137 	}
138 
139   if (files.size() > 2) {
140     usage("Warning: Too many arguments");
141   }
142 
143   if (files.size() < 1) {
144     usage("Warning: specify a LM list file to read from");
145   }
146 
147   std::string infile = files[0];
148   std::string outfile="";
149 
150   if (files.size() == 1) {
151     outfile=infile;
152     //remove path information
153     std::string::size_type p = outfile.rfind('/');
154     if (p != std::string::npos && ((p+1) < outfile.size()))
155       outfile.erase(0,p+1);
156     outfile+=".out";
157   } else
158     outfile = files[1];
159 
160   std::cerr << "inpfile: " << infile << std::endl;
161   learn = ((slearn != NULL)? true : false);
162 
163   if (learn) std::cerr << "outfile: " << outfile << std::endl;
164   if (score) std::cerr << "interactive: " << score << std::endl;
165   if (memmap) std::cerr << "memory mapping: " << memmap << std::endl;
166   std::cerr << "loading up to the LM level " << requiredMaxlev << " (if any)" << std::endl;
167   std::cerr << "order: " << order << std::endl;
168   if (requiredMaxlev > 0) std::cerr << "loading up to the LM level " << requiredMaxlev << " (if any)" << std::endl;
169 
170   std::cerr << "dub: " << dub<< std::endl;
171 
172   lmContainer *lmt[100], *start_lmt[100]; //interpolated language models
173   std::string lmf[100]; //lm filenames
174 
175   float w[100]; //interpolation weights
176   int N;
177 
178 
179   //Loading Language Models`
180   std::cerr << "Reading " << infile << "..." << std::endl;
181   std::fstream inptxt(infile.c_str(),std::ios::in);
182 
183   //std::string line;
184   char line[BUFSIZ];
185   const char* words[3];
186   int tokenN;
187 
188   inptxt.getline(line,BUFSIZ,'\n');
189   tokenN = parseWords(line,words,3);
190 
191   if (tokenN != 2 || ((strcmp(words[0],"LMINTERPOLATION") != 0) && (strcmp(words[0],"lminterpolation")!=0)))
192     error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMINTERPOLATION number_of_models\nweight_of_LM_1 filename_of_LM_1\nweight_of_LM_2 filename_of_LM_2");
193 
194   N=atoi(words[1]);
195   std::cerr << "Number of LMs: " << N << "..." << std::endl;
196   if(N > 100) {
197     std::cerr << "Can't interpolate more than 100 language models." << std::endl;
198     exit(1);
199   }
200 
201   for (int i=0; i<N; i++) {
202     inptxt.getline(line,BUFSIZ,'\n');
203     tokenN = parseWords(line,words,3);
204     if(tokenN != 2) {
205       std::cerr << "Wrong input format." << std::endl;
206       exit(1);
207     }
208     w[i] = (float) atof(words[0]);
209     lmf[i] = words[1];
210 
211     std::cerr << "i:" << i << " w[i]:" << w[i] << " lmf[i]:" << lmf[i] << std::endl;
212     start_lmt[i] = lmt[i] = load_lm(lmf[i],requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
213   }
214 
215   inptxt.close();
216 
217   int maxorder = 0;
218   for (int i=0; i<N; i++) {
219     maxorder = (maxorder > lmt[i]->maxlevel())?maxorder:lmt[i]->maxlevel();
220   }
221 
222   if (order <= 0) {
223     order = maxorder;
224     std::cerr << "order is not set or wrongly set to a non positive value; reset to the maximum order of LMs: " << order << std::endl;
225   } else if (order > maxorder) {
226     order = maxorder;
227     std::cerr << "order is too high; reset to the maximum order of LMs" << order << std::endl;
228   }
229 
230   //Learning mixture weights
231   if (learn) {
232 
233     std::vector< std::vector<float> > p(N); //LM probabilities
234     float c[N]; //expected counts
235     float den,norm; //inner denominator, normalization term
236     float variation=1.0; // global variation between new old params
237 
238     dictionary* dict=new dictionary(slearn,1000000,dictionary_load_factor);
239     ngram ng(dict);
240     int bos=ng.dict->encode(ng.dict->BoS());
241     std::ifstream dev(slearn,std::ios::in);
242 
243     for(;;) {
244       std::string line;
245       getline(dev, line);
246       if(dev.eof())
247         break;
248       if(dev.fail()) {
249         std::cerr << "Problem reading input file " << seval << std::endl;
250         exit(1);
251       }
252       std::istringstream lstream(line);
253       if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
254         std::string token, newlm;
255         int id;
256         lstream >> token >> id >> newlm;
257         if(id <= 0 || id > N) {
258           std::cerr << "LM id out of range." << std::endl;
259           return 1;
260         }
261         id--; // count from 0 now
262         if(lmt[id] != start_lmt[id])
263           delete lmt[id];
264         lmt[id] = load_lm(newlm,requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
265         continue;
266       }
267       while(lstream >> ng) {
268 
269         // reset ngram at begin of sentence
270         if (*ng.wordp(1)==bos) {
271           ng.size=1;
272           continue;
273         }
274         if (order > 0 && ng.size > order) ng.size=order;
275         for (int i=0; i<N; i++) {
276           ngram ong(lmt[i]->getDict());
277           ong.trans(ng);
278           double logpr;
279           logpr = lmt[i]->clprob(ong); //LM log-prob (using caches if available)
280           p[i].push_back(pow(10.0,logpr));
281         }
282       }
283 
284       for (int i=0; i<N; i++) lmt[i]->check_caches_levels();
285     }
286     dev.close();
287 
288     while( variation > 0.01 ) {
289 
290       for (int i=0; i<N; i++) c[i]=0;	//reset counters
291 
292       for(unsigned i = 0; i < p[0].size(); i++) {
293         den=0.0;
294         for(int j = 0; j < N; j++)
295           den += w[j] * p[j][i]; //denominator of EM formula
296         //update expected counts
297         for(int j = 0; j < N; j++)
298           c[j] += w[j] * p[j][i] / den;
299       }
300 
301       norm=0.0;
302       for (int i=0; i<N; i++) norm+=c[i];
303 
304       //update weights and compute distance
305       variation=0.0;
306       for (int i=0; i<N; i++) {
307         c[i]/=norm; //c[i] is now the new weight
308         variation+=(w[i]>c[i]?(w[i]-c[i]):(c[i]-w[i]));
309         w[i]=c[i]; //update weights
310       }
311       std::cerr << "Variation " << variation << std::endl;
312     }
313 
314     //Saving results
315     std::cerr << "Saving in " << outfile << "..." << std::endl;
316     //saving result
317     std::fstream outtxt(outfile.c_str(),std::ios::out);
318     outtxt << "LMINTERPOLATION " << N << "\n";
319     for (int i=0; i<N; i++) outtxt << w[i] << " " << lmf[i] << "\n";
320     outtxt.close();
321   }
322 
323   for(int i = 0; i < N; i++)
324     if(lmt[i] != start_lmt[i]) {
325       delete lmt[i];
326       lmt[i] = start_lmt[i];
327     }
328 
329   if (seval != NULL) {
330     std::cerr << "Start Eval" << std::endl;
331 
332     std::cout.setf(ios::fixed);
333     std::cout.precision(2);
334     int i;
335     int Nw=0,Noov_all=0, Noov_any=0, Nbo=0;
336     double Pr,lPr;
337     double logPr=0,PP=0;
338 
339     // variables for storing sentence-based Perplexity
340     int sent_Nw=0, sent_Noov_all=0, sent_Noov_any=0, sent_Nbo=0;
341     double sent_logPr=0,sent_PP=0;
342 
343     //normalize weights
344     for (i=0,Pr=0; i<N; i++) Pr+=w[i];
345     for (i=0; i<N; i++) w[i]/=Pr;
346 
347     dictionary* dict=new dictionary(NULL,1000000,dictionary_load_factor);
348     dict->incflag(1);
349     ngram ng(dict);
350     int bos=ng.dict->encode(ng.dict->BoS());
351     int eos=ng.dict->encode(ng.dict->EoS());
352 
353     std::fstream inptxt(seval,std::ios::in);
354 
355     for(;;) {
356       std::string line;
357       getline(inptxt, line);
358       if(inptxt.eof())
359         break;
360       if(inptxt.fail()) {
361         std::cerr << "Problem reading input file " << seval << std::endl;
362         return 1;
363       }
364       std::istringstream lstream(line);
365       if(line.substr(0, 26) == "###interpolate-lm:weights ") {
366         std::string token;
367         lstream >> token;
368         for(int i = 0; i < N; i++) {
369           if(lstream.eof()) {
370             std::cerr << "Not enough weights!" << std::endl;
371             return 1;
372           }
373           lstream >> w[i];
374         }
375         continue;
376       }
377       if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
378         std::string token, newlm;
379         int id;
380         lstream >> token >> id >> newlm;
381         if(id <= 0 || id > N) {
382           std::cerr << "LM id out of range." << std::endl;
383           return 1;
384         }
385         id--; // count from 0 now
386         delete lmt[id];
387         lmt[id] = load_lm(newlm,requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
388         continue;
389       }
390 
391       double bow;
392       int bol=0;
393       char *msp;
394       unsigned int statesize;
395 
396       while(lstream >> ng) {
397 
398         // reset ngram at begin of sentence
399         if (*ng.wordp(1)==bos) {
400           ng.size=1;
401           continue;
402         }
403         if (order > 0 && ng.size > order) ng.size=order;
404 
405 
406         if (ng.size>=1) {
407 
408           int  minbol=MAX_NGRAM; //minimum backoff level of the mixture
409           bool OOV_all_flag=true;  //OOV flag wrt all LM[i]
410           bool OOV_any_flag=false; //OOV flag wrt any LM[i]
411           float logpr;
412 
413           Pr = 0.0;
414           for (i=0; i<N; i++) {
415 
416             ngram ong(lmt[i]->getDict());
417             ong.trans(ng);
418             logpr = lmt[i]->clprob(ong,&bow,&bol,&msp,&statesize); //actual prob of the interpolation
419             //logpr = lmt[i]->clprob(ong,&bow,&bol); //LM log-prob
420 
421             Pr+=w[i] * pow(10.0,logpr); //actual prob of the interpolation
422             if (bol < minbol) minbol=bol; //backoff of LM[i]
423 
424             if (*ong.wordp(1) != lmt[i]->getDict()->oovcode()) OOV_all_flag=false; //OOV wrt LM[i]
425             if (*ong.wordp(1) == lmt[i]->getDict()->oovcode()) OOV_any_flag=true; //OOV wrt LM[i]
426           }
427 
428           lPr=log(Pr)/M_LN10;
429           logPr+=lPr;
430           sent_logPr+=lPr;
431 
432           if (debug==1) {
433             std::cout << ng.dict->decode(*ng.wordp(1)) << " [" << ng.size-minbol << "]" << " ";
434             if (*ng.wordp(1)==eos) std::cout << std::endl;
435           }
436           if (debug==2)
437             std::cout << ng << " [" << ng.size-minbol << "-gram]" << " " << log(Pr) << std::endl;
438 
439           if (minbol) {
440             Nbo++;  //all LMs have back-offed by at least one
441             sent_Nbo++;
442           }
443 
444           if (OOV_all_flag) {
445             Noov_all++;  //word is OOV wrt all LM
446             sent_Noov_all++;
447           }
448           if (OOV_any_flag) {
449             Noov_any++;  //word is OOV wrt any LM
450             sent_Noov_any++;
451           }
452 
453           Nw++;
454           sent_Nw++;
455 
456           if (*ng.wordp(1)==eos && sent_PP_flag) {
457             sent_PP=exp((-sent_logPr * log(10.0)) /sent_Nw);
458             std::cout << "%% sent_Nw=" << sent_Nw
459                       << " sent_PP=" << sent_PP
460                       << " sent_Nbo=" << sent_Nbo
461                       << " sent_Noov=" << sent_Noov_all
462                       << " sent_OOV=" << (float)sent_Noov_all/sent_Nw * 100.0 << "%"
463                       << " sent_Noov_any=" << sent_Noov_any
464                       << " sent_OOV_any=" << (float)sent_Noov_any/sent_Nw * 100.0 << "%" << std::endl;
465             //reset statistics for sentence based Perplexity
466             sent_Nw=sent_Noov_any=sent_Noov_all=sent_Nbo=0;
467             sent_logPr=0.0;
468           }
469 
470 
471           if ((Nw % 10000)==0) std::cerr << ".";
472         }
473       }
474     }
475 
476     PP=exp((-logPr * M_LN10) /Nw);
477 
478     std::cout << "%% Nw=" << Nw
479               << " PP=" << PP
480               << " Nbo=" << Nbo
481               << " Noov=" << Noov_all
482               << " OOV=" << (float)Noov_all/Nw * 100.0 << "%"
483               << " Noov_any=" << Noov_any
484               << " OOV_any=" << (float)Noov_any/Nw * 100.0 << "%" << std::endl;
485 
486   };
487 
488 
489   if (score == true) {
490 
491 
492     dictionary* dict=new dictionary(NULL,1000000,dictionary_load_factor);
493     dict->incflag(1); // start generating the dictionary;
494     ngram ng(dict);
495     int bos=ng.dict->encode(ng.dict->BoS());
496 
497     double Pr,logpr;
498 
499     double bow;
500     int bol=0, maxbol=0;
501     unsigned int maxstatesize, statesize;
502     int i,n=0;
503     std::cout << "> ";
504     while(std::cin >> ng) {
505 
506       // reset ngram at begin of sentence
507       if (*ng.wordp(1)==bos) {
508         ng.size=1;
509         continue;
510       }
511 
512       if (ng.size>=maxorder) {
513 
514         if (order > 0 && ng.size > order) ng.size=order;
515         n++;
516         maxstatesize=0;
517         maxbol=0;
518         Pr=0.0;
519         for (i=0; i<N; i++) {
520           ngram ong(lmt[i]->getDict());
521           ong.trans(ng);
522           logpr = lmt[i]->clprob(ong,&bow,&bol,NULL,&statesize); //LM log-prob (using caches if available)
523 
524           Pr+=w[i] * pow(10.0,logpr); //actual prob of the interpolation
525           std::cout << "lm " << i << ":" << " logpr: " << logpr << " weight: " << w[i] << std::endl;
526           if (maxbol<bol) maxbol=bol;
527           if (maxstatesize<statesize) maxstatesize=statesize;
528         }
529 
530         std::cout << ng << " p= " << log(Pr) << " bo= " << maxbol << " recombine= " << maxstatesize << std::endl;
531 
532         if ((n % 10000000)==0) {
533           std::cerr << "." << std::endl;
534           for (i=0; i<N; i++) lmt[i]->check_caches_levels();
535         }
536 
537       } else {
538         std::cout << ng << " p= NULL" << std::endl;
539       }
540       std::cout << "> ";
541     }
542 
543 
544   }
545 
546   for (int i=0; i<N; i++) delete lmt[i];
547 
548   return 0;
549 }
550 
load_lm(std::string file,int requiredMaxlev,int dub,int memmap,float nlf,float dlf)551 lmContainer* load_lm(std::string file,int requiredMaxlev,int dub,int memmap, float nlf, float dlf)
552 {
553   lmContainer* lmt=NULL;
554 
555   lmt = lmt->CreateLanguageModel(file,nlf,dlf);
556 
557   lmt->setMaxLoadedLevel(requiredMaxlev);
558 
559   lmt->load(file,memmap);
560 
561   if (dub) lmt->setlogOOVpenalty((int)dub);
562 
563   //use caches to save time (only if PS_CACHE_ENABLE is defined through compilation flags)
564   lmt->init_caches(lmt->maxlevel());
565   return lmt;
566 }
567