1 #include <vector>
2 #ifdef _WIN32
3 #include <WinSock2.h>
4 #ifndef SHUT_RD
5 #   define SHUT_RD SD_RECEIVE
6 #endif
7 
8 #ifndef SHUT_WR
9 #   define SHUT_WR SD_SEND
10 #endif
11 
12 #ifndef SHUT_RDWR
13 #   define SHUT_RDWR SD_BOTH
14 #endif
15 #else
16 #include <netdb.h>
17 #endif
18 #include "io_buf.h"
19 #include "cache.h"
20 #include "network.h"
21 #include "reductions.h"
22 
23 struct sender {
24   io_buf* buf;
25   int sd;
26   vw* all;//loss ring_size others
27   example** delay_ring;
28   size_t sent_index;
29   size_t received_index;
30 };
31 
open_sockets(sender & s,string host)32 void open_sockets(sender& s, string host)
33 {
34   s.sd = open_socket(host.c_str());
35   s.buf = new io_buf();
36   s.buf->files.push_back(s.sd);
37 }
38 
send_features(io_buf * b,example & ec,uint32_t mask)39 void send_features(io_buf *b, example& ec, uint32_t mask)
40 { // note: subtracting 1 b/c not sending constant
41   output_byte(*b,(unsigned char) (ec.indices.size()-1));
42 
43   for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) {
44     if (*i == constant_namespace)
45       continue;
46     output_features(*b, *i, ec.atomics[*i].begin, ec.atomics[*i].end, mask);
47   }
48   b->flush();
49 }
50 
receive_result(sender & s)51 void receive_result(sender& s)
52 {
53   float res, weight;
54 
55   get_prediction(s.sd,res,weight);
56   example& ec = *s.delay_ring[s.received_index++ % s.all->p->ring_size];
57   ec.pred.scalar = res;
58 
59   label_data& ld = ec.l.simple;
60   ec.loss = s.all->loss->getLoss(s.all->sd, ec.pred.scalar, ld.label) * ld.weight;
61 
62   return_simple_example(*(s.all), NULL, ec);
63 }
64 
learn(sender & s,LEARNER::base_learner & base,example & ec)65 void learn(sender& s, LEARNER::base_learner& base, example& ec)
66 {
67   if (s.received_index + s.all->p->ring_size / 2 - 1 == s.sent_index)
68     receive_result(s);
69 
70   s.all->set_minmax(s.all->sd, ec.l.simple.label);
71   s.all->p->lp.cache_label(&ec.l, *s.buf);//send label information.
72   cache_tag(*s.buf, ec.tag);
73   send_features(s.buf,ec, (uint32_t)s.all->parse_mask);
74   s.delay_ring[s.sent_index++ % s.all->p->ring_size] = &ec;
75 }
76 
finish_example(vw & all,sender &,example & ec)77 void finish_example(vw& all, sender&, example& ec){}
78 
end_examples(sender & s)79 void end_examples(sender& s)
80 { //close our outputs to signal finishing.
81   while (s.received_index != s.sent_index)
82     receive_result(s);
83   shutdown(s.buf->files[0],SHUT_WR);
84 }
85 
finish(sender & s)86 void finish(sender& s)
87 {
88   s.buf->files.delete_v();
89   s.buf->space.delete_v();
90   free(s.delay_ring);
91   delete s.buf;
92 }
93 
sender_setup(vw & all)94 LEARNER::base_learner* sender_setup(vw& all)
95 {
96   if (missing_option<string, true>(all, "sendto", "send examples to <host>"))
97     return NULL;
98 
99   sender& s = calloc_or_die<sender>();
100   s.sd = -1;
101   if (all.vm.count("sendto"))
102     {
103       string host = all.vm["sendto"].as< string >();
104       open_sockets(s, host);
105     }
106 
107   s.all = &all;
108   s.delay_ring = calloc_or_die<example*>(all.p->ring_size);
109 
110   LEARNER::learner<sender>& l = init_learner(&s, learn, 1);
111   l.set_finish(finish);
112   l.set_finish_example(finish_example);
113   l.set_end_examples(end_examples);
114   return make_base(l);
115 }
116