1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD (revised)
4 license as described in the file LICENSE.
5 */
6 #include <sys/types.h>
7
8 #ifndef _WIN32
9 #include <sys/mman.h>
10 #include <sys/wait.h>
11 #include <unistd.h>
12 #include <netinet/tcp.h>
13 #endif
14
15 #include <signal.h>
16
17 #include <fstream>
18
19 #ifdef _WIN32
20 #include <winsock2.h>
21 #include <Windows.h>
22 #include <io.h>
23 typedef int socklen_t;
24
daemon(int a,int b)25 int daemon(int a, int b)
26 {
27 exit(0);
28 return 0;
29 }
getpid()30 int getpid()
31 {
32 return (int) ::GetCurrentProcessId();
33 }
34 #else
35 #include <netdb.h>
36 #endif
37 #include <boost/program_options.hpp>
38
39 #ifdef __FreeBSD__
40 #include <netinet/in.h>
41 #endif
42
43 #include <errno.h>
44 #include <stdio.h>
45 #include <assert.h>
46 namespace po = boost::program_options;
47
48 #include "parse_example.h"
49 #include "cache.h"
50 #include "unique_sort.h"
51 #include "constant.h"
52 #include "vw.h"
53
54 using namespace std;
55
initialize_mutex(MUTEX * pm)56 void initialize_mutex(MUTEX * pm)
57 {
58 #ifndef _WIN32
59 pthread_mutex_init(pm, NULL);
60 #else
61 ::InitializeCriticalSection(pm);
62 #endif
63 }
64
delete_mutex(MUTEX * pm)65 void delete_mutex(MUTEX * pm)
66 {
67 #ifndef _WIN32
68 // no operation necessary here
69 #else
70 ::DeleteCriticalSection(pm);
71 #endif
72 }
73
initialize_condition_variable(CV * pcv)74 void initialize_condition_variable(CV * pcv)
75 {
76 #ifndef _WIN32
77 pthread_cond_init(pcv, NULL);
78 #else
79 ::InitializeConditionVariable(pcv);
80 #endif
81 }
82
mutex_lock(MUTEX * pm)83 void mutex_lock(MUTEX * pm)
84 {
85 #ifndef _WIN32
86 pthread_mutex_lock(pm);
87 #else
88 ::EnterCriticalSection(pm);
89 #endif
90 }
91
mutex_unlock(MUTEX * pm)92 void mutex_unlock(MUTEX * pm)
93 {
94 #ifndef _WIN32
95 pthread_mutex_unlock(pm);
96 #else
97 ::LeaveCriticalSection(pm);
98 #endif
99 }
100
condition_variable_wait(CV * pcv,MUTEX * pm)101 void condition_variable_wait(CV * pcv, MUTEX * pm)
102 {
103 #ifndef _WIN32
104 pthread_cond_wait(pcv, pm);
105 #else
106 ::SleepConditionVariableCS(pcv, pm, INFINITE);
107 #endif
108 }
109
condition_variable_signal(CV * pcv)110 void condition_variable_signal(CV * pcv)
111 {
112 #ifndef _WIN32
113 pthread_cond_signal(pcv);
114 #else
115 ::WakeConditionVariable(pcv);
116 #endif
117 }
118
condition_variable_signal_all(CV * pcv)119 void condition_variable_signal_all(CV * pcv)
120 {
121 #ifndef _WIN32
122 pthread_cond_broadcast(pcv);
123 #else
124 ::WakeAllConditionVariable(pcv);
125 #endif
126 }
127
128 //This should not? matter in a library mode.
129 bool got_sigterm;
130
handle_sigterm(int)131 void handle_sigterm (int)
132 {
133 got_sigterm = true;
134 }
135
is_test_only(uint32_t counter,uint32_t period,uint32_t after,bool holdout_off,uint32_t target_modulus)136 bool is_test_only(uint32_t counter, uint32_t period, uint32_t after, bool holdout_off, uint32_t target_modulus) // target should be 0 in the normal case, or period-1 in the case that emptylines separate examples
137 {
138 if(holdout_off) return false;
139 //cerr << "(" << counter << "," << period << "," << target_modulus << ")";
140 if (after == 0) // hold out by period
141 return (counter % period == target_modulus);
142 else // hold out by position
143 return (counter >= after);
144 }
145
new_parser()146 parser* new_parser()
147 {
148 parser& ret = calloc_or_die<parser>();
149 ret.input = new io_buf;
150 ret.output = new io_buf;
151 ret.local_example_number = 0;
152 ret.in_pass_counter = 0;
153 ret.ring_size = 1 << 8;
154 ret.done = false;
155 ret.used_index = 0;
156
157 return &ret;
158 }
159
set_compressed(parser * par)160 void set_compressed(parser* par){
161 finalize_source(par);
162 par->input = new comp_io_buf;
163 par->output = new comp_io_buf;
164 }
165
cache_numbits(io_buf * buf,int filepointer)166 uint32_t cache_numbits(io_buf* buf, int filepointer)
167 {
168 v_array<char> t = v_init<char>();
169
170 uint32_t v_length;
171 buf->read_file(filepointer, (char*)&v_length, sizeof(v_length));
172 if(v_length>29){
173 cerr << "cache version too long, cache file is probably invalid" << endl;
174 throw exception();
175 }
176 else if (v_length == 0) {
177 cerr << "cache version too short, cache file is probably invalid" << endl;
178 throw exception();
179 }
180
181 t.erase();
182 if (t.size() < v_length)
183 t.resize(v_length);
184
185 buf->read_file(filepointer,t.begin,v_length);
186 version_struct v_tmp(t.begin);
187 if ( v_tmp != version )
188 {
189 cout << "cache has possibly incompatible version, rebuilding" << endl;
190 t.delete_v();
191 return 0;
192 }
193
194 char temp;
195 if (buf->read_file(filepointer, &temp, 1) < 1)
196 {
197 cout << "failed to read" << endl;
198 throw exception();
199 }
200 if (temp != 'c')
201 {
202 cout << "data file is not a cache file" << endl;
203 throw exception();
204 }
205
206 t.delete_v();
207
208 const int total = sizeof(uint32_t);
209 char* p[total];
210 if (buf->read_file(filepointer, p, total) < total)
211 {
212 return true;
213 }
214
215 uint32_t cache_numbits = *(uint32_t *)p;
216 return cache_numbits;
217 }
218
member(v_array<int> ids,int id)219 bool member(v_array<int> ids, int id)
220 {
221 for (size_t i = 0; i < ids.size(); i++)
222 if (ids[i] == id)
223 return true;
224 return false;
225 }
226
reset_source(vw & all,size_t numbits)227 void reset_source(vw& all, size_t numbits)
228 {
229 io_buf* input = all.p->input;
230 input->current = 0;
231 if (all.p->write_cache)
232 {
233 all.p->output->flush();
234 all.p->write_cache = false;
235 all.p->output->close_file();
236 remove(all.p->output->finalname.begin);
237 rename(all.p->output->currentname.begin, all.p->output->finalname.begin);
238 while(input->num_files() > 0)
239 if (input->compressed())
240 input->close_file();
241 else
242 {
243 int fd = input->files.pop();
244 if (!member(all.final_prediction_sink, (size_t) fd))
245 io_buf::close_file_or_socket(fd);
246 }
247 input->open_file(all.p->output->finalname.begin, all.stdin_off, io_buf::READ); //pushing is merged into open_file
248 all.p->reader = read_cached_features;
249 }
250 if ( all.p->resettable == true )
251 {
252 if (all.daemon)
253 {
254 // wait for all predictions to be sent back to client
255 mutex_lock(&all.p->output_lock);
256 while (all.p->local_example_number != all.p->end_parsed_examples)
257 condition_variable_wait(&all.p->output_done, &all.p->output_lock);
258 mutex_unlock(&all.p->output_lock);
259
260 // close socket, erase final prediction sink and socket
261 io_buf::close_file_or_socket(all.p->input->files[0]);
262 all.final_prediction_sink.erase();
263 all.p->input->files.erase();
264
265 sockaddr_in client_address;
266 socklen_t size = sizeof(client_address);
267 int f = (int)accept(all.p->bound_sock,(sockaddr*)&client_address,&size);
268 if (f < 0)
269 {
270 cerr << "accept: " << strerror(errno) << endl;
271 throw exception();
272 }
273
274 // note: breaking cluster parallel online learning by dropping support for id
275
276 all.final_prediction_sink.push_back((size_t) f);
277 all.p->input->files.push_back(f);
278
279 if (isbinary(*(all.p->input))) {
280 all.p->reader = read_cached_features;
281 all.print = binary_print_result;
282 } else {
283 all.p->reader = read_features;
284 all.print = print_result;
285 }
286 }
287 else {
288 for (size_t i = 0; i < input->files.size();i++)
289 {
290 input->reset_file(input->files[i]);
291 if (cache_numbits(input, input->files[i]) < numbits) {
292 cerr << "argh, a bug in caching of some sort! Exiting\n" ;
293 throw exception();
294 }
295 }
296 }
297 }
298 }
299
finalize_source(parser * p)300 void finalize_source(parser* p)
301 {
302 #ifdef _WIN32
303 int f = _fileno(stdin);
304 #else
305 int f = fileno(stdin);
306 #endif
307 while (!p->input->files.empty() && p->input->files.last() == f)
308 p->input->files.pop();
309 p->input->close_files();
310
311 delete p->input;
312 p->output->close_files();
313 delete p->output;
314 }
315
make_write_cache(vw & all,string & newname,bool quiet)316 void make_write_cache(vw& all, string &newname, bool quiet)
317 {
318 io_buf* output = all.p->output;
319 if (output->files.size() != 0){
320 cerr << "Warning: you tried to make two write caches. Only the first one will be made." << endl;
321 return;
322 }
323
324 string temp = newname+string(".writing");
325 push_many(output->currentname,temp.c_str(),temp.length()+1);
326
327 int f = output->open_file(temp.c_str(), all.stdin_off, io_buf::WRITE);
328 if (f == -1) {
329 cerr << "can't create cache file !" << endl;
330 return;
331 }
332
333 uint32_t v_length = (uint32_t)version.to_string().length()+1;
334
335 output->write_file(f, &v_length, sizeof(v_length));
336 output->write_file(f,version.to_string().c_str(),v_length);
337 output->write_file(f,"c",1);
338 output->write_file(f, &all.num_bits, sizeof(all.num_bits));
339
340 push_many(output->finalname,newname.c_str(),newname.length()+1);
341 all.p->write_cache = true;
342 if (!quiet)
343 cerr << "creating cache_file = " << newname << endl;
344 }
345
parse_cache(vw & all,po::variables_map & vm,string source,bool quiet)346 void parse_cache(vw& all, po::variables_map &vm, string source,
347 bool quiet)
348 {
349 vector<string> caches;
350 if (vm.count("cache_file"))
351 caches = vm["cache_file"].as< vector<string> >();
352 if (vm.count("cache"))
353 caches.push_back(source+string(".cache"));
354
355 all.p->write_cache = false;
356
357 for (size_t i = 0; i < caches.size(); i++)
358 {
359 int f = -1;
360 if (!vm.count("kill_cache"))
361 try {
362 f = all.p->input->open_file(caches[i].c_str(), all.stdin_off, io_buf::READ);
363 }
364 catch (exception e){ f = -1;}
365 if (f == -1)
366 make_write_cache(all, caches[i], quiet);
367 else {
368 uint32_t c = cache_numbits(all.p->input, f);
369 if (c < all.num_bits) {
370 all.p->input->close_file();
371 make_write_cache(all, caches[i], quiet);
372 }
373 else {
374 if (!quiet)
375 cerr << "using cache_file = " << caches[i].c_str() << endl;
376 all.p->reader = read_cached_features;
377 if (c == all.num_bits)
378 all.p->sorted_cache = true;
379 else
380 all.p->sorted_cache = false;
381 all.p->resettable = true;
382 }
383 }
384 }
385
386 all.parse_mask = (1 << all.num_bits) - 1;
387 if (caches.size() == 0)
388 {
389 if (!quiet)
390 cerr << "using no cache" << endl;
391 all.p->output->space.delete_v();
392 }
393 }
394
395 //For macs
396 #ifndef MAP_ANONYMOUS
397 # define MAP_ANONYMOUS MAP_ANON
398 #endif
399
enable_sources(vw & all,bool quiet,size_t passes)400 void enable_sources(vw& all, bool quiet, size_t passes)
401 {
402 all.p->input->current = 0;
403 parse_cache(all, all.vm, all.data_filename, quiet);
404
405 if (all.daemon || all.active)
406 {
407 #ifdef _WIN32
408 WSAData wsaData;
409 WSAStartup(MAKEWORD(2,2), &wsaData);
410 int lastError = WSAGetLastError();
411 #endif
412 all.p->bound_sock = (int)socket(PF_INET, SOCK_STREAM, 0);
413 if (all.p->bound_sock < 0) {
414 cerr << "socket: " << strerror(errno) << endl;
415 throw exception();
416 }
417
418 int on = 1;
419 if (setsockopt(all.p->bound_sock, SOL_SOCKET, SO_REUSEADDR, (char*)&on, sizeof(on)) < 0)
420 cerr << "setsockopt SO_REUSEADDR: " << strerror(errno) << endl;
421
422 // Enable TCP Keep Alive to prevent socket leaks
423 int enableTKA = 1;
424 if (setsockopt(all.p->bound_sock, SOL_SOCKET, SO_KEEPALIVE, (char*)&enableTKA, sizeof(enableTKA)) < 0)
425 cerr << "setsockopt SO_KEEPALIVE: " << strerror(errno) << endl;
426
427 sockaddr_in address;
428 address.sin_family = AF_INET;
429 address.sin_addr.s_addr = htonl(INADDR_ANY);
430 short unsigned int port = 26542;
431 if (all.vm.count("port"))
432 port = (uint16_t)all.vm["port"].as<size_t>();
433 address.sin_port = htons(port);
434
435 // attempt to bind to socket
436 if ( ::bind(all.p->bound_sock,(sockaddr*)&address, sizeof(address)) < 0 )
437 {
438 cerr << "bind: " << strerror(errno) << endl;
439 throw exception();
440 }
441
442 // listen on socket
443 if (listen(all.p->bound_sock, 1) < 0) {
444 cerr << "listen: " << strerror(errno) << endl;
445 throw exception();
446 }
447
448 // write port file
449 if (all.vm.count("port_file"))
450 {
451 socklen_t address_size = sizeof(address);
452 if (getsockname(all.p->bound_sock, (sockaddr*)&address, &address_size) < 0)
453 {
454 cerr << "getsockname: " << strerror(errno) << endl;
455 }
456 ofstream port_file;
457 port_file.open(all.vm["port_file"].as<string>().c_str());
458 if (!port_file.is_open())
459 {
460 cerr << "error writing port file" << endl;
461 throw exception();
462 }
463 port_file << ntohs(address.sin_port) << endl;
464 port_file.close();
465 }
466
467 // background process
468 if (!all.active && daemon(1,1))
469 {
470 cerr << "daemon: " << strerror(errno) << endl;
471 throw exception();
472 }
473 // write pid file
474 if (all.vm.count("pid_file"))
475 {
476 ofstream pid_file;
477 pid_file.open(all.vm["pid_file"].as<string>().c_str());
478 if (!pid_file.is_open())
479 {
480 cerr << "error writing pid file" << endl;
481 throw exception();
482 }
483 pid_file << getpid() << endl;
484 pid_file.close();
485 }
486
487 if (all.daemon && !all.active)
488 {
489 #ifdef _WIN32
490 throw exception();
491 #else
492 // weights will be shared across processes, accessible to children
493 float* shared_weights =
494 (float*)mmap(0,(all.length() << all.reg.stride_shift) * sizeof(float),
495 PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0);
496
497 size_t float_count = all.length() << all.reg.stride_shift;
498 weight* dest = shared_weights;
499 memcpy(dest, all.reg.weight_vector, float_count*sizeof(float));
500 free(all.reg.weight_vector);
501 all.reg.weight_vector = dest;
502
503 // learning state to be shared across children
504 shared_data* sd = (shared_data *)mmap(0,sizeof(shared_data),
505 PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0);
506 memcpy(sd, all.sd, sizeof(shared_data));
507 free(all.sd);
508 all.sd = sd;
509
510 // create children
511 size_t num_children = all.num_children;
512 v_array<int> children = v_init<int>();
513 children.resize(num_children);
514 for (size_t i = 0; i < num_children; i++)
515 {
516 // fork() returns pid if parent, 0 if child
517 // store fork value and run child process if child
518 if ((children[i] = fork()) == 0)
519 goto child;
520 }
521
522 // install signal handler so we can kill children when killed
523 {
524 struct sigaction sa;
525 // specifically don't set SA_RESTART in sa.sa_flags, so that
526 // waitid will be interrupted by SIGTERM with handler installed
527 memset(&sa, 0, sizeof(sa));
528 sa.sa_handler = handle_sigterm;
529 sigaction(SIGTERM, &sa, NULL);
530 }
531
532 while (true)
533 {
534 // wait for child to change state; if finished, then respawn
535 int status;
536 pid_t pid = wait(&status);
537 if (got_sigterm)
538 {
539 for (size_t i = 0; i < num_children; i++)
540 kill(children[i], SIGTERM);
541 VW::finish(all);
542 exit(0);
543 }
544 if (pid < 0)
545 continue;
546 for (size_t i = 0; i < num_children; i++)
547 if (pid == children[i])
548 {
549 if ((children[i]=fork()) == 0)
550 goto child;
551 break;
552 }
553 }
554
555 #endif
556 }
557
558 #ifndef _WIN32
559 child:
560 #endif
561 sockaddr_in client_address;
562 socklen_t size = sizeof(client_address);
563 all.p->max_fd = 0;
564 if (!all.quiet)
565 cerr << "calling accept" << endl;
566 int f = (int)accept(all.p->bound_sock,(sockaddr*)&client_address,&size);
567 if (f < 0)
568 {
569 cerr << "accept: " << strerror(errno) << endl;
570 throw exception();
571 }
572
573 all.p->label_sock = f;
574 all.print = print_result;
575
576 all.final_prediction_sink.push_back((size_t) f);
577
578 all.p->input->files.push_back(f);
579 all.p->max_fd = max(f, all.p->max_fd);
580 if (!all.quiet)
581 cerr << "reading data from port " << port << endl;
582
583 all.p->max_fd++;
584 if(all.active)
585 all.p->reader = read_features;
586 else {
587 if (isbinary(*(all.p->input))) {
588 all.p->reader = read_cached_features;
589 all.print = binary_print_result;
590 } else {
591 all.p->reader = read_features;
592 }
593 all.p->sorted_cache = true;
594 }
595 all.p->resettable = all.p->write_cache || all.daemon;
596 }
597 else
598 {
599 if (all.p->input->files.size() > 0)
600 {
601 if (!quiet)
602 cerr << "ignoring text input in favor of cache input" << endl;
603 }
604 else
605 {
606 string temp = all.data_filename;
607 if (!quiet)
608 cerr << "Reading datafile = " << temp << endl;
609 int f = all.p->input->open_file(temp.c_str(), all.stdin_off, io_buf::READ);
610 if (f == -1 && temp.size() != 0)
611 {
612 cerr << "can't open '" << temp << "', sailing on!" << endl;
613 }
614 all.p->reader = read_features;
615 all.p->resettable = all.p->write_cache;
616 }
617 }
618
619 if (passes > 1 && !all.p->resettable)
620 {
621 cerr << all.program_name << ": need a cache file for multiple passes: try using --cache_file" << endl;
622 throw exception();
623 }
624 all.p->input->count = all.p->input->files.size();
625 if (!quiet)
626 cerr << "num sources = " << all.p->input->files.size() << endl;
627 }
628
parser_done(parser * p)629 bool parser_done(parser* p)
630 {
631 if (p->done)
632 {
633 if (p->used_index != p->begin_parsed_examples)
634 return false;
635 return true;
636 }
637 return false;
638 }
639
set_done(vw & all)640 void set_done(vw& all)
641 {
642 all.early_terminate = true;
643 mutex_lock(&all.p->examples_lock);
644 all.p->done = true;
645 mutex_unlock(&all.p->examples_lock);
646 }
647
addgrams(vw & all,size_t ngram,size_t skip_gram,v_array<feature> & atomics,v_array<audit_data> & audits,size_t initial_length,v_array<size_t> & gram_mask,size_t skips)648 void addgrams(vw& all, size_t ngram, size_t skip_gram, v_array<feature>& atomics, v_array<audit_data>& audits,
649 size_t initial_length, v_array<size_t> &gram_mask, size_t skips)
650 {
651 if (ngram == 0 && gram_mask.last() < initial_length)
652 {
653 size_t last = initial_length - gram_mask.last();
654 for(size_t i = 0; i < last; i++)
655 {
656 size_t new_index = atomics[i].weight_index;
657 for (size_t n = 1; n < gram_mask.size(); n++)
658 new_index = new_index*quadratic_constant + atomics[i+gram_mask[n]].weight_index;
659 feature f = {1.,(uint32_t)(new_index)};
660 atomics.push_back(f);
661 if ((all.audit || all.hash_inv) && audits.size() >= initial_length)
662 {
663 string feature_name(audits[i].feature);
664 for (size_t n = 1; n < gram_mask.size(); n++)
665 {
666 feature_name += string("^");
667 feature_name += string(audits[i+gram_mask[n]].feature);
668 }
669 string feature_space = string(audits[i].space);
670
671 audit_data a_feature = {NULL,NULL,new_index, 1., true};
672 a_feature.space = (char*)malloc(feature_space.length()+1);
673 strcpy(a_feature.space, feature_space.c_str());
674 a_feature.feature = (char*)malloc(feature_name.length()+1);
675 strcpy(a_feature.feature, feature_name.c_str());
676 audits.push_back(a_feature);
677 }
678 }
679 }
680 if (ngram > 0)
681 {
682 gram_mask.push_back(gram_mask.last()+1+skips);
683 addgrams(all, ngram-1, skip_gram, atomics, audits, initial_length, gram_mask, 0);
684 gram_mask.pop();
685 }
686 if (skip_gram > 0 && ngram > 0)
687 addgrams(all, ngram, skip_gram-1, atomics, audits, initial_length, gram_mask, skips+1);
688 }
689
690 /**
691 * This function adds k-skip-n-grams to the feature vector.
692 * Definition of k-skip-n-grams:
693 * Consider a feature vector - a, b, c, d, e, f
694 * 2-skip-2-grams would be - ab, ac, ad, bc, bd, be, cd, ce, cf, de, df, ef
695 * 1-skip-3-grams would be - abc, abd, acd, ace, bcd, bce, bde, bdf, cde, cdf, cef, def
696 * Note that for a n-gram, (n-1)-grams, (n-2)-grams... 2-grams are also appended
697 * The k-skip-n-grams are appended to the feature vector.
698 * Hash is evaluated using the principle h(a, b) = h(a)*X + h(b), where X is a random no.
699 * 32 random nos. are maintained in an array and are used in the hashing.
700 */
generateGrams(vw & all,example * & ex)701 void generateGrams(vw& all, example* &ex) {
702 for(unsigned char* index = ex->indices.begin; index < ex->indices.end; index++)
703 {
704 size_t length = ex->atomics[*index].size();
705 for (size_t n = 1; n < all.ngram[*index]; n++)
706 {
707 all.p->gram_mask.erase();
708 all.p->gram_mask.push_back((size_t)0);
709 addgrams(all, n, all.skips[*index], ex->atomics[*index],
710 ex->audit_features[*index],
711 length, all.p->gram_mask, 0);
712 }
713 }
714 }
715
get_unused_example(vw & all)716 example* get_unused_example(vw& all)
717 {
718 while (true)
719 {
720 mutex_lock(&all.p->examples_lock);
721 if (all.p->examples[all.p->begin_parsed_examples % all.p->ring_size].in_use == false)
722 {
723 example& ret = all.p->examples[all.p->begin_parsed_examples++ % all.p->ring_size];
724 ret.in_use = true;
725 mutex_unlock(&all.p->examples_lock);
726 return &ret;
727 }
728 else
729 condition_variable_wait(&all.p->example_unused, &all.p->examples_lock);
730 mutex_unlock(&all.p->examples_lock);
731 }
732 }
733
parse_atomic_example(vw & all,example * ae,bool do_read=true)734 bool parse_atomic_example(vw& all, example* ae, bool do_read = true)
735 {
736 if (do_read && all.p->reader(&all, ae) <= 0)
737 return false;
738
739 if(all.p->sort_features && ae->sorted == false)
740 unique_sort_features(all.audit, (uint32_t)all.parse_mask, ae);
741
742 if (all.p->write_cache)
743 {
744 all.p->lp.cache_label(&ae->l,*(all.p->output));
745 cache_features(*(all.p->output), ae, (uint32_t)all.parse_mask);
746 }
747 return true;
748 }
749
end_pass_example(vw & all,example * ae)750 void end_pass_example(vw& all, example* ae)
751 {
752 all.p->lp.default_label(&ae->l);
753 ae->end_pass = true;
754 all.p->in_pass_counter = 0;
755 }
756
feature_limit(vw & all,example * ex)757 void feature_limit(vw& all, example* ex)
758 {
759 for(unsigned char* index = ex->indices.begin; index < ex->indices.end; index++)
760 if (all.limit[*index] < ex->atomics[*index].size())
761 {
762 v_array<feature>& features = ex->atomics[*index];
763
764 qsort(features.begin, features.size(), sizeof(feature), order_features);
765
766 unique_features(features, all.limit[*index]);
767 }
768 }
769
770 namespace VW{
setup_example(vw & all,example * ae)771 void setup_example(vw& all, example* ae)
772 {
773 ae->partial_prediction = 0.;
774 ae->num_features = 0;
775 ae->total_sum_feat_sq = 0;
776 ae->loss = 0.;
777
778 ae->example_counter = (size_t)(all.p->end_parsed_examples);
779 if (!all.p->emptylines_separate_examples)
780 all.p->in_pass_counter++;
781
782 ae->test_only = is_test_only(all.p->in_pass_counter, all.holdout_period, all.holdout_after, all.holdout_set_off, all.p->emptylines_separate_examples ? (all.holdout_period-1) : 0);
783
784 if (all.p->emptylines_separate_examples && example_is_newline(*ae))
785 all.p->in_pass_counter++;
786
787 all.sd->t += all.p->lp.get_weight(&ae->l);
788 ae->example_t = (float)all.sd->t;
789
790
791 if (all.ignore_some)
792 {
793 if (all.audit || all.hash_inv)
794 for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
795 if (all.ignore[*i])
796 ae->audit_features[*i].erase();
797
798 for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
799 if (all.ignore[*i])
800 {//delete namespace
801 ae->atomics[*i].erase();
802 memmove(i,i+1,(ae->indices.end - (i+1))*sizeof(*i));
803 ae->indices.end--;
804 i--;
805 }
806 }
807
808 if(all.ngram_strings.size() > 0)
809 generateGrams(all, ae);
810
811 if (all.add_constant) {
812 //add constant feature
813 ae->indices.push_back(constant_namespace);
814 feature temp = {1,(uint32_t) constant};
815 ae->atomics[constant_namespace].push_back(temp);
816 ae->total_sum_feat_sq++;
817 }
818
819 if(all.limit_strings.size() > 0)
820 feature_limit(all,ae);
821
822 uint32_t multiplier = all.wpp << all.reg.stride_shift;
823 if(multiplier != 1) //make room for per-feature information.
824 {
825 for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
826 for(feature* j = ae->atomics[*i].begin; j != ae->atomics[*i].end; j++)
827 j->weight_index *= multiplier;
828 if (all.audit || all.hash_inv)
829 for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
830 for(audit_data* j = ae->audit_features[*i].begin; j != ae->audit_features[*i].end; j++)
831 j->weight_index *= multiplier;
832 }
833
834 for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
835 {
836 ae->num_features += ae->atomics[*i].end - ae->atomics[*i].begin;
837 ae->total_sum_feat_sq += ae->sum_feat_sq[*i];
838 }
839
840 for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++)
841 {
842 ae->num_features
843 += ae->atomics[(int)(*i)[0]].size()
844 *ae->atomics[(int)(*i)[1]].size();
845 ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]];
846 }
847
848 for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++)
849 {
850 ae->num_features
851 += ae->atomics[(int)(*i)[0]].size()
852 *ae->atomics[(int)(*i)[1]].size()
853 *ae->atomics[(int)(*i)[2]].size();
854 ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]] * ae->sum_feat_sq[(int)(*i)[1]] * ae->sum_feat_sq[(int)(*i)[2]];
855 }
856 }
857 }
858
859 namespace VW{
new_unused_example(vw & all)860 example* new_unused_example(vw& all) {
861 example* ec = get_unused_example(all);
862 all.p->lp.default_label(&ec->l);
863 all.p->begin_parsed_examples++;
864 ec->example_counter = all.p->begin_parsed_examples;
865 return ec;
866 }
read_example(vw & all,char * example_line)867 example* read_example(vw& all, char* example_line)
868 {
869 example* ret = get_unused_example(all);
870
871 read_line(all, ret, example_line);
872 parse_atomic_example(all,ret,false);
873 setup_example(all, ret);
874 all.p->end_parsed_examples++;
875
876 return ret;
877 }
878
read_example(vw & all,string example_line)879 example* read_example(vw& all, string example_line) { return read_example(all, (char*)example_line.c_str()); }
880
add_constant_feature(vw & vw,example * ec)881 void add_constant_feature(vw& vw, example*ec) {
882 uint32_t cns = constant_namespace;
883 ec->indices.push_back(cns);
884 feature temp = {1,(uint32_t) constant};
885 ec->atomics[cns].push_back(temp);
886 ec->total_sum_feat_sq++;
887 ec->num_features++;
888 }
889
add_label(example * ec,float label,float weight,float base)890 void add_label(example* ec, float label, float weight, float base)
891 {
892 ec->l.simple.label = label;
893 ec->l.simple.weight = weight;
894 ec->l.simple.initial = base;
895 }
896
import_example(vw & all,vector<feature_space> vf)897 example* import_example(vw& all, vector<feature_space> vf)
898 {
899 example* ret = get_unused_example(all);
900 all.p->lp.default_label(&ret->l);
901 for (size_t i = 0; i < vf.size();i++)
902 {
903 uint32_t index = vf[i].first;
904 ret->indices.push_back(index);
905 for (size_t j = 0; j < vf[i].second.size(); j++)
906 {
907 ret->sum_feat_sq[index] += vf[i].second[j].x * vf[i].second[j].x;
908 ret->atomics[index].push_back(vf[i].second[j]);
909 }
910 }
911 parse_atomic_example(all,ret,false);
912 setup_example(all, ret);
913 all.p->end_parsed_examples++;
914 return ret;
915 }
916
import_example(vw & all,primitive_feature_space * features,size_t len)917 example* import_example(vw& all, primitive_feature_space* features, size_t len)
918 {
919 example* ret = get_unused_example(all);
920 all.p->lp.default_label(&ret->l);
921 for (size_t i = 0; i < len;i++)
922 {
923 uint32_t index = features[i].name;
924 ret->indices.push_back(index);
925 for (size_t j = 0; j < features[i].len; j++)
926 {
927 ret->sum_feat_sq[index] += features[i].fs[j].x * features[i].fs[j].x;
928 ret->atomics[index].push_back(features[i].fs[j]);
929 }
930 }
931 parse_atomic_example(all,ret,false); // all.p->parsed_examples++;
932 setup_example(all, ret);
933
934 return ret;
935 }
936
export_example(vw & all,example * ec,size_t & len)937 primitive_feature_space* export_example(vw& all, example* ec, size_t& len)
938 {
939 len = ec->indices.size();
940 primitive_feature_space* fs_ptr = new primitive_feature_space[len];
941
942 int fs_count = 0;
943 for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
944 {
945 fs_ptr[fs_count].name = *i;
946 fs_ptr[fs_count].len = ec->atomics[*i].size();
947 fs_ptr[fs_count].fs = new feature[fs_ptr[fs_count].len];
948
949 int f_count = 0;
950 for (feature *f = ec->atomics[*i].begin; f != ec->atomics[*i].end; f++)
951 {
952 feature t = *f;
953 t.weight_index >>= all.reg.stride_shift;
954 fs_ptr[fs_count].fs[f_count] = t;
955 f_count++;
956 }
957 fs_count++;
958 }
959 return fs_ptr;
960 }
961
releaseFeatureSpace(primitive_feature_space * features,size_t len)962 void releaseFeatureSpace(primitive_feature_space* features, size_t len)
963 {
964 for (size_t i = 0; i < len;i++)
965 delete features[i].fs;
966 delete (features);
967 }
968
parse_example_label(vw & all,example & ec,string label)969 void parse_example_label(vw& all, example&ec, string label) {
970 v_array<substring> words = v_init<substring>();
971 char* cstr = (char*)label.c_str();
972 substring str = { cstr, cstr+label.length() };
973 words.push_back(str);
974 all.p->lp.parse_label(all.p, all.sd, &ec.l, words);
975 words.erase();
976 words.delete_v();
977 }
978
empty_example(vw & all,example & ec)979 void empty_example(vw& all, example& ec)
980 {
981 if (all.audit || all.hash_inv)
982 for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++)
983 {
984 for (audit_data* temp
985 = ec.audit_features[*i].begin;
986 temp != ec.audit_features[*i].end; temp++)
987 {
988 if (temp->alloced)
989 {
990 free(temp->space);
991 free(temp->feature);
992 temp->alloced=false;
993 }
994 }
995 ec.audit_features[*i].erase();
996 }
997
998 for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++)
999 {
1000 ec.atomics[*i].erase();
1001 ec.sum_feat_sq[*i]=0;
1002 }
1003
1004 ec.indices.erase();
1005 ec.tag.erase();
1006 ec.sorted = false;
1007 ec.end_pass = false;
1008 }
1009
finish_example(vw & all,example * ec)1010 void finish_example(vw& all, example* ec)
1011 {
1012 mutex_lock(&all.p->output_lock);
1013 all.p->local_example_number++;
1014 condition_variable_signal(&all.p->output_done);
1015 mutex_unlock(&all.p->output_lock);
1016
1017 empty_example(all, *ec);
1018
1019 mutex_lock(&all.p->examples_lock);
1020 assert(ec->in_use);
1021 ec->in_use = false;
1022 condition_variable_signal(&all.p->example_unused);
1023 if (all.p->done)
1024 condition_variable_signal_all(&all.p->example_available);
1025 mutex_unlock(&all.p->examples_lock);
1026 }
1027 }
1028
1029 #ifdef _WIN32
main_parse_loop(LPVOID in)1030 DWORD WINAPI main_parse_loop(LPVOID in)
1031 #else
1032 void *main_parse_loop(void *in)
1033 #endif
1034 {
1035 vw* all = (vw*) in;
1036 size_t example_number = 0; // for variable-size batch learning algorithms
1037
1038
1039 while(!all->p->done)
1040 {
1041 example* ae = get_unused_example(*all);
1042 if (!all->do_reset_source && example_number != all->pass_length && all->max_examples > example_number
1043 && parse_atomic_example(*all, ae) )
1044 {
1045 VW::setup_example(*all, ae);
1046 example_number++;
1047 }
1048 else
1049 {
1050 reset_source(*all, all->num_bits);
1051 all->do_reset_source = false;
1052 all->passes_complete++;
1053 end_pass_example(*all, ae);
1054 if (all->passes_complete == all->numpasses && example_number == all->pass_length)
1055 {
1056 all->passes_complete = 0;
1057 all->pass_length = all->pass_length*2+1;
1058 }
1059 if (all->passes_complete >= all->numpasses && all->max_examples >= example_number)
1060 {
1061 mutex_lock(&all->p->examples_lock);
1062 all->p->done = true;
1063 mutex_unlock(&all->p->examples_lock);
1064 }
1065 example_number = 0;
1066 }
1067 mutex_lock(&all->p->examples_lock);
1068 all->p->end_parsed_examples++;
1069 condition_variable_signal_all(&all->p->example_available);
1070 mutex_unlock(&all->p->examples_lock);
1071 }
1072 return NULL;
1073 }
1074
1075 namespace VW{
get_example(parser * p)1076 example* get_example(parser* p)
1077 {
1078 mutex_lock(&p->examples_lock);
1079 if (p->end_parsed_examples != p->used_index) {
1080 size_t ring_index = p->used_index++ % p->ring_size;
1081 if (!(p->examples+ring_index)->in_use)
1082 cout << p->used_index << " " << p->end_parsed_examples << " " << ring_index << endl;
1083 assert((p->examples+ring_index)->in_use);
1084 mutex_unlock(&p->examples_lock);
1085
1086 return p->examples + ring_index;
1087 }
1088 else {
1089 if (!p->done)
1090 {
1091 condition_variable_wait(&p->example_available, &p->examples_lock);
1092 mutex_unlock(&p->examples_lock);
1093 return get_example(p);
1094 }
1095 else {
1096 mutex_unlock(&p->examples_lock);
1097 return NULL;
1098 }
1099 }
1100 }
1101
get_topic_prediction(example * ec,size_t i)1102 float get_topic_prediction(example* ec, size_t i)
1103 {
1104 return ec->topic_predictions[i];
1105 }
1106
get_label(example * ec)1107 float get_label(example* ec)
1108 {
1109 return ec->l.simple.label;
1110 }
1111
get_importance(example * ec)1112 float get_importance(example* ec)
1113 {
1114 return ec->l.simple.weight;
1115 }
1116
get_initial(example * ec)1117 float get_initial(example* ec)
1118 {
1119 return ec->l.simple.initial;
1120 }
1121
get_prediction(example * ec)1122 float get_prediction(example* ec)
1123 {
1124 return ec->pred.scalar;
1125 }
1126
get_cost_sensitive_prediction(example * ec)1127 float get_cost_sensitive_prediction(example* ec)
1128 {
1129 return (float)ec->pred.multiclass;
1130 }
1131
get_tag_length(example * ec)1132 size_t get_tag_length(example* ec)
1133 {
1134 return ec->tag.size();
1135 }
1136
get_tag(example * ec)1137 const char* get_tag(example* ec)
1138 {
1139 return ec->tag.begin;
1140 }
1141
get_feature_number(example * ec)1142 size_t get_feature_number(example* ec)
1143 {
1144 return ec->num_features;
1145 }
1146 }
1147
initialize_examples(vw & all)1148 void initialize_examples(vw& all)
1149 {
1150 all.p->used_index = 0;
1151 all.p->begin_parsed_examples = 0;
1152 all.p->end_parsed_examples = 0;
1153 all.p->done = false;
1154
1155 all.p->examples = calloc_or_die<example>(all.p->ring_size);
1156
1157 for (size_t i = 0; i < all.p->ring_size; i++)
1158 {
1159 memset(&all.p->examples[i].l, 0, sizeof(polylabel));
1160 all.p->examples[i].in_use = false;
1161 }
1162 }
1163
adjust_used_index(vw & all)1164 void adjust_used_index(vw& all)
1165 {
1166 all.p->used_index=all.p->begin_parsed_examples;
1167 }
1168
initialize_parser_datastructures(vw & all)1169 void initialize_parser_datastructures(vw& all)
1170 {
1171 initialize_examples(all);
1172 initialize_mutex(&all.p->examples_lock);
1173 initialize_condition_variable(&all.p->example_available);
1174 initialize_condition_variable(&all.p->example_unused);
1175 initialize_mutex(&all.p->output_lock);
1176 initialize_condition_variable(&all.p->output_done);
1177 }
1178
1179 namespace VW {
start_parser(vw & all,bool init_structures)1180 void start_parser(vw& all, bool init_structures)
1181 {
1182 if (init_structures)
1183 initialize_parser_datastructures(all);
1184 #ifndef _WIN32
1185 pthread_create(&all.parse_thread, NULL, main_parse_loop, &all);
1186 #else
1187 all.parse_thread = ::CreateThread(NULL, 0, static_cast<LPTHREAD_START_ROUTINE>(main_parse_loop), &all, NULL, NULL);
1188 #endif
1189 }
1190 }
free_parser(vw & all)1191 void free_parser(vw& all)
1192 {
1193 all.p->channels.delete_v();
1194 all.p->words.delete_v();
1195 all.p->name.delete_v();
1196
1197 if(all.ngram_strings.size() > 0)
1198 all.p->gram_mask.delete_v();
1199
1200 for (size_t i = 0; i < all.p->ring_size; i++)
1201 {
1202 dealloc_example(all.p->lp.delete_label, all.p->examples[i]);
1203 }
1204 free(all.p->examples);
1205
1206 io_buf* output = all.p->output;
1207 if (output != NULL)
1208 {
1209 output->finalname.delete_v();
1210 output->currentname.delete_v();
1211 }
1212
1213 all.p->counts.delete_v();
1214 }
1215
release_parser_datastructures(vw & all)1216 void release_parser_datastructures(vw& all)
1217 {
1218 delete_mutex(&all.p->examples_lock);
1219 delete_mutex(&all.p->output_lock);
1220 }
1221
1222 namespace VW {
end_parser(vw & all)1223 void end_parser(vw& all)
1224 {
1225 #ifndef _WIN32
1226 pthread_join(all.parse_thread, NULL);
1227 #else
1228 ::WaitForSingleObject(all.parse_thread, INFINITE);
1229 ::CloseHandle(all.parse_thread);
1230 #endif
1231 release_parser_datastructures(all);
1232 }
1233 }
1234