1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved.  Released under a BSD (revised)
4 license as described in the file LICENSE.
5  */
6 #include <sys/types.h>
7 
8 #ifndef _WIN32
9 #include <sys/mman.h>
10 #include <sys/wait.h>
11 #include <unistd.h>
12 #include <netinet/tcp.h>
13 #endif
14 
15 #include <signal.h>
16 
17 #include <fstream>
18 
19 #ifdef _WIN32
20 #include <winsock2.h>
21 #include <Windows.h>
22 #include <io.h>
23 typedef int socklen_t;
24 
daemon(int a,int b)25 int daemon(int a, int b)
26 {
27 	exit(0);
28 	return 0;
29 }
getpid()30 int getpid()
31 {
32 	return (int) ::GetCurrentProcessId();
33 }
34 #else
35 #include <netdb.h>
36 #endif
37 #include <boost/program_options.hpp>
38 
39 #ifdef __FreeBSD__
40 #include <netinet/in.h>
41 #endif
42 
43 #include <errno.h>
44 #include <stdio.h>
45 #include <assert.h>
46 namespace po = boost::program_options;
47 
48 #include "parse_example.h"
49 #include "cache.h"
50 #include "unique_sort.h"
51 #include "constant.h"
52 #include "vw.h"
53 
54 using namespace std;
55 
initialize_mutex(MUTEX * pm)56 void initialize_mutex(MUTEX * pm)
57 {
58 #ifndef _WIN32
59   pthread_mutex_init(pm, NULL);
60 #else
61 	::InitializeCriticalSection(pm);
62 #endif
63 }
64 
delete_mutex(MUTEX * pm)65 void delete_mutex(MUTEX * pm)
66 {
67 #ifndef _WIN32
68 	// no operation necessary here
69 #else
70 	::DeleteCriticalSection(pm);
71 #endif
72 }
73 
initialize_condition_variable(CV * pcv)74 void initialize_condition_variable(CV * pcv)
75 {
76 #ifndef _WIN32
77   pthread_cond_init(pcv, NULL);
78 #else
79 	::InitializeConditionVariable(pcv);
80 #endif
81 }
82 
mutex_lock(MUTEX * pm)83 void mutex_lock(MUTEX * pm)
84 {
85 #ifndef _WIN32
86 	pthread_mutex_lock(pm);
87 #else
88 	::EnterCriticalSection(pm);
89 #endif
90 }
91 
mutex_unlock(MUTEX * pm)92 void mutex_unlock(MUTEX * pm)
93 {
94 #ifndef _WIN32
95 	pthread_mutex_unlock(pm);
96 #else
97 	::LeaveCriticalSection(pm);
98 #endif
99 }
100 
condition_variable_wait(CV * pcv,MUTEX * pm)101 void condition_variable_wait(CV * pcv, MUTEX * pm)
102 {
103 #ifndef _WIN32
104 	pthread_cond_wait(pcv, pm);
105 #else
106 	::SleepConditionVariableCS(pcv, pm, INFINITE);
107 #endif
108 }
109 
condition_variable_signal(CV * pcv)110 void condition_variable_signal(CV * pcv)
111 {
112 #ifndef _WIN32
113 	pthread_cond_signal(pcv);
114 #else
115 	::WakeConditionVariable(pcv);
116 #endif
117 }
118 
condition_variable_signal_all(CV * pcv)119 void condition_variable_signal_all(CV * pcv)
120 {
121 #ifndef _WIN32
122 	pthread_cond_broadcast(pcv);
123 #else
124 	::WakeAllConditionVariable(pcv);
125 #endif
126 }
127 
128 //This should not? matter in a library mode.
129 bool got_sigterm;
130 
handle_sigterm(int)131 void handle_sigterm (int)
132 {
133   got_sigterm = true;
134 }
135 
is_test_only(uint32_t counter,uint32_t period,uint32_t after,bool holdout_off,uint32_t target_modulus)136 bool is_test_only(uint32_t counter, uint32_t period, uint32_t after, bool holdout_off, uint32_t target_modulus)  // target should be 0 in the normal case, or period-1 in the case that emptylines separate examples
137 {
138   if(holdout_off) return false;
139   //cerr << "(" << counter << "," << period << "," << target_modulus << ")";
140   if (after == 0) // hold out by period
141     return (counter % period == target_modulus);
142   else // hold out by position
143     return (counter >= after);
144 }
145 
new_parser()146 parser* new_parser()
147 {
148   parser& ret = calloc_or_die<parser>();
149   ret.input = new io_buf;
150   ret.output = new io_buf;
151   ret.local_example_number = 0;
152   ret.in_pass_counter = 0;
153   ret.ring_size = 1 << 8;
154   ret.done = false;
155   ret.used_index = 0;
156 
157   return &ret;
158 }
159 
set_compressed(parser * par)160 void set_compressed(parser* par){
161   finalize_source(par);
162   par->input = new comp_io_buf;
163   par->output = new comp_io_buf;
164 }
165 
cache_numbits(io_buf * buf,int filepointer)166 uint32_t cache_numbits(io_buf* buf, int filepointer)
167 {
168   v_array<char> t = v_init<char>();
169 
170   uint32_t v_length;
171   buf->read_file(filepointer, (char*)&v_length, sizeof(v_length));
172   if(v_length>29){
173     cerr << "cache version too long, cache file is probably invalid" << endl;
174     throw exception();
175   }
176   else if (v_length == 0) {
177     cerr << "cache version too short, cache file is probably invalid" << endl;
178     throw exception();
179   }
180 
181   t.erase();
182   if (t.size() < v_length)
183     t.resize(v_length);
184 
185   buf->read_file(filepointer,t.begin,v_length);
186   version_struct v_tmp(t.begin);
187   if ( v_tmp != version )
188     {
189       cout << "cache has possibly incompatible version, rebuilding" << endl;
190       t.delete_v();
191       return 0;
192     }
193 
194   char temp;
195   if (buf->read_file(filepointer, &temp, 1) < 1)
196     {
197       cout << "failed to read" << endl;
198       throw exception();
199     }
200   if (temp != 'c')
201     {
202       cout << "data file is not a cache file" << endl;
203       throw exception();
204     }
205 
206   t.delete_v();
207 
208   const int total = sizeof(uint32_t);
209   char* p[total];
210   if (buf->read_file(filepointer, p, total) < total)
211     {
212       return true;
213     }
214 
215   uint32_t cache_numbits = *(uint32_t *)p;
216   return cache_numbits;
217 }
218 
member(v_array<int> ids,int id)219 bool member(v_array<int> ids, int id)
220 {
221   for (size_t i = 0; i < ids.size(); i++)
222     if (ids[i] == id)
223       return true;
224   return false;
225 }
226 
reset_source(vw & all,size_t numbits)227 void reset_source(vw& all, size_t numbits)
228 {
229   io_buf* input = all.p->input;
230   input->current = 0;
231   if (all.p->write_cache)
232     {
233       all.p->output->flush();
234       all.p->write_cache = false;
235       all.p->output->close_file();
236       remove(all.p->output->finalname.begin);
237       rename(all.p->output->currentname.begin, all.p->output->finalname.begin);
238       while(input->num_files() > 0)
239 	if (input->compressed())
240 	  input->close_file();
241 	else
242 	  {
243 	    int fd = input->files.pop();
244 	    if (!member(all.final_prediction_sink, (size_t) fd))
245 	      io_buf::close_file_or_socket(fd);
246 	  }
247       input->open_file(all.p->output->finalname.begin, all.stdin_off, io_buf::READ); //pushing is merged into open_file
248       all.p->reader = read_cached_features;
249     }
250   if ( all.p->resettable == true )
251     {
252       if (all.daemon)
253 	{
254 	  // wait for all predictions to be sent back to client
255 	  mutex_lock(&all.p->output_lock);
256 	  while (all.p->local_example_number != all.p->end_parsed_examples)
257 	    condition_variable_wait(&all.p->output_done, &all.p->output_lock);
258 	  mutex_unlock(&all.p->output_lock);
259 
260 	  // close socket, erase final prediction sink and socket
261 	  io_buf::close_file_or_socket(all.p->input->files[0]);
262 	  all.final_prediction_sink.erase();
263 	  all.p->input->files.erase();
264 
265 	  sockaddr_in client_address;
266 	  socklen_t size = sizeof(client_address);
267 	  int f = (int)accept(all.p->bound_sock,(sockaddr*)&client_address,&size);
268 	  if (f < 0)
269 	    {
270 	      cerr << "accept: " << strerror(errno) << endl;
271 	      throw exception();
272 	    }
273 
274 	  // note: breaking cluster parallel online learning by dropping support for id
275 
276 	  all.final_prediction_sink.push_back((size_t) f);
277 	  all.p->input->files.push_back(f);
278 
279 	  if (isbinary(*(all.p->input))) {
280 	    all.p->reader = read_cached_features;
281 	    all.print = binary_print_result;
282 	  } else {
283 	    all.p->reader = read_features;
284 	    all.print = print_result;
285 	  }
286 	}
287       else {
288 	for (size_t i = 0; i < input->files.size();i++)
289 	  {
290 	    input->reset_file(input->files[i]);
291 	    if (cache_numbits(input, input->files[i]) < numbits) {
292 	      cerr << "argh, a bug in caching of some sort!  Exiting\n" ;
293 	      throw exception();
294 	    }
295 	  }
296       }
297     }
298 }
299 
finalize_source(parser * p)300 void finalize_source(parser* p)
301 {
302 #ifdef _WIN32
303   int f = _fileno(stdin);
304 #else
305   int f = fileno(stdin);
306 #endif
307   while (!p->input->files.empty() && p->input->files.last() == f)
308     p->input->files.pop();
309   p->input->close_files();
310 
311   delete p->input;
312   p->output->close_files();
313   delete p->output;
314 }
315 
make_write_cache(vw & all,string & newname,bool quiet)316 void make_write_cache(vw& all, string &newname, bool quiet)
317 {
318   io_buf* output = all.p->output;
319   if (output->files.size() != 0){
320     cerr << "Warning: you tried to make two write caches.  Only the first one will be made." << endl;
321     return;
322   }
323 
324   string temp = newname+string(".writing");
325   push_many(output->currentname,temp.c_str(),temp.length()+1);
326 
327   int f = output->open_file(temp.c_str(), all.stdin_off, io_buf::WRITE);
328   if (f == -1) {
329     cerr << "can't create cache file !" << endl;
330     return;
331   }
332 
333   uint32_t v_length = (uint32_t)version.to_string().length()+1;
334 
335   output->write_file(f, &v_length, sizeof(v_length));
336   output->write_file(f,version.to_string().c_str(),v_length);
337   output->write_file(f,"c",1);
338   output->write_file(f, &all.num_bits, sizeof(all.num_bits));
339 
340   push_many(output->finalname,newname.c_str(),newname.length()+1);
341   all.p->write_cache = true;
342   if (!quiet)
343     cerr << "creating cache_file = " << newname << endl;
344 }
345 
parse_cache(vw & all,po::variables_map & vm,string source,bool quiet)346 void parse_cache(vw& all, po::variables_map &vm, string source,
347 		 bool quiet)
348 {
349   vector<string> caches;
350   if (vm.count("cache_file"))
351     caches = vm["cache_file"].as< vector<string> >();
352   if (vm.count("cache"))
353     caches.push_back(source+string(".cache"));
354 
355   all.p->write_cache = false;
356 
357   for (size_t i = 0; i < caches.size(); i++)
358     {
359       int f = -1;
360       if (!vm.count("kill_cache"))
361 	try {
362         f = all.p->input->open_file(caches[i].c_str(), all.stdin_off, io_buf::READ);
363 	}
364 	catch (exception e){ f = -1;}
365       if (f == -1)
366 	make_write_cache(all, caches[i], quiet);
367       else {
368 	uint32_t c = cache_numbits(all.p->input, f);
369 	if (c < all.num_bits) {
370           all.p->input->close_file();
371 	  make_write_cache(all, caches[i], quiet);
372 	}
373 	else {
374 	  if (!quiet)
375 	    cerr << "using cache_file = " << caches[i].c_str() << endl;
376 	  all.p->reader = read_cached_features;
377 	  if (c == all.num_bits)
378 	    all.p->sorted_cache = true;
379 	  else
380 	    all.p->sorted_cache = false;
381 	  all.p->resettable = true;
382 	}
383       }
384     }
385 
386   all.parse_mask = (1 << all.num_bits) - 1;
387   if (caches.size() == 0)
388     {
389       if (!quiet)
390 	cerr << "using no cache" << endl;
391       all.p->output->space.delete_v();
392     }
393 }
394 
395 //For macs
396 #ifndef MAP_ANONYMOUS
397 # define MAP_ANONYMOUS MAP_ANON
398 #endif
399 
enable_sources(vw & all,bool quiet,size_t passes)400 void enable_sources(vw& all, bool quiet, size_t passes)
401 {
402   all.p->input->current = 0;
403   parse_cache(all, all.vm, all.data_filename, quiet);
404 
405   if (all.daemon || all.active)
406     {
407 #ifdef _WIN32
408       WSAData wsaData;
409       WSAStartup(MAKEWORD(2,2), &wsaData);
410       int lastError = WSAGetLastError();
411 #endif
412       all.p->bound_sock = (int)socket(PF_INET, SOCK_STREAM, 0);
413       if (all.p->bound_sock < 0) {
414 	cerr << "socket: " << strerror(errno) << endl;
415 	throw exception();
416       }
417 
418       int on = 1;
419       if (setsockopt(all.p->bound_sock, SOL_SOCKET, SO_REUSEADDR, (char*)&on, sizeof(on)) < 0)
420 	cerr << "setsockopt SO_REUSEADDR: " << strerror(errno) << endl;
421 
422       // Enable TCP Keep Alive to prevent socket leaks
423       int enableTKA = 1;
424       if (setsockopt(all.p->bound_sock, SOL_SOCKET, SO_KEEPALIVE, (char*)&enableTKA, sizeof(enableTKA)) < 0)
425         cerr << "setsockopt SO_KEEPALIVE: " << strerror(errno) << endl;
426 
427       sockaddr_in address;
428       address.sin_family = AF_INET;
429       address.sin_addr.s_addr = htonl(INADDR_ANY);
430       short unsigned int port = 26542;
431       if (all.vm.count("port"))
432 	port = (uint16_t)all.vm["port"].as<size_t>();
433       address.sin_port = htons(port);
434 
435       // attempt to bind to socket
436       if ( ::bind(all.p->bound_sock,(sockaddr*)&address, sizeof(address)) < 0 )
437 	{
438 	  cerr << "bind: " << strerror(errno) << endl;
439 	  throw exception();
440 	}
441 
442       // listen on socket
443       if (listen(all.p->bound_sock, 1) < 0) {
444         cerr << "listen: " << strerror(errno) << endl;
445         throw exception();
446       }
447 
448       // write port file
449       if (all.vm.count("port_file"))
450 	{
451           socklen_t address_size = sizeof(address);
452           if (getsockname(all.p->bound_sock, (sockaddr*)&address, &address_size) < 0)
453             {
454               cerr << "getsockname: " << strerror(errno) << endl;
455             }
456 	  ofstream port_file;
457 	  port_file.open(all.vm["port_file"].as<string>().c_str());
458 	  if (!port_file.is_open())
459 	    {
460 	      cerr << "error writing port file" << endl;
461 	      throw exception();
462 	    }
463 	  port_file << ntohs(address.sin_port) << endl;
464 	  port_file.close();
465 	}
466 
467       // background process
468       if (!all.active && daemon(1,1))
469 	{
470 	  cerr << "daemon: " << strerror(errno) << endl;
471 	  throw exception();
472 	}
473       // write pid file
474       if (all.vm.count("pid_file"))
475 	{
476 	  ofstream pid_file;
477 	  pid_file.open(all.vm["pid_file"].as<string>().c_str());
478 	  if (!pid_file.is_open())
479 	    {
480 	      cerr << "error writing pid file" << endl;
481 	      throw exception();
482 	    }
483 	  pid_file << getpid() << endl;
484 	  pid_file.close();
485 	}
486 
487       if (all.daemon && !all.active)
488 	{
489 #ifdef _WIN32
490 		throw exception();
491 #else
492 	  // weights will be shared across processes, accessible to children
493 	  float* shared_weights =
494 	    (float*)mmap(0,(all.length() << all.reg.stride_shift) * sizeof(float),
495 			 PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0);
496 
497 	  size_t float_count = all.length() << all.reg.stride_shift;
498 	  weight* dest = shared_weights;
499 	  memcpy(dest, all.reg.weight_vector, float_count*sizeof(float));
500 	  free(all.reg.weight_vector);
501 	  all.reg.weight_vector = dest;
502 
503 	  // learning state to be shared across children
504 	  shared_data* sd = (shared_data *)mmap(0,sizeof(shared_data),
505 			 PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0);
506 	  memcpy(sd, all.sd, sizeof(shared_data));
507 	  free(all.sd);
508 	  all.sd = sd;
509 
510 	  // create children
511 	  size_t num_children = all.num_children;
512 	  v_array<int> children = v_init<int>();
513 	  children.resize(num_children);
514 	  for (size_t i = 0; i < num_children; i++)
515 	    {
516 	      // fork() returns pid if parent, 0 if child
517 	      // store fork value and run child process if child
518 	      if ((children[i] = fork()) == 0)
519 		goto child;
520 	    }
521 
522 	  // install signal handler so we can kill children when killed
523 	  {
524 	    struct sigaction sa;
525 	    // specifically don't set SA_RESTART in sa.sa_flags, so that
526 	    // waitid will be interrupted by SIGTERM with handler installed
527 	    memset(&sa, 0, sizeof(sa));
528 	    sa.sa_handler = handle_sigterm;
529 	    sigaction(SIGTERM, &sa, NULL);
530 	  }
531 
532 	  while (true)
533 	    {
534 	      // wait for child to change state; if finished, then respawn
535 	      int status;
536 	      pid_t pid = wait(&status);
537 	      if (got_sigterm)
538 		{
539 		  for (size_t i = 0; i < num_children; i++)
540 		    kill(children[i], SIGTERM);
541                   VW::finish(all);
542 		  exit(0);
543 		}
544 	      if (pid < 0)
545 		continue;
546 	      for (size_t i = 0; i < num_children; i++)
547 		if (pid == children[i])
548 		  {
549 		    if ((children[i]=fork()) == 0)
550 		      goto child;
551 		    break;
552 		  }
553 	    }
554 
555 #endif
556 	}
557 
558 #ifndef _WIN32
559 	child:
560 #endif
561       sockaddr_in client_address;
562       socklen_t size = sizeof(client_address);
563       all.p->max_fd = 0;
564       if (!all.quiet)
565 	cerr << "calling accept" << endl;
566       int f = (int)accept(all.p->bound_sock,(sockaddr*)&client_address,&size);
567       if (f < 0)
568 	{
569 	  cerr << "accept: " << strerror(errno) << endl;
570 	  throw exception();
571 	}
572 
573       all.p->label_sock = f;
574       all.print = print_result;
575 
576       all.final_prediction_sink.push_back((size_t) f);
577 
578       all.p->input->files.push_back(f);
579       all.p->max_fd = max(f, all.p->max_fd);
580       if (!all.quiet)
581 	cerr << "reading data from port " << port << endl;
582 
583       all.p->max_fd++;
584       if(all.active)
585 	all.p->reader = read_features;
586       else {
587 	if (isbinary(*(all.p->input))) {
588 	  all.p->reader = read_cached_features;
589 	  all.print = binary_print_result;
590 	} else {
591 	  all.p->reader = read_features;
592 	}
593 	all.p->sorted_cache = true;
594       }
595       all.p->resettable = all.p->write_cache || all.daemon;
596     }
597   else
598     {
599       if (all.p->input->files.size() > 0)
600 	{
601 	  if (!quiet)
602 	    cerr << "ignoring text input in favor of cache input" << endl;
603 	}
604       else
605 	{
606 	  string temp = all.data_filename;
607 	  if (!quiet)
608 	    cerr << "Reading datafile = " << temp << endl;
609 	  int f = all.p->input->open_file(temp.c_str(), all.stdin_off, io_buf::READ);
610 	  if (f == -1 && temp.size() != 0)
611 	    {
612 			cerr << "can't open '" << temp << "', sailing on!" << endl;
613 	    }
614 	  all.p->reader = read_features;
615 	  all.p->resettable = all.p->write_cache;
616 	}
617     }
618 
619   if (passes > 1 && !all.p->resettable)
620     {
621       cerr << all.program_name << ": need a cache file for multiple passes: try using --cache_file" << endl;
622       throw exception();
623     }
624   all.p->input->count = all.p->input->files.size();
625   if (!quiet)
626     cerr << "num sources = " << all.p->input->files.size() << endl;
627 }
628 
parser_done(parser * p)629 bool parser_done(parser* p)
630 {
631   if (p->done)
632     {
633       if (p->used_index != p->begin_parsed_examples)
634 	return false;
635       return true;
636     }
637   return false;
638 }
639 
set_done(vw & all)640 void set_done(vw& all)
641 {
642   all.early_terminate = true;
643   mutex_lock(&all.p->examples_lock);
644   all.p->done = true;
645   mutex_unlock(&all.p->examples_lock);
646 }
647 
addgrams(vw & all,size_t ngram,size_t skip_gram,v_array<feature> & atomics,v_array<audit_data> & audits,size_t initial_length,v_array<size_t> & gram_mask,size_t skips)648 void addgrams(vw& all, size_t ngram, size_t skip_gram, v_array<feature>& atomics, v_array<audit_data>& audits,
649 	      size_t initial_length, v_array<size_t> &gram_mask, size_t skips)
650 {
651   if (ngram == 0 && gram_mask.last() < initial_length)
652     {
653       size_t last = initial_length - gram_mask.last();
654       for(size_t i = 0; i < last; i++)
655 	{
656 	  size_t new_index = atomics[i].weight_index;
657 	  for (size_t n = 1; n < gram_mask.size(); n++)
658 	    new_index = new_index*quadratic_constant + atomics[i+gram_mask[n]].weight_index;
659 	  feature f = {1.,(uint32_t)(new_index)};
660 	  atomics.push_back(f);
661 	  if ((all.audit || all.hash_inv) && audits.size() >= initial_length)
662 	    {
663 	      string feature_name(audits[i].feature);
664 	      for (size_t n = 1; n < gram_mask.size(); n++)
665 		{
666 		  feature_name += string("^");
667 		  feature_name += string(audits[i+gram_mask[n]].feature);
668 		}
669 	      string feature_space = string(audits[i].space);
670 
671 	      audit_data a_feature = {NULL,NULL,new_index, 1., true};
672 	      a_feature.space = (char*)malloc(feature_space.length()+1);
673 	      strcpy(a_feature.space, feature_space.c_str());
674 	      a_feature.feature = (char*)malloc(feature_name.length()+1);
675 	      strcpy(a_feature.feature, feature_name.c_str());
676 	      audits.push_back(a_feature);
677 	    }
678 	}
679     }
680   if (ngram > 0)
681     {
682       gram_mask.push_back(gram_mask.last()+1+skips);
683       addgrams(all, ngram-1, skip_gram, atomics, audits, initial_length, gram_mask, 0);
684       gram_mask.pop();
685     }
686   if (skip_gram > 0 && ngram > 0)
687     addgrams(all, ngram, skip_gram-1, atomics, audits, initial_length, gram_mask, skips+1);
688 }
689 
690 /**
691  * This function adds k-skip-n-grams to the feature vector.
692  * Definition of k-skip-n-grams:
693  * Consider a feature vector - a, b, c, d, e, f
694  * 2-skip-2-grams would be - ab, ac, ad, bc, bd, be, cd, ce, cf, de, df, ef
695  * 1-skip-3-grams would be - abc, abd, acd, ace, bcd, bce, bde, bdf, cde, cdf, cef, def
696  * Note that for a n-gram, (n-1)-grams, (n-2)-grams... 2-grams are also appended
697  * The k-skip-n-grams are appended to the feature vector.
698  * Hash is evaluated using the principle h(a, b) = h(a)*X + h(b), where X is a random no.
699  * 32 random nos. are maintained in an array and are used in the hashing.
700  */
generateGrams(vw & all,example * & ex)701 void generateGrams(vw& all, example* &ex) {
702   for(unsigned char* index = ex->indices.begin; index < ex->indices.end; index++)
703     {
704       size_t length = ex->atomics[*index].size();
705       for (size_t n = 1; n < all.ngram[*index]; n++)
706 	{
707 	  all.p->gram_mask.erase();
708 	  all.p->gram_mask.push_back((size_t)0);
709 	  addgrams(all, n, all.skips[*index], ex->atomics[*index],
710 		   ex->audit_features[*index],
711 		   length, all.p->gram_mask, 0);
712 	}
713     }
714 }
715 
get_unused_example(vw & all)716 example* get_unused_example(vw& all)
717 {
718   while (true)
719     {
720       mutex_lock(&all.p->examples_lock);
721       if (all.p->examples[all.p->begin_parsed_examples % all.p->ring_size].in_use == false)
722 	{
723 	  example& ret = all.p->examples[all.p->begin_parsed_examples++ % all.p->ring_size];
724 	  ret.in_use = true;
725 	  mutex_unlock(&all.p->examples_lock);
726 	  return &ret;
727 	}
728       else
729 	condition_variable_wait(&all.p->example_unused, &all.p->examples_lock);
730       mutex_unlock(&all.p->examples_lock);
731     }
732 }
733 
parse_atomic_example(vw & all,example * ae,bool do_read=true)734 bool parse_atomic_example(vw& all, example* ae, bool do_read = true)
735 {
736   if (do_read && all.p->reader(&all, ae) <= 0)
737     return false;
738 
739   if(all.p->sort_features && ae->sorted == false)
740     unique_sort_features(all.audit, (uint32_t)all.parse_mask, ae);
741 
742   if (all.p->write_cache)
743     {
744       all.p->lp.cache_label(&ae->l,*(all.p->output));
745       cache_features(*(all.p->output), ae, (uint32_t)all.parse_mask);
746     }
747   return true;
748 }
749 
end_pass_example(vw & all,example * ae)750 void end_pass_example(vw& all, example* ae)
751 {
752   all.p->lp.default_label(&ae->l);
753   ae->end_pass = true;
754   all.p->in_pass_counter = 0;
755 }
756 
feature_limit(vw & all,example * ex)757 void feature_limit(vw& all, example* ex)
758 {
759   for(unsigned char* index = ex->indices.begin; index < ex->indices.end; index++)
760     if (all.limit[*index] < ex->atomics[*index].size())
761       {
762 	v_array<feature>& features = ex->atomics[*index];
763 
764 	qsort(features.begin, features.size(), sizeof(feature), order_features);
765 
766 	unique_features(features, all.limit[*index]);
767       }
768 }
769 
770 namespace VW{
setup_example(vw & all,example * ae)771 void setup_example(vw& all, example* ae)
772 {
773   ae->partial_prediction = 0.;
774   ae->num_features = 0;
775   ae->total_sum_feat_sq = 0;
776   ae->loss = 0.;
777 
778   ae->example_counter = (size_t)(all.p->end_parsed_examples);
779   if (!all.p->emptylines_separate_examples)
780     all.p->in_pass_counter++;
781 
782   ae->test_only = is_test_only(all.p->in_pass_counter, all.holdout_period, all.holdout_after, all.holdout_set_off, all.p->emptylines_separate_examples ? (all.holdout_period-1) : 0);
783 
784   if (all.p->emptylines_separate_examples && example_is_newline(*ae))
785     all.p->in_pass_counter++;
786 
787   all.sd->t += all.p->lp.get_weight(&ae->l);
788   ae->example_t = (float)all.sd->t;
789 
790 
791   if (all.ignore_some)
792     {
793       if (all.audit || all.hash_inv)
794 	for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
795 	  if (all.ignore[*i])
796 	    ae->audit_features[*i].erase();
797 
798       for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
799 	if (all.ignore[*i])
800 	  {//delete namespace
801 	    ae->atomics[*i].erase();
802 	    memmove(i,i+1,(ae->indices.end - (i+1))*sizeof(*i));
803 	    ae->indices.end--;
804 	    i--;
805 	  }
806     }
807 
808   if(all.ngram_strings.size() > 0)
809     generateGrams(all, ae);
810 
811   if (all.add_constant) {
812     //add constant feature
813     ae->indices.push_back(constant_namespace);
814     feature temp = {1,(uint32_t) constant};
815     ae->atomics[constant_namespace].push_back(temp);
816     ae->total_sum_feat_sq++;
817   }
818 
819   if(all.limit_strings.size() > 0)
820     feature_limit(all,ae);
821 
822   uint32_t multiplier = all.wpp << all.reg.stride_shift;
823   if(multiplier != 1) //make room for per-feature information.
824     {
825       for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
826 	for(feature* j = ae->atomics[*i].begin; j != ae->atomics[*i].end; j++)
827 	  j->weight_index *= multiplier;
828       if (all.audit || all.hash_inv)
829 	for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
830 	  for(audit_data* j = ae->audit_features[*i].begin; j != ae->audit_features[*i].end; j++)
831 	    j->weight_index *= multiplier;
832     }
833 
834   for (unsigned char* i = ae->indices.begin; i != ae->indices.end; i++)
835     {
836       ae->num_features += ae->atomics[*i].end - ae->atomics[*i].begin;
837       ae->total_sum_feat_sq += ae->sum_feat_sq[*i];
838     }
839 
840   for (vector<string>::iterator i = all.pairs.begin(); i != all.pairs.end();i++)
841     {
842       ae->num_features
843 	+= ae->atomics[(int)(*i)[0]].size()
844 	*ae->atomics[(int)(*i)[1]].size();
845       ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]]*ae->sum_feat_sq[(int)(*i)[1]];
846     }
847 
848   for (vector<string>::iterator i = all.triples.begin(); i != all.triples.end();i++)
849     {
850       ae->num_features
851 	+= ae->atomics[(int)(*i)[0]].size()
852 	*ae->atomics[(int)(*i)[1]].size()
853 	*ae->atomics[(int)(*i)[2]].size();
854       ae->total_sum_feat_sq += ae->sum_feat_sq[(int)(*i)[0]] * ae->sum_feat_sq[(int)(*i)[1]] * ae->sum_feat_sq[(int)(*i)[2]];
855     }
856 }
857 }
858 
859 namespace VW{
new_unused_example(vw & all)860   example* new_unused_example(vw& all) {
861     example* ec = get_unused_example(all);
862     all.p->lp.default_label(&ec->l);
863     all.p->begin_parsed_examples++;
864     ec->example_counter = all.p->begin_parsed_examples;
865     return ec;
866   }
read_example(vw & all,char * example_line)867   example* read_example(vw& all, char* example_line)
868   {
869     example* ret = get_unused_example(all);
870 
871     read_line(all, ret, example_line);
872 	parse_atomic_example(all,ret,false);
873     setup_example(all, ret);
874     all.p->end_parsed_examples++;
875 
876     return ret;
877   }
878 
read_example(vw & all,string example_line)879   example* read_example(vw& all, string example_line) { return read_example(all, (char*)example_line.c_str()); }
880 
add_constant_feature(vw & vw,example * ec)881   void add_constant_feature(vw& vw, example*ec) {
882     uint32_t cns = constant_namespace;
883     ec->indices.push_back(cns);
884     feature temp = {1,(uint32_t) constant};
885     ec->atomics[cns].push_back(temp);
886     ec->total_sum_feat_sq++;
887     ec->num_features++;
888   }
889 
add_label(example * ec,float label,float weight,float base)890   void add_label(example* ec, float label, float weight, float base)
891   {
892     ec->l.simple.label = label;
893     ec->l.simple.weight = weight;
894     ec->l.simple.initial = base;
895   }
896 
import_example(vw & all,vector<feature_space> vf)897   example* import_example(vw& all, vector<feature_space> vf)
898   {
899     example* ret = get_unused_example(all);
900     all.p->lp.default_label(&ret->l);
901     for (size_t i = 0; i < vf.size();i++)
902       {
903 	uint32_t index = vf[i].first;
904 	ret->indices.push_back(index);
905 	for (size_t j = 0; j < vf[i].second.size(); j++)
906 	  {
907 	    ret->sum_feat_sq[index] += vf[i].second[j].x * vf[i].second[j].x;
908 	    ret->atomics[index].push_back(vf[i].second[j]);
909 	  }
910       }
911 	parse_atomic_example(all,ret,false);
912     setup_example(all, ret);
913     all.p->end_parsed_examples++;
914     return ret;
915   }
916 
import_example(vw & all,primitive_feature_space * features,size_t len)917   example* import_example(vw& all, primitive_feature_space* features, size_t len)
918   {
919     example* ret = get_unused_example(all);
920     all.p->lp.default_label(&ret->l);
921     for (size_t i = 0; i < len;i++)
922       {
923 	uint32_t index = features[i].name;
924 	ret->indices.push_back(index);
925 	for (size_t j = 0; j < features[i].len; j++)
926 	  {
927 	    ret->sum_feat_sq[index] += features[i].fs[j].x * features[i].fs[j].x;
928 	    ret->atomics[index].push_back(features[i].fs[j]);
929 	  }
930       }
931     parse_atomic_example(all,ret,false); // all.p->parsed_examples++;
932     setup_example(all, ret);
933 
934     return ret;
935   }
936 
export_example(vw & all,example * ec,size_t & len)937   primitive_feature_space* export_example(vw& all, example* ec, size_t& len)
938   {
939     len = ec->indices.size();
940     primitive_feature_space* fs_ptr = new primitive_feature_space[len];
941 
942     int fs_count = 0;
943     for (unsigned char* i = ec->indices.begin; i != ec->indices.end; i++)
944       {
945 		fs_ptr[fs_count].name = *i;
946 		fs_ptr[fs_count].len = ec->atomics[*i].size();
947 		fs_ptr[fs_count].fs = new feature[fs_ptr[fs_count].len];
948 
949 		int f_count = 0;
950 		for (feature *f = ec->atomics[*i].begin; f != ec->atomics[*i].end; f++)
951 		  {
952 			feature t = *f;
953 			t.weight_index >>= all.reg.stride_shift;
954 			fs_ptr[fs_count].fs[f_count] = t;
955 			f_count++;
956 		  }
957 		fs_count++;
958       }
959     return fs_ptr;
960   }
961 
releaseFeatureSpace(primitive_feature_space * features,size_t len)962   void releaseFeatureSpace(primitive_feature_space* features, size_t len)
963   {
964     for (size_t i = 0; i < len;i++)
965       delete features[i].fs;
966     delete (features);
967   }
968 
parse_example_label(vw & all,example & ec,string label)969   void parse_example_label(vw& all, example&ec, string label) {
970     v_array<substring> words = v_init<substring>();
971     char* cstr = (char*)label.c_str();
972     substring str = { cstr, cstr+label.length() };
973     words.push_back(str);
974     all.p->lp.parse_label(all.p, all.sd, &ec.l, words);
975     words.erase();
976     words.delete_v();
977   }
978 
empty_example(vw & all,example & ec)979   void empty_example(vw& all, example& ec)
980   {
981 	if (all.audit || all.hash_inv)
982       for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++)
983 	{
984 	  for (audit_data* temp
985 		 = ec.audit_features[*i].begin;
986 	       temp != ec.audit_features[*i].end; temp++)
987 	    {
988 	      if (temp->alloced)
989 		{
990 		  free(temp->space);
991 		  free(temp->feature);
992 		  temp->alloced=false;
993 		}
994 	    }
995 	  ec.audit_features[*i].erase();
996 	}
997 
998     for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++)
999       {
1000 	ec.atomics[*i].erase();
1001 	ec.sum_feat_sq[*i]=0;
1002       }
1003 
1004     ec.indices.erase();
1005     ec.tag.erase();
1006     ec.sorted = false;
1007     ec.end_pass = false;
1008   }
1009 
finish_example(vw & all,example * ec)1010   void finish_example(vw& all, example* ec)
1011   {
1012     mutex_lock(&all.p->output_lock);
1013     all.p->local_example_number++;
1014     condition_variable_signal(&all.p->output_done);
1015     mutex_unlock(&all.p->output_lock);
1016 
1017     empty_example(all, *ec);
1018 
1019     mutex_lock(&all.p->examples_lock);
1020     assert(ec->in_use);
1021     ec->in_use = false;
1022     condition_variable_signal(&all.p->example_unused);
1023     if (all.p->done)
1024       condition_variable_signal_all(&all.p->example_available);
1025     mutex_unlock(&all.p->examples_lock);
1026   }
1027 }
1028 
1029 #ifdef _WIN32
main_parse_loop(LPVOID in)1030 DWORD WINAPI main_parse_loop(LPVOID in)
1031 #else
1032 void *main_parse_loop(void *in)
1033 #endif
1034 {
1035 	vw* all = (vw*) in;
1036 	size_t example_number = 0;  // for variable-size batch learning algorithms
1037 
1038 
1039 	while(!all->p->done)
1040 	  {
1041             example* ae = get_unused_example(*all);
1042 	    if (!all->do_reset_source && example_number != all->pass_length && all->max_examples > example_number
1043 		   && parse_atomic_example(*all, ae) )
1044 	     {
1045 	       VW::setup_example(*all, ae);
1046 	       example_number++;
1047 	     }
1048 	    else
1049 	     {
1050 	       reset_source(*all, all->num_bits);
1051 	       all->do_reset_source = false;
1052 	       all->passes_complete++;
1053 	       end_pass_example(*all, ae);
1054 	       if (all->passes_complete == all->numpasses && example_number == all->pass_length)
1055 			 {
1056 			   all->passes_complete = 0;
1057 			   all->pass_length = all->pass_length*2+1;
1058 			 }
1059 	       if (all->passes_complete >= all->numpasses && all->max_examples >= example_number)
1060 			 {
1061 			   mutex_lock(&all->p->examples_lock);
1062 			   all->p->done = true;
1063 			   mutex_unlock(&all->p->examples_lock);
1064 			 }
1065 	       example_number = 0;
1066 	     }
1067 	   mutex_lock(&all->p->examples_lock);
1068 	   all->p->end_parsed_examples++;
1069 	   condition_variable_signal_all(&all->p->example_available);
1070 	   mutex_unlock(&all->p->examples_lock);
1071 	  }
1072 	return NULL;
1073 }
1074 
1075 namespace VW{
get_example(parser * p)1076 example* get_example(parser* p)
1077 {
1078   mutex_lock(&p->examples_lock);
1079   if (p->end_parsed_examples != p->used_index) {
1080     size_t ring_index = p->used_index++ % p->ring_size;
1081     if (!(p->examples+ring_index)->in_use)
1082       cout << p->used_index << " " << p->end_parsed_examples << " " << ring_index << endl;
1083     assert((p->examples+ring_index)->in_use);
1084     mutex_unlock(&p->examples_lock);
1085 
1086     return p->examples + ring_index;
1087   }
1088   else {
1089     if (!p->done)
1090       {
1091 	condition_variable_wait(&p->example_available, &p->examples_lock);
1092 	mutex_unlock(&p->examples_lock);
1093 	return get_example(p);
1094       }
1095     else {
1096       mutex_unlock(&p->examples_lock);
1097       return NULL;
1098     }
1099   }
1100 }
1101 
get_topic_prediction(example * ec,size_t i)1102 float get_topic_prediction(example* ec, size_t i)
1103 {
1104 	return ec->topic_predictions[i];
1105 }
1106 
get_label(example * ec)1107 float get_label(example* ec)
1108 {
1109 	return ec->l.simple.label;
1110 }
1111 
get_importance(example * ec)1112 float get_importance(example* ec)
1113 {
1114 	return ec->l.simple.weight;
1115 }
1116 
get_initial(example * ec)1117 float get_initial(example* ec)
1118 {
1119 	return ec->l.simple.initial;
1120 }
1121 
get_prediction(example * ec)1122 float get_prediction(example* ec)
1123 {
1124 	return ec->pred.scalar;
1125 }
1126 
get_cost_sensitive_prediction(example * ec)1127 float get_cost_sensitive_prediction(example* ec)
1128 {
1129        return (float)ec->pred.multiclass;
1130 }
1131 
get_tag_length(example * ec)1132 size_t get_tag_length(example* ec)
1133 {
1134 	return ec->tag.size();
1135 }
1136 
get_tag(example * ec)1137 const char* get_tag(example* ec)
1138 {
1139 	return ec->tag.begin;
1140 }
1141 
get_feature_number(example * ec)1142 size_t get_feature_number(example* ec)
1143 {
1144 	return ec->num_features;
1145 }
1146 }
1147 
initialize_examples(vw & all)1148 void initialize_examples(vw& all)
1149 {
1150   all.p->used_index = 0;
1151   all.p->begin_parsed_examples = 0;
1152   all.p->end_parsed_examples = 0;
1153   all.p->done = false;
1154 
1155   all.p->examples = calloc_or_die<example>(all.p->ring_size);
1156 
1157   for (size_t i = 0; i < all.p->ring_size; i++)
1158     {
1159       memset(&all.p->examples[i].l, 0, sizeof(polylabel));
1160       all.p->examples[i].in_use = false;
1161     }
1162 }
1163 
adjust_used_index(vw & all)1164 void adjust_used_index(vw& all)
1165 {
1166 	all.p->used_index=all.p->begin_parsed_examples;
1167 }
1168 
initialize_parser_datastructures(vw & all)1169 void initialize_parser_datastructures(vw& all)
1170 {
1171   initialize_examples(all);
1172   initialize_mutex(&all.p->examples_lock);
1173   initialize_condition_variable(&all.p->example_available);
1174   initialize_condition_variable(&all.p->example_unused);
1175   initialize_mutex(&all.p->output_lock);
1176   initialize_condition_variable(&all.p->output_done);
1177 }
1178 
1179 namespace VW {
start_parser(vw & all,bool init_structures)1180 void start_parser(vw& all, bool init_structures)
1181 {
1182   if (init_structures)
1183 	initialize_parser_datastructures(all);
1184   #ifndef _WIN32
1185   pthread_create(&all.parse_thread, NULL, main_parse_loop, &all);
1186   #else
1187   all.parse_thread = ::CreateThread(NULL, 0, static_cast<LPTHREAD_START_ROUTINE>(main_parse_loop), &all, NULL, NULL);
1188   #endif
1189 }
1190 }
free_parser(vw & all)1191 void free_parser(vw& all)
1192 {
1193   all.p->channels.delete_v();
1194   all.p->words.delete_v();
1195   all.p->name.delete_v();
1196 
1197   if(all.ngram_strings.size() > 0)
1198     all.p->gram_mask.delete_v();
1199 
1200   for (size_t i = 0; i < all.p->ring_size; i++)
1201     {
1202       dealloc_example(all.p->lp.delete_label, all.p->examples[i]);
1203     }
1204   free(all.p->examples);
1205 
1206   io_buf* output = all.p->output;
1207   if (output != NULL)
1208     {
1209       output->finalname.delete_v();
1210       output->currentname.delete_v();
1211     }
1212 
1213   all.p->counts.delete_v();
1214 }
1215 
release_parser_datastructures(vw & all)1216 void release_parser_datastructures(vw& all)
1217 {
1218   delete_mutex(&all.p->examples_lock);
1219   delete_mutex(&all.p->output_lock);
1220 }
1221 
1222 namespace VW {
end_parser(vw & all)1223 void end_parser(vw& all)
1224 {
1225   #ifndef _WIN32
1226   pthread_join(all.parse_thread, NULL);
1227   #else
1228   ::WaitForSingleObject(all.parse_thread, INFINITE);
1229   ::CloseHandle(all.parse_thread);
1230   #endif
1231   release_parser_datastructures(all);
1232 }
1233 }
1234