1 #include <vector>
2 #ifdef _WIN32
3 #include <WinSock2.h>
4 #ifndef SHUT_RD
5 # define SHUT_RD SD_RECEIVE
6 #endif
7
8 #ifndef SHUT_WR
9 # define SHUT_WR SD_SEND
10 #endif
11
12 #ifndef SHUT_RDWR
13 # define SHUT_RDWR SD_BOTH
14 #endif
15 #else
16 #include <netdb.h>
17 #endif
18 #include "io_buf.h"
19 #include "cache.h"
20 #include "network.h"
21 #include "reductions.h"
22
23 struct sender {
24 io_buf* buf;
25 int sd;
26 vw* all;//loss ring_size others
27 example** delay_ring;
28 size_t sent_index;
29 size_t received_index;
30 };
31
open_sockets(sender & s,string host)32 void open_sockets(sender& s, string host)
33 {
34 s.sd = open_socket(host.c_str());
35 s.buf = new io_buf();
36 s.buf->files.push_back(s.sd);
37 }
38
send_features(io_buf * b,example & ec,uint32_t mask)39 void send_features(io_buf *b, example& ec, uint32_t mask)
40 { // note: subtracting 1 b/c not sending constant
41 output_byte(*b,(unsigned char) (ec.indices.size()-1));
42
43 for (unsigned char* i = ec.indices.begin; i != ec.indices.end; i++) {
44 if (*i == constant_namespace)
45 continue;
46 output_features(*b, *i, ec.atomics[*i].begin, ec.atomics[*i].end, mask);
47 }
48 b->flush();
49 }
50
receive_result(sender & s)51 void receive_result(sender& s)
52 {
53 float res, weight;
54
55 get_prediction(s.sd,res,weight);
56 example& ec = *s.delay_ring[s.received_index++ % s.all->p->ring_size];
57 ec.pred.scalar = res;
58
59 label_data& ld = ec.l.simple;
60 ec.loss = s.all->loss->getLoss(s.all->sd, ec.pred.scalar, ld.label) * ld.weight;
61
62 return_simple_example(*(s.all), NULL, ec);
63 }
64
learn(sender & s,LEARNER::base_learner & base,example & ec)65 void learn(sender& s, LEARNER::base_learner& base, example& ec)
66 {
67 if (s.received_index + s.all->p->ring_size / 2 - 1 == s.sent_index)
68 receive_result(s);
69
70 s.all->set_minmax(s.all->sd, ec.l.simple.label);
71 s.all->p->lp.cache_label(&ec.l, *s.buf);//send label information.
72 cache_tag(*s.buf, ec.tag);
73 send_features(s.buf,ec, (uint32_t)s.all->parse_mask);
74 s.delay_ring[s.sent_index++ % s.all->p->ring_size] = &ec;
75 }
76
finish_example(vw & all,sender &,example & ec)77 void finish_example(vw& all, sender&, example& ec){}
78
end_examples(sender & s)79 void end_examples(sender& s)
80 { //close our outputs to signal finishing.
81 while (s.received_index != s.sent_index)
82 receive_result(s);
83 shutdown(s.buf->files[0],SHUT_WR);
84 }
85
finish(sender & s)86 void finish(sender& s)
87 {
88 s.buf->files.delete_v();
89 s.buf->space.delete_v();
90 free(s.delay_ring);
91 delete s.buf;
92 }
93
sender_setup(vw & all)94 LEARNER::base_learner* sender_setup(vw& all)
95 {
96 if (missing_option<string, true>(all, "sendto", "send examples to <host>"))
97 return NULL;
98
99 sender& s = calloc_or_die<sender>();
100 s.sd = -1;
101 if (all.vm.count("sendto"))
102 {
103 string host = all.vm["sendto"].as< string >();
104 open_sockets(s, host);
105 }
106
107 s.all = &all;
108 s.delay_ring = calloc_or_die<example*>(all.p->ring_size);
109
110 LEARNER::learner<sender>& l = init_learner(&s, learn, 1);
111 l.set_finish(finish);
112 l.set_finish_example(finish_example);
113 l.set_end_examples(end_examples);
114 return make_base(l);
115 }
116