1 // $Id: ngt.cpp 245 2009-04-02 14:05:40Z fabio_brugnara $
2
3 /******************************************************************************
4 IrstLM: IRST Language Model Toolkit
5 Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
6
7 This library is free software; you can redistribute it and/or
8 modify it under the terms of the GNU Lesser General Public
9 License as published by the Free Software Foundation; either
10 version 2.1 of the License, or (at your option) any later version.
11
12 This library is distributed in the hope that it will be useful,
13
14
15 but WITHOUT ANY WARRANTY; without even the implied warranty of
16 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 Lesser General Public License for more details.
18
19 You should have received a copy of the GNU Lesser General Public
20 License along with this library; if not, write to the Free Software
21 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22
23 ******************************************************************************/
24
25 // dtsel
26 // by M. Federico
27 // Copyright Marcello Federico, Fondazione Bruno Kessler, 2012
28
29 using namespace std;
30
31 #include <cmath>
32 #include "util.h"
33 #include <sstream>
34 #include "mfstream.h"
35 #include "mempool.h"
36 #include "htable.h"
37 #include "dictionary.h"
38 #include "n_gram.h"
39 #include "ngramtable.h"
40 #include "cmd.h"
41
42 #define YES 1
43 #define NO 0
44
print_help(int TypeFlag=0)45 void print_help(int TypeFlag=0){
46 std::cerr << std::endl << "dtsel - performs data selection" << std::endl;
47 std::cerr << std::endl << "USAGE:" << std::endl
48 << " dtsel -s=<outfile> [options]" << std::endl;
49 std::cerr << std::endl << "OPTIONS:" << std::endl;
50 FullPrintParams(TypeFlag, 0, 1, stderr);
51 }
52
usage(const char * msg=0)53 void usage(const char *msg = 0)
54 {
55 if (msg){
56 std::cerr << msg << std::endl;
57 }
58 else{
59 print_help();
60 }
61 exit(1);
62 }
63
prob(ngramtable * ngt,ngram ng,int size,int cv)64 double prob(ngramtable* ngt,ngram ng,int size,int cv){
65 double fstar,lambda;
66
67 assert(size<=ngt->maxlevel() && size<=ng.size);
68 if (size>1){
69 ngram history=ng;
70 if (ngt->get(history,size,size-1) && history.freq>cv){
71 fstar=0.0;
72 if (ngt->get(ng,size,size)){
73 cv=(cv>ng.freq)?ng.freq:cv;
74 if (ng.freq>cv){
75 fstar=(double)(ng.freq-cv)/(double)(history.freq -cv + history.succ);
76 lambda=(double)history.succ/(double)(history.freq -cv + history.succ);
77 }else //ng.freq==cv
78 lambda=(double)(history.succ-1)/(double)(history.freq -cv + history.succ-1);
79 }
80 else
81 lambda=(double)history.succ/(double)(history.freq -cv + history.succ);
82
83 return fstar + lambda * prob(ngt,ng,size-1,cv);
84 }
85 else return prob(ngt,ng,size-1,cv);
86
87 }else{ //unigram branch
88 if (ngt->get(ng,1,1) && ng.freq>cv)
89 return (double)(ng.freq-cv)/(ngt->totfreq()-1);
90 else{
91 //cerr << "backoff to oov unigram " << ng.freq << " " << cv << "\n";
92 *ng.wordp(1)=ngt->dict->oovcode();
93 if (ngt->get(ng,1,1) && ng.freq>0)
94 return (double)ng.freq/ngt->totfreq();
95 else //use an automatic estimate of Pr(oov)
96 return (double)ngt->dict->size()/(ngt->totfreq()+ngt->dict->size());
97 }
98
99 }
100
101 }
102
103
computePP(ngramtable * train,ngramtable * test,double oovpenalty,double & oovrate,int cv=0)104 double computePP(ngramtable* train,ngramtable* test,double oovpenalty,double& oovrate,int cv=0){
105
106
107 ngram ng2(test->dict);ngram ng1(train->dict);
108 int N=0; double H=0; oovrate=0;
109
110 test->scan(ng2,INIT,test->maxlevel());
111 while(test->scan(ng2,CONT,test->maxlevel())) {
112
113 ng1.trans(ng2);
114 H-=log(prob(train,ng1,ng1.size,cv));
115 if (*ng1.wordp(1)==train->dict->oovcode()){
116 H-=oovpenalty;
117 oovrate++;
118 }
119 N++;
120 }
121 oovrate/=N;
122 return exp(H/N);
123 }
124
125
main(int argc,char ** argv)126 int main(int argc, char **argv)
127 {
128 char *indom=NULL; //indomain data: one sentence per line
129 char *outdom=NULL; //domain data: one sentence per line
130 char *scorefile=NULL; //score file
131 char *evalset=NULL; //evalset to measure performance
132
133 int minfreq=2; //frequency threshold for dictionary pruning (optional)
134 int ngsz=0; // n-gram size
135 int dub=10000000; //upper bound of true vocabulary
136 int model=2; //data selection model: 1 only in-domain cross-entropy,
137 //2 cross-entropy difference.
138 int cv=1; //cross-validation parameter: 1 only in-domain cross-entropy,
139
140 int blocksize=100000; //block-size in words
141 int verbose=0;
142 int useindex=0; //provided score file includes and index
143 double convergence_treshold=0;
144
145 bool help=false;
146
147 DeclareParams((char*)
148 "min-word-freq", CMDINTTYPE|CMDMSG, &minfreq, "frequency threshold for dictionary pruning, default: 2",
149 "f", CMDINTTYPE|CMDMSG, &minfreq, "frequency threshold for dictionary pruning, default: 2",
150
151 "ngram-order", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1 , MAX_NGRAM, "n-gram default size, default: 0",
152 "n", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1 , MAX_NGRAM, "n-gram default size, default: 0",
153
154 "in-domain-file", CMDSTRINGTYPE|CMDMSG, &indom, "indomain data file: one sentence per line",
155 "i", CMDSTRINGTYPE|CMDMSG, &indom, "indomain data file: one sentence per line",
156
157 "out-domain-file", CMDSTRINGTYPE|CMDMSG, &outdom, "domain data file: one sentence per line",
158 "o", CMDSTRINGTYPE|CMDMSG, &outdom, "domain data file: one sentence per line",
159
160 "score-file", CMDSTRINGTYPE|CMDMSG, &scorefile, "score output file",
161 "s", CMDSTRINGTYPE|CMDMSG, &scorefile, "score output file",
162
163 "dictionary-upper-bound", CMDINTTYPE|CMDMSG, &dub, "upper bound of true vocabulary, default: 10000000",
164 "dub", CMDINTTYPE|CMDMSG, &dub, "upper bound of true vocabulary, default: 10000000",
165
166 "model", CMDSUBRANGETYPE|CMDMSG, &model, 1 , 2, "data selection model: 1 only in-domain cross-entropy, 2 cross-entropy difference; default: 2",
167 "m", CMDSUBRANGETYPE|CMDMSG, &model, 1 , 2, "data selection model: 1 only in-domain cross-entropy, 2 cross-entropy difference; default: 2",
168
169 "cross-validation", CMDSUBRANGETYPE|CMDMSG, &cv, 1 , 3, "cross-validation parameter: 1 only in-domain cross-entropy; default: 1",
170 "cv", CMDSUBRANGETYPE|CMDMSG, &cv, 1 , 3, "cross-validation parameter: 1 only in-domain cross-entropy; default: 1",
171
172 "test", CMDSTRINGTYPE|CMDMSG, &evalset, "evaluation set file to measure performance",
173 "t", CMDSTRINGTYPE|CMDMSG, &evalset, "evaluation set file to measure performance",
174
175 "block-size", CMDINTTYPE|CMDMSG, &blocksize, "block-size in words, default: 100000",
176 "bs", CMDINTTYPE|CMDMSG, &blocksize, "block-size in words, default: 100000",
177
178 "convergence-threshold", CMDDOUBLETYPE|CMDMSG, &convergence_treshold, "convergence threshold, default: 0",
179 "c", CMDDOUBLETYPE|CMDMSG, &convergence_treshold, "convergence threshold, default: 0",
180
181 "index", CMDSUBRANGETYPE|CMDMSG, &useindex,0,1, "provided score file includes and index, default: 0",
182 "x", CMDSUBRANGETYPE|CMDMSG, &useindex,0,1, "provided score file includes and index, default: 0",
183
184 "verbose", CMDSUBRANGETYPE|CMDMSG, &verbose,0,2, "verbose level, default: 0",
185 "v", CMDSUBRANGETYPE|CMDMSG, &verbose,0,2, "verbose level, default: 0",
186 "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
187 "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
188
189 (char *)NULL
190 );
191
192
193
194 GetParams(&argc, &argv, (char*) NULL);
195
196 if (help){
197 usage();
198 }
199 if (scorefile==NULL) {
200 usage();
201 }
202
203 if (!evalset && (!indom || !outdom)){
204 cerr <<"Must specify in-domain and out-domain data files\n";
205 exit(1);
206 };
207
208 //score file is always required: either as output or as input
209 if (!scorefile){
210 cerr <<"Must specify score file\n";
211 exit(1);
212 };
213
214 if (!evalset && !model){
215 cerr <<"Must specify data selection model\n";
216 exit(1);
217 }
218
219 if (evalset && (convergence_treshold<0 || convergence_treshold > 0.1)){
220 cerr <<"Convergence threshold must be between 0 and 0.1. \n";
221 exit(1);
222 }
223
224 TABLETYPE table_type=COUNT;
225
226
227 if (!evalset){
228
229 //computed dictionary on indomain data
230 dictionary *dict = new dictionary(indom,1000000,0);
231 dictionary *pd=new dictionary(dict,true,minfreq);
232 delete dict;dict=pd;
233
234 //build in-domain table restricted to the given dictionary
235 ngramtable *indngt=new ngramtable(indom,ngsz,NULL,dict,NULL,0,0,NULL,0,table_type);
236 double indoovpenalty=-log(dub-indngt->dict->size());
237 ngram indng(indngt->dict);
238 int indoovcode=indngt->dict->oovcode();
239
240 //build out-domain table restricted to the in-domain dictionary
241 char command[1000]="";
242
243 if (useindex)
244 sprintf(command,"cut -d \" \" -f 2- %s",outdom);
245 else
246 sprintf(command,"%s",outdom);
247
248
249 ngramtable *outdngt=new ngramtable(command,ngsz,NULL,dict,NULL,0,0,NULL,0,table_type);
250 double outdoovpenalty=-log(dub-outdngt->dict->size());
251 ngram outdng(outdngt->dict);
252 int outdoovcode=outdngt->dict->oovcode();
253
254 cerr << "dict size idom: " << indngt->dict->size() << " odom: " << outdngt->dict->size() << "\n";
255 cerr << "oov penalty idom: " << indoovpenalty << " odom: " << outdoovpenalty << "\n";
256
257 //go through the odomain sentences
258 int bos=dict->encode(dict->BoS());
259 mfstream inp(outdom,ios::in); ngram ng(dict);
260 mfstream txt(outdom,ios::in);
261 mfstream output(scorefile,ios::out);
262
263
264 int linenumber=1; string line;
265 int lenght=0;float deltaH=0; float deltaHoov=0; int words=0;string index;
266
267 while (getline(inp,line)){
268
269 istringstream lninp(line);
270
271 linenumber++;
272
273 if (useindex) lninp >> index;
274
275 // reset ngram at begin of sentence
276 ng.size=1; deltaH=0;deltaHoov=0; lenght=0;
277
278 while(lninp>>ng){
279
280 if (*ng.wordp(1)==bos) continue;
281
282 lenght++; words++;
283
284 if ((words % 1000000)==0) cerr << ".";
285
286 if (ng.size>ngsz) ng.size=ngsz;
287 indng.trans(ng);outdng.trans(ng);
288
289 if (model==1){//compute cross-entropy
290 deltaH-=log(prob(indngt,indng,indng.size,0));
291 deltaHoov-=(*indng.wordp(1)==indoovcode?indoovpenalty:0);
292 }
293
294 if (model==2){ //compute cross-entropy difference
295 deltaH+=log(prob(outdngt,outdng,outdng.size,cv))-log(prob(indngt,indng,indng.size,0));
296 deltaHoov+=(*outdng.wordp(1)==outdoovcode?outdoovpenalty:0)-(*indng.wordp(1)==indoovcode?indoovpenalty:0);
297 }
298 }
299
300 output << (deltaH + deltaHoov)/lenght << " " << line << "\n";
301 }
302 }
303 else{
304
305 //build in-domain LM from evaluation set
306 ngramtable *tstngt=new ngramtable(evalset,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
307
308 //build empty out-domain LM
309 ngramtable *outdngt=new ngramtable(NULL,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
310
311 //if indomain data is passed then limit comparison to its dictionary
312 dictionary *dict = NULL;
313 if (indom){
314 cerr << "dtsel: limit evaluation dict to indomain words with freq >=" << minfreq << "\n";
315 //computed dictionary on indomain data
316 dict = new dictionary(indom,1000000,0);
317 dictionary *pd=new dictionary(dict,true,minfreq);
318 delete dict;dict=pd;
319 outdngt->dict=dict;
320 }
321
322 dictionary* outddict=outdngt->dict;
323
324 //get codes of <s>, </s> and UNK
325 outddict->incflag(1);
326 int bos=outddict->encode(outddict->BoS());
327 int oov=outddict->encode(outddict->OOV());
328 outddict->incflag(0);
329 outddict->oovcode(oov);
330
331
332 double oldPP=dub; double newPP=0; double oovrate=0;
333
334 long totwords=0; long totlines=0; long nextstep=blocksize;
335
336 double score; string index;
337
338 mfstream outd(scorefile,ios::in); string line;
339
340 //initialize n-gram
341 ngram ng(outdngt->dict); for (int i=1;i<ngsz;i++) ng.pushc(bos); ng.freq=1;
342
343 //check if to use open or closed voabulary
344
345 if (!dict) outddict->incflag(1);
346
347 while (getline(outd,line)){
348
349 istringstream lninp(line);
350
351 //skip score and eventually the index
352 lninp >> score; if (useindex) lninp >> index;
353
354 while (lninp >> ng){
355
356 if (*ng.wordp(1) == bos) continue;
357
358 if (ng.size>ngsz) ng.size=ngsz;
359
360 outdngt->put(ng);
361
362 totwords++;
363 }
364
365 totlines++;
366
367 if (totwords>=nextstep){ //if block is complete
368
369 if (!dict) outddict->incflag(0);
370
371 newPP=computePP(outdngt,tstngt,-log(dub-outddict->size()),oovrate);
372
373 if (!dict) outddict->incflag(1);
374
375 cout << totwords << " " << newPP;
376 if (verbose) cout << " " << totlines << " " << oovrate;
377 cout << "\n";
378
379 if (convergence_treshold && (oldPP-newPP)/oldPP < convergence_treshold) return 1;
380
381 oldPP=newPP;
382
383 nextstep+=blocksize;
384 }
385 }
386
387 if (!dict) outddict->incflag(0);
388 newPP=computePP(outdngt,tstngt,-log(dub-outddict->size()),oovrate);
389 cout << totwords << " " << newPP;
390 if (verbose) cout << " " << totlines << " " << oovrate;
391
392 }
393
394 }
395
396
397
398