1 // $Id: ngt.cpp 245 2009-04-02 14:05:40Z fabio_brugnara $
2 
3 /******************************************************************************
4 IrstLM: IRST Language Model Toolkit
5 Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
6 
7 This library is free software; you can redistribute it and/or
8 modify it under the terms of the GNU Lesser General Public
9 License as published by the Free Software Foundation; either
10 version 2.1 of the License, or (at your option) any later version.
11 
12 This library is distributed in the hope that it will be useful,
13 
14 
15 but WITHOUT ANY WARRANTY; without even the implied warranty of
16 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17 Lesser General Public License for more details.
18 
19 You should have received a copy of the GNU Lesser General Public
20 License along with this library; if not, write to the Free Software
21 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301 USA
22 
23 ******************************************************************************/
24 
25 // dtsel
26 // by M. Federico
27 // Copyright Marcello Federico, Fondazione Bruno Kessler, 2012
28 
29 using namespace std;
30 
31 #include <cmath>
32 #include "util.h"
33 #include <sstream>
34 #include "mfstream.h"
35 #include "mempool.h"
36 #include "htable.h"
37 #include "dictionary.h"
38 #include "n_gram.h"
39 #include "ngramtable.h"
40 #include "cmd.h"
41 
42 #define YES   1
43 #define NO    0
44 
print_help(int TypeFlag=0)45 void print_help(int TypeFlag=0){
46 		std::cerr << std::endl << "dtsel - performs data selection" << std::endl;
47 		std::cerr << std::endl << "USAGE:"  << std::endl
48 			  << "       dtsel -s=<outfile> [options]" << std::endl;
49 		std::cerr << std::endl << "OPTIONS:" << std::endl;
50 	FullPrintParams(TypeFlag, 0, 1, stderr);
51 }
52 
usage(const char * msg=0)53 void usage(const char *msg = 0)
54 {
55   if (msg){
56     std::cerr << msg << std::endl;
57 	}
58   else{
59 		print_help();
60 	}
61 	exit(1);
62 }
63 
prob(ngramtable * ngt,ngram ng,int size,int cv)64 double prob(ngramtable* ngt,ngram ng,int size,int cv){
65 	double fstar,lambda;
66 
67 	assert(size<=ngt->maxlevel() && size<=ng.size);
68 	if (size>1){
69 		ngram history=ng;
70 		if (ngt->get(history,size,size-1) && history.freq>cv){
71 			fstar=0.0;
72 			if (ngt->get(ng,size,size)){
73 				cv=(cv>ng.freq)?ng.freq:cv;
74 				if (ng.freq>cv){
75 					fstar=(double)(ng.freq-cv)/(double)(history.freq -cv + history.succ);
76 					lambda=(double)history.succ/(double)(history.freq -cv + history.succ);
77 				}else //ng.freq==cv
78 					lambda=(double)(history.succ-1)/(double)(history.freq -cv + history.succ-1);
79 			}
80 			else
81 				lambda=(double)history.succ/(double)(history.freq -cv + history.succ);
82 
83 			return fstar + lambda * prob(ngt,ng,size-1,cv);
84 		}
85 		else return prob(ngt,ng,size-1,cv);
86 
87 	}else{ //unigram branch
88 		if (ngt->get(ng,1,1) && ng.freq>cv)
89 			return (double)(ng.freq-cv)/(ngt->totfreq()-1);
90 		else{
91 			//cerr << "backoff to oov unigram " << ng.freq << " " << cv << "\n";
92 			*ng.wordp(1)=ngt->dict->oovcode();
93 			if (ngt->get(ng,1,1) && ng.freq>0)
94 				return (double)ng.freq/ngt->totfreq();
95 			else //use an automatic estimate of Pr(oov)
96 				return (double)ngt->dict->size()/(ngt->totfreq()+ngt->dict->size());
97 		}
98 
99 	}
100 
101 }
102 
103 
computePP(ngramtable * train,ngramtable * test,double oovpenalty,double & oovrate,int cv=0)104 double computePP(ngramtable* train,ngramtable* test,double oovpenalty,double& oovrate,int cv=0){
105 
106 
107 	ngram ng2(test->dict);ngram ng1(train->dict);
108 	int N=0; double H=0;  oovrate=0;
109 
110 	test->scan(ng2,INIT,test->maxlevel());
111  	while(test->scan(ng2,CONT,test->maxlevel())) {
112 
113 		ng1.trans(ng2);
114 		H-=log(prob(train,ng1,ng1.size,cv));
115 		if (*ng1.wordp(1)==train->dict->oovcode()){
116 			H-=oovpenalty;
117 			oovrate++;
118 		}
119 		N++;
120 	}
121 	oovrate/=N;
122 	return exp(H/N);
123 }
124 
125 
main(int argc,char ** argv)126 int main(int argc, char **argv)
127 {
128 	char *indom=NULL;   //indomain data: one sentence per line
129 	char *outdom=NULL;   //domain data: one sentence per line
130 	char *scorefile=NULL;  //score file
131 	char *evalset=NULL;    //evalset to measure performance
132 
133 	int minfreq=2;          //frequency threshold for dictionary pruning (optional)
134 	int ngsz=0;             // n-gram size
135 	int dub=10000000;      //upper bound of true vocabulary
136 	int model=2;           //data selection model: 1 only in-domain cross-entropy,
137 	                       //2 cross-entropy difference.
138 	int cv=1;              //cross-validation parameter: 1 only in-domain cross-entropy,
139 
140 	int blocksize=100000; //block-size in words
141 	int verbose=0;
142 	int useindex=0; //provided score file includes and index
143 	double convergence_treshold=0;
144 
145 	bool help=false;
146 
147 	DeclareParams((char*)
148 				  "min-word-freq", CMDINTTYPE|CMDMSG, &minfreq, "frequency threshold for dictionary pruning, default: 2",
149 				  "f", CMDINTTYPE|CMDMSG, &minfreq, "frequency threshold for dictionary pruning, default: 2",
150 
151 		      "ngram-order", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1 , MAX_NGRAM, "n-gram default size, default: 0",
152 				  "n", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1 , MAX_NGRAM, "n-gram default size, default: 0",
153 
154 				  "in-domain-file", CMDSTRINGTYPE|CMDMSG, &indom, "indomain data file: one sentence per line",
155 				  "i", CMDSTRINGTYPE|CMDMSG, &indom, "indomain data file: one sentence per line",
156 
157 				  "out-domain-file", CMDSTRINGTYPE|CMDMSG, &outdom, "domain data file: one sentence per line",
158 				  "o", CMDSTRINGTYPE|CMDMSG, &outdom, "domain data file: one sentence per line",
159 
160 				  "score-file", CMDSTRINGTYPE|CMDMSG, &scorefile, "score output file",
161 				  "s", CMDSTRINGTYPE|CMDMSG, &scorefile, "score output file",
162 
163 				  "dictionary-upper-bound", CMDINTTYPE|CMDMSG, &dub, "upper bound of true vocabulary, default: 10000000",
164 				  "dub", CMDINTTYPE|CMDMSG, &dub, "upper bound of true vocabulary, default: 10000000",
165 
166 				  "model", CMDSUBRANGETYPE|CMDMSG, &model, 1 , 2, "data selection model: 1 only in-domain cross-entropy, 2 cross-entropy difference; default: 2",
167 				  "m", CMDSUBRANGETYPE|CMDMSG, &model, 1 , 2, "data selection model: 1 only in-domain cross-entropy, 2 cross-entropy difference; default: 2",
168 
169 				  "cross-validation", CMDSUBRANGETYPE|CMDMSG, &cv, 1 , 3, "cross-validation parameter: 1 only in-domain cross-entropy; default: 1",
170   				  "cv", CMDSUBRANGETYPE|CMDMSG, &cv, 1 , 3, "cross-validation parameter: 1 only in-domain cross-entropy; default: 1",
171 
172 				  "test", CMDSTRINGTYPE|CMDMSG, &evalset, "evaluation set file to measure performance",
173 				  "t", CMDSTRINGTYPE|CMDMSG, &evalset, "evaluation set file to measure performance",
174 
175 				  "block-size", CMDINTTYPE|CMDMSG, &blocksize, "block-size in words, default: 100000",
176 				  "bs", CMDINTTYPE|CMDMSG, &blocksize, "block-size in words, default: 100000",
177 
178 				  "convergence-threshold", CMDDOUBLETYPE|CMDMSG, &convergence_treshold, "convergence threshold, default: 0",
179 				  "c", CMDDOUBLETYPE|CMDMSG, &convergence_treshold, "convergence threshold, default: 0",
180 
181 				  "index", CMDSUBRANGETYPE|CMDMSG, &useindex,0,1, "provided score file includes and index, default: 0",
182 				  "x", CMDSUBRANGETYPE|CMDMSG, &useindex,0,1, "provided score file includes and index, default: 0",
183 
184 				  "verbose", CMDSUBRANGETYPE|CMDMSG, &verbose,0,2, "verbose level, default: 0",
185 				  "v", CMDSUBRANGETYPE|CMDMSG, &verbose,0,2, "verbose level, default: 0",
186 								"Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
187 								"h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
188 
189 				  (char *)NULL
190 				  );
191 
192 
193 
194 	GetParams(&argc, &argv, (char*) NULL);
195 
196 	if (help){
197 		usage();
198 	}
199 	if (scorefile==NULL) {
200 		usage();
201 	}
202 
203 	if (!evalset && (!indom || !outdom)){
204 		cerr <<"Must specify in-domain and out-domain data files\n";
205 		exit(1);
206 	};
207 
208 	//score file is always required: either as output or as input
209 	if (!scorefile){
210 		cerr <<"Must specify score file\n";
211 		exit(1);
212 	};
213 
214 	if (!evalset && !model){
215 		cerr <<"Must specify data selection model\n";
216 		exit(1);
217 	}
218 
219 	if (evalset && (convergence_treshold<0 || convergence_treshold > 0.1)){
220 		cerr <<"Convergence threshold must be between 0 and 0.1. \n";
221 		exit(1);
222 	}
223 
224 	TABLETYPE table_type=COUNT;
225 
226 
227 	if (!evalset){
228 
229 		//computed dictionary on indomain data
230 		dictionary *dict = new dictionary(indom,1000000,0);
231 		dictionary *pd=new dictionary(dict,true,minfreq);
232 		delete dict;dict=pd;
233 
234 		//build in-domain table restricted to the given dictionary
235 		ngramtable *indngt=new ngramtable(indom,ngsz,NULL,dict,NULL,0,0,NULL,0,table_type);
236 		double indoovpenalty=-log(dub-indngt->dict->size());
237 		ngram indng(indngt->dict);
238 		int indoovcode=indngt->dict->oovcode();
239 
240 		//build out-domain table restricted to the in-domain dictionary
241 		char command[1000]="";
242 
243 		if (useindex)
244 			sprintf(command,"cut -d \" \" -f 2- %s",outdom);
245 		else
246 			sprintf(command,"%s",outdom);
247 
248 
249 		ngramtable *outdngt=new ngramtable(command,ngsz,NULL,dict,NULL,0,0,NULL,0,table_type);
250 		double outdoovpenalty=-log(dub-outdngt->dict->size());
251 		ngram outdng(outdngt->dict);
252 		int outdoovcode=outdngt->dict->oovcode();
253 
254 		cerr << "dict size idom: " << indngt->dict->size() << " odom: " << outdngt->dict->size() << "\n";
255 		cerr << "oov penalty idom: " << indoovpenalty << " odom: " << outdoovpenalty << "\n";
256 
257 		//go through the odomain sentences
258 		int bos=dict->encode(dict->BoS());
259 		mfstream inp(outdom,ios::in); ngram ng(dict);
260 		mfstream txt(outdom,ios::in);
261 		mfstream output(scorefile,ios::out);
262 
263 
264 		int linenumber=1; string line;
265 		int lenght=0;float deltaH=0; float deltaHoov=0; int words=0;string index;
266 
267 		while (getline(inp,line)){
268 
269 			istringstream lninp(line);
270 
271 			linenumber++;
272 
273 			if (useindex) lninp >> index;
274 
275 			// reset ngram at begin of sentence
276 			ng.size=1; deltaH=0;deltaHoov=0; lenght=0;
277 
278 			while(lninp>>ng){
279 
280 				if (*ng.wordp(1)==bos) continue;
281 
282 				lenght++; words++;
283 
284 				if ((words % 1000000)==0) cerr << ".";
285 
286 				if (ng.size>ngsz) ng.size=ngsz;
287 				indng.trans(ng);outdng.trans(ng);
288 
289 				if (model==1){//compute cross-entropy
290 					deltaH-=log(prob(indngt,indng,indng.size,0));
291 					deltaHoov-=(*indng.wordp(1)==indoovcode?indoovpenalty:0);
292 				}
293 
294 				if (model==2){ //compute cross-entropy difference
295 					deltaH+=log(prob(outdngt,outdng,outdng.size,cv))-log(prob(indngt,indng,indng.size,0));
296 					deltaHoov+=(*outdng.wordp(1)==outdoovcode?outdoovpenalty:0)-(*indng.wordp(1)==indoovcode?indoovpenalty:0);
297 				}
298 			}
299 
300 			output << (deltaH + deltaHoov)/lenght  << " " << line << "\n";
301 		}
302 	}
303 	else{
304 
305 		//build in-domain LM from evaluation set
306 		ngramtable *tstngt=new ngramtable(evalset,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
307 
308 		//build empty out-domain LM
309 		ngramtable *outdngt=new ngramtable(NULL,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
310 
311 		//if indomain data is passed then limit comparison to its dictionary
312 		dictionary *dict = NULL;
313 		if (indom){
314 			cerr << "dtsel: limit evaluation dict to indomain words with freq >=" <<  minfreq << "\n";
315 			//computed dictionary on indomain data
316 			dict = new dictionary(indom,1000000,0);
317 			dictionary *pd=new dictionary(dict,true,minfreq);
318 			delete dict;dict=pd;
319 			outdngt->dict=dict;
320 		}
321 
322 		dictionary* outddict=outdngt->dict;
323 
324 		//get codes of <s>, </s> and UNK
325 		outddict->incflag(1);
326 		int bos=outddict->encode(outddict->BoS());
327 		int oov=outddict->encode(outddict->OOV());
328 		outddict->incflag(0);
329 		outddict->oovcode(oov);
330 
331 
332 		double oldPP=dub; double newPP=0; double oovrate=0;
333 
334 		long totwords=0; long totlines=0; long nextstep=blocksize;
335 
336 		double score; string index;
337 
338 		mfstream outd(scorefile,ios::in); string line;
339 
340 		//initialize n-gram
341 		ngram ng(outdngt->dict); for (int i=1;i<ngsz;i++) ng.pushc(bos); ng.freq=1;
342 
343 		//check if to use open or closed voabulary
344 
345 		if (!dict) outddict->incflag(1);
346 
347 		while (getline(outd,line)){
348 
349 			istringstream lninp(line);
350 
351 			//skip score and eventually the index
352 			lninp >> score; if (useindex) lninp >> index;
353 
354 			while (lninp >> ng){
355 
356 				if (*ng.wordp(1) == bos) continue;
357 
358 				if (ng.size>ngsz) ng.size=ngsz;
359 
360 				outdngt->put(ng);
361 
362 				totwords++;
363 			}
364 
365 			totlines++;
366 
367 			if (totwords>=nextstep){ //if block is complete
368 
369 				if (!dict) outddict->incflag(0);
370 
371 				newPP=computePP(outdngt,tstngt,-log(dub-outddict->size()),oovrate);
372 
373 				if (!dict) outddict->incflag(1);
374 
375 				cout << totwords << " " << newPP;
376 				if (verbose) cout << " " << totlines << " " << oovrate;
377 				cout << "\n";
378 
379 				if (convergence_treshold && (oldPP-newPP)/oldPP < convergence_treshold) return 1;
380 
381 				oldPP=newPP;
382 
383 				nextstep+=blocksize;
384 			}
385 		}
386 
387 		if (!dict) outddict->incflag(0);
388 		newPP=computePP(outdngt,tstngt,-log(dub-outddict->size()),oovrate);
389 		cout << totwords << " " << newPP;
390 		if (verbose) cout << " " << totlines << " " << oovrate;
391 
392 	}
393 
394 }
395 
396 
397 
398