1 /* $Id$ */
2 
3 /*
4  *
5  * Copyright (C) 2003 David Mazieres (dm@uun.org)
6  *
7  * This program is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License as
9  * published by the Free Software Foundation; either version 2, or (at
10  * your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful, but
13  * WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15  * General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
20  * USA
21  *
22  */
23 
24 #include "async.h"
25 #include "dns.h"
26 
27 struct tcpconnect_t {
~tcpconnect_ttcpconnect_t28   virtual ~tcpconnect_t () {}
29 };
30 
31 struct tcpportconnect_t : tcpconnect_t {
32   u_int16_t port;
33   cbi cb;
34   int fd;
35   dnsreq_t *dnsp;
36   str *namep;
37 
38   tcpportconnect_t (const in_addr &a, u_int16_t port, cbi cb);
39   tcpportconnect_t (str hostname, u_int16_t port, cbi cb,
40 		bool dnssearch, str *namep);
41   ~tcpportconnect_t ();
42 
replytcpportconnect_t43   void reply (int s) { if (s == fd) fd = -1; (*cb) (s); delete this; }
failtcpportconnect_t44   void fail (int error) { errno = error; reply (-1); }
45   void connect_to_name (str hostname, bool dnssearch);
46   void name_cb (ptr<hostent> h, int err);
47   void connect_to_in_addr (const in_addr &a);
48   void connect_cb ();
49 };
50 
tcpportconnect_t(const in_addr & a,u_int16_t p,cbi c)51 tcpportconnect_t::tcpportconnect_t (const in_addr &a, u_int16_t p, cbi c)
52   : port (p), cb (c), fd (-1), dnsp (NULL), namep (NULL)
53 {
54   connect_to_in_addr (a);
55 }
56 
tcpportconnect_t(str hostname,u_int16_t p,cbi c,bool dnssearch,str * np)57 tcpportconnect_t::tcpportconnect_t (str hostname, u_int16_t p, cbi c,
58 			    bool dnssearch, str *np)
59   : port (p), cb (c), fd (-1), dnsp (NULL), namep (np)
60 {
61   connect_to_name (hostname, dnssearch);
62 }
63 
~tcpportconnect_t()64 tcpportconnect_t::~tcpportconnect_t ()
65 {
66   if (dnsp)
67     dnsreq_cancel (dnsp);
68   if (fd >= 0) {
69     fdcb (fd, selwrite, NULL);
70     close (fd);
71   }
72 }
73 
74 void
connect_to_name(str hostname,bool dnssearch)75 tcpportconnect_t::connect_to_name (str hostname, bool dnssearch)
76 {
77   dnsp = dns_hostbyname (hostname, wrap (this, &tcpportconnect_t::name_cb),
78 			 dnssearch);
79 }
80 
81 void
name_cb(ptr<hostent> h,int err)82 tcpportconnect_t::name_cb (ptr<hostent> h, int err)
83 {
84   dnsp = NULL;
85   if (!h) {
86     if (dns_tmperr (err))
87       fail (EAGAIN);
88     else
89       fail (ENOENT);
90     return;
91   }
92   if (namep)
93     *namep = h->h_name;
94   connect_to_in_addr (*(in_addr *) h->h_addr);
95 }
96 
97 void
connect_to_in_addr(const in_addr & a)98 tcpportconnect_t::connect_to_in_addr (const in_addr &a)
99 {
100   sockaddr_in sin;
101   bzero (&sin, sizeof (sin));
102   sin.sin_family = AF_INET;
103   sin.sin_port = htons (port);
104   sin.sin_addr = a;
105 
106   fd = inetsocket (SOCK_STREAM);
107   if (fd < 0) {
108     delaycb (0, wrap (this, &tcpportconnect_t::fail, errno));
109     return;
110   }
111   make_async (fd);
112   close_on_exec (fd);
113   if (connect (fd, (sockaddr *) &sin, sizeof (sin)) < 0
114       && errno != EINPROGRESS) {
115     delaycb (0, wrap (this, &tcpportconnect_t::fail, errno));
116     return;
117   }
118   fdcb (fd, selwrite, wrap (this, &tcpportconnect_t::connect_cb));
119 }
120 
121 void
connect_cb()122 tcpportconnect_t::connect_cb ()
123 {
124   fdcb (fd, selwrite, NULL);
125 
126   sockaddr_in sin;
127   socklen_t sn = sizeof (sin);
128   if (!getpeername (fd, (sockaddr *) &sin, &sn)) {
129     reply (fd);
130     return;
131   }
132 
133   int err = 0;
134   sn = sizeof (err);
135   getsockopt (fd, SOL_SOCKET, SO_ERROR, (char *) &err, &sn);
136   fail (err ? err : ECONNREFUSED);
137 }
138 
139 tcpconnect_t *
tcpconnect(in_addr addr,u_int16_t port,cbi cb)140 tcpconnect (in_addr addr, u_int16_t port, cbi cb)
141 {
142   return New tcpportconnect_t (addr, port, cb);
143 }
144 
145 tcpconnect_t *
tcpconnect(str hostname,u_int16_t port,cbi cb,bool dnssearch,str * namep)146 tcpconnect (str hostname, u_int16_t port, cbi cb,
147 	    bool dnssearch, str *namep)
148 {
149   return New tcpportconnect_t (hostname, port, cb, dnssearch, namep);
150 }
151 
152 void
tcpconnect_cancel(tcpconnect_t * tc)153 tcpconnect_cancel (tcpconnect_t *tc)
154 {
155   delete tc;
156 }
157 
158 struct tcpsrvconnect_t : tcpconnect_t {
159   u_int16_t defport;
160   cbi cb;
161   int dnserr;
162   dnsreq_t *areq;
163   ptr<hostent> h;
164   dnsreq_t *srvreq;
165   ptr<srvlist> srvl;
166   timecb_t *tmo;
167   vec<tcpconnect_t *> cons;
168   int cbad;
169   int error;
170   ptr<srvlist> *srvlp;
171   str *namep;
172 
173   tcpsrvconnect_t (str name, str service, cbi cb, u_int16_t dp,
174 		   bool search, ptr<srvlist> *sp, str *np);
175   tcpsrvconnect_t (ref<srvlist> sl, cbi cb, str *np);
176   ~tcpsrvconnect_t ();
177   void dnsacb (ptr<hostent>, int err);
178   void dnssrvcb (ptr<srvlist>, int err);
179   void maybe_start (int err);
180   void connectcb (int cn, int fd);
181   void nextsrv (bool timeout = false);
182 };
183 
184 void
nextsrv(bool timeout)185 tcpsrvconnect_t::nextsrv (bool timeout)
186 {
187   if (!timeout)
188     timecb_remove (tmo);
189   tmo = NULL;
190 
191   u_int n = cons.size ();
192 
193   if (n >= srvl->s_nsrv)
194     return;
195 
196   // warn ("nextsrv %d (port %d)\n", n, srvl->s_srvs[n].port);
197 
198   if (!srvl->s_srvs[n].port || !srvl->s_srvs[n].name[0]) {
199     cons.push_back (NULL);
200     errno = ENOENT;
201     connectcb (n, -1);
202     return;
203   }
204   else if (h && !strcasecmp (srvl->s_srvs[n].name, h->h_name))
205     cons.push_back (tcpconnect (*(in_addr *) h->h_addr, srvl->s_srvs[n].port,
206 				wrap (this, &tcpsrvconnect_t::connectcb, n)));
207   else {
208     str name = srvl->s_srvs[n].name;
209     addrhint **hint;
210     for (hint = srvl->s_hints;
211 	 *hint && ((*hint)->h_addrtype != AF_INET
212 		   || strcasecmp ((*hint)->h_name, name));
213 	 hint++)
214       ;
215     if (*hint)
216       cons.push_back (tcpconnect (inaddr_cast ((*hint)->h_address),
217 				  srvl->s_srvs[n].port,
218 				  wrap (this, &tcpsrvconnect_t::connectcb,
219 					n)));
220     else
221       cons.push_back (tcpconnect (srvl->s_srvs[n].name, srvl->s_srvs[n].port,
222 				  wrap (this, &tcpsrvconnect_t::connectcb, n),
223 				  false));
224   }
225 
226   tmo = delaycb (4, wrap (this, &tcpsrvconnect_t::nextsrv, true));
227 }
228 
229 void
connectcb(int cn,int fd)230 tcpsrvconnect_t::connectcb (int cn, int fd)
231 {
232   cons[cn] = NULL;
233 
234   if (fd >= 0) {
235     errno = 0;
236     if (namep) {
237       if (srvl) {
238 	*namep = srvl->s_srvs[cn].name;
239 	srvl->s_srvs[cn].port = 0;
240       }
241       else
242 	*namep = h->h_name;
243     }
244     (*cb) (fd);
245     delete this;
246     return;
247   }
248 
249   // warn ("%s:%d %m\n", srvl->s_srvs[cn].name, srvl->s_srvs[cn].port);
250 
251   if (!error)
252     error = errno;
253   else if (errno == EAGAIN)
254     error = errno;
255   else if (error != EAGAIN && errno != ENOENT)
256     error = errno;
257 
258   if (srvl)
259     srvl->s_srvs[cn].port = 0;
260 
261   if (!srvl || ++cbad >= srvl->s_nsrv) {
262     errno = error;
263     (*cb) (-1);
264     delete this;
265     return;
266   }
267 
268   if (!cons.back ())
269     nextsrv ();
270 }
271 
tcpsrvconnect_t(str name,str s,cbi cb,u_int16_t dp,bool search,ptr<srvlist> * sp,str * np)272 tcpsrvconnect_t::tcpsrvconnect_t (str name, str s, cbi cb, u_int16_t dp,
273 				  bool search, ptr<srvlist> *sp, str *np)
274   : defport (dp), cb (cb), dnserr (0), tmo (NULL), cbad (0),
275     error (0), srvlp (sp), namep (np)
276 {
277   areq = dns_hostbyname (name, wrap (this, &tcpsrvconnect_t::dnsacb), search);
278   srvreq = dns_srvbyname (name, "tcp", s,
279 			  wrap (this, &tcpsrvconnect_t::dnssrvcb), search);
280 }
281 
tcpsrvconnect_t(ref<srvlist> sl,cbi cb,str * np)282 tcpsrvconnect_t::tcpsrvconnect_t (ref<srvlist> sl, cbi cb, str *np)
283 
284   : defport (0), cb (cb), dnserr (0), areq (NULL), srvreq (NULL),
285     tmo (NULL), cbad (0), error (0), srvlp (NULL), namep (np)
286 {
287   delaycb (0, wrap (this, &tcpsrvconnect_t::dnssrvcb, sl, 0));
288 }
289 
~tcpsrvconnect_t()290 tcpsrvconnect_t::~tcpsrvconnect_t ()
291 {
292   for (tcpconnect_t **cp = cons.base (); cp < cons.lim (); cp++)
293     tcpconnect_cancel (*cp);
294   dnsreq_cancel (areq);
295   dnsreq_cancel (srvreq);
296   timecb_remove (tmo);
297 }
298 
299 void
maybe_start(int err)300 tcpsrvconnect_t::maybe_start (int err)
301 {
302   if (err && err != NXDOMAIN && err != ARERR_NXREC) {
303     if (!dnserr)
304       dnserr = err;
305     else if (!dns_tmperr (dnserr) && dns_tmperr (err))
306       dnserr = err;
307   }
308   if (srvreq || (!srvl && areq))
309     return;
310   if (srvl)
311     nextsrv ();
312   else if (h && defport) {
313     cons.push_back (tcpconnect (*(in_addr *) h->h_addr, defport,
314 				wrap (this, &tcpsrvconnect_t::connectcb, 0)));
315   }
316   else {
317     if (dns_tmperr (dnserr))
318       errno = EAGAIN;
319     else
320       errno = ENOENT;
321     (*cb) (-1);
322     delete this;
323   }
324 }
325 
326 void
dnsacb(ptr<hostent> hh,int err)327 tcpsrvconnect_t::dnsacb (ptr<hostent> hh, int err)
328 {
329   areq = NULL;
330   h = hh;
331   maybe_start (err);
332 }
333 
334 void
dnssrvcb(ptr<srvlist> s,int err)335 tcpsrvconnect_t::dnssrvcb (ptr<srvlist> s, int err)
336 {
337   srvreq = NULL;
338   srvl = s;
339   if (srvlp)
340     *srvlp = srvl;
341   maybe_start (err);
342 }
343 
344 tcpconnect_t *
tcpconnect_srv(str hostname,str service,u_int16_t defport,cbi cb,bool dnssearch,ptr<srvlist> * srvlp,str * np)345 tcpconnect_srv (str hostname, str service, u_int16_t defport,
346 		cbi cb, bool dnssearch, ptr<srvlist> *srvlp, str *np)
347 {
348   if (srvlp && *srvlp)
349     return New tcpsrvconnect_t (*srvlp, cb, np);
350   else
351     return New tcpsrvconnect_t (hostname, service, cb, defport,
352 				dnssearch, srvlp, np);
353 }
354 
355 tcpconnect_t *
tcpconnect_srv_retry(ref<srvlist> srvl,cbi cb,str * np)356 tcpconnect_srv_retry (ref<srvlist> srvl, cbi cb, str *np)
357 {
358   return New tcpsrvconnect_t (srvl, cb, np);
359 }
360 
361