1local loader = function(loader)
2
3local socket = require("_cqueues.socket")
4local cqueues = require("cqueues")
5local errno = require("cqueues.errno")
6
7local poll = cqueues.poll
8local monotime = cqueues.monotime
9
10local AF_INET = socket.AF_INET
11local AF_INET6 = socket.AF_INET6
12local AF_UNIX = socket.AF_UNIX
13local SOCK_STREAM = socket.SOCK_STREAM
14local SOCK_DGRAM = socket.SOCK_DGRAM
15
16local EAGAIN = errno.EAGAIN
17local EPIPE = errno.EPIPE
18local ETIMEDOUT = errno.ETIMEDOUT
19local ENOTCONN = errno.ENOTCONN
20local ENOTSOCK = errno.ENOTSOCK
21local strerror = errno.strerror
22
23local format = string.format
24
25
26--
27-- H E L P E R  R O U T I N E S
28--
29-- ========================================================================
30
31local function timed_poll(self, deadline)
32	if deadline then
33		local curtime = monotime()
34
35		if deadline <= curtime then
36			return false
37		end
38
39		poll(self, deadline - curtime)
40
41		return true
42	else
43		poll(self)
44
45		return true
46	end
47end -- timed_poll
48
49
50local function logname(so)
51	local af, addr, port = so:peername()
52
53	if af == AF_INET or af == AF_INET6 then
54		return format("%s.%s", addr, port)
55	elseif af == AF_UNIX then
56		return format("unix:%s", addr or "unnamed")
57	end
58end -- logname
59
60
61--
62-- E R R O R  M A N A G E M E N T
63--
64-- All errors in the I/O routines are first passed to a per-socket error
65-- handler, which can choose to return or throw them.
66--
67-- The default error handler is not actually installed with any socket, as
68-- that would create needless churn in the registry index on socket
69-- instantiation. Instead we interpose socket.onerror and socket:onerror and
70-- return our default handler if none was previously installed.
71--
72-- ========================================================================
73
74-- default error handler
75local function def_onerror(self, op, why, lvl)
76	if why == EPIPE then
77		return EPIPE
78	elseif why == ETIMEDOUT then
79		return ETIMEDOUT
80	else
81		local addr = logname(self)
82		local msg
83
84		if addr then
85			msg = format("[%s]:%s: %s", addr, op, strerror(why))
86		else
87			msg = format("socket:%s: %s", op, strerror(why))
88		end
89
90		error(msg, lvl)
91	end
92end -- def_onerror
93
94
95do
96	local _onerror = socket.onerror; socket.onerror = function(...)
97		return _onerror(...) or def_onerror
98	end
99end
100
101do
102	local _onerror; _onerror = socket.interpose("onerror", function(...)
103		return _onerror(...) or def_onerror
104	end)
105end
106
107
108--
109-- On buffered I/O we need to preserve errors across calls, otherwise
110-- unchecked transient errors might lead to unexpected behavior by
111-- application code. This is particularly true regarding timeouts, and
112-- especially so when mixed with iterators like socket:lines--doubly so when
113-- reading MIME headers, which could terminate on ETIMEDOUT, EPIPE, or just
114-- when reaching the end of the headers section.
115--
116-- Why not just always throw on such errors? One reason is that we partially
117-- mimic Lua's file objects, which will return such errors. (And we might
118-- change our semantics to fully mimic Lua in the future.)
119--
120-- Another reason is that it's very common to want to deal with timeouts
121-- inline. For example, maybe you want to write a keep-alive message after a
122-- read timeout. Timeouts are exceptional but not necessarily errors.
123--
124local preserve = {
125	read = "r", lines = "r", fill = "r", unpack = "r",
126	write = "w", flush = "w", pack = "w",
127
128	-- these too for good measure, even though they're not buffered
129	recvfd = "r", sendfd = "w",
130}
131
132-- drop EPIPE errors on input channel
133local nopipe = {
134	read = true, lines = true, fill = true, unpack = true, recvfd = true
135}
136
137local function oops(self, op, why, level)
138	local onerror = self:onerror() or def_onerror
139
140	if why == EPIPE and nopipe[op] then
141		return -- EOF
142	elseif preserve[op] then
143		self:seterror(preserve[op], why)
144	end
145
146	-- NOTE: There's normally no need to increment on a tail-call
147	-- (except when directly calling the error() routine), but we
148	-- increment here so the callee has the correct stack level to pass
149	-- to error() directly, without making adjustments for its own
150	-- activation record.
151	return onerror(self, op, why, (level or 2) + 1)
152end -- oops
153
154
155--
156-- A P I  E X T E N S I O N S
157--
158-- The core sockets implementation in C will not yield on I/O, or throw
159-- recoverable errors. These things are done in Lua code for simplicitly and
160-- portability--Lua 5.1/LuaJIT doesn't support resumption of C routines.
161--
162-- ========================================================================
163
164--
165-- Extended socket.pair
166--
167local _pair = socket.pair; socket.pair = function(type)
168	if type == "stream" then
169		type = SOCK_STREAM
170	elseif type == "dgram" then
171		type = SOCK_DGRAM
172	end
173
174	return _pair(type)
175end
176
177
178--
179-- Throwable socket:setbufsiz
180--
181local _setbufsiz; _setbufsiz = socket.interpose("setbufsiz", function(self, input_, output_)
182	local input, output, why = _setbufsiz(self, input_, output_)
183
184	if not input then
185		return nil, nil, oops(self, "setbufsiz", why)
186	end
187
188	return input, output
189end)
190
191
192--
193-- Yielding socket:listen
194--
195local _listen; _listen = socket.interpose("listen", function(self, timeout)
196	if not timeout then
197		timeout = self:timeout()
198	end
199	local deadline = timeout and (monotime() + timeout)
200	local ok, why = _listen(self)
201
202	while not ok do
203		if why == EAGAIN then
204			if not timed_poll(self, deadline) then
205				return nil, oops(self, "listen", ETIMEDOUT)
206			end
207		else
208			return nil, oops(self, "listen", why)
209		end
210
211		ok, why = _listen(self)
212	end
213
214	return self
215end)
216
217
218--
219-- Yielding socket:accept
220--
221local _accept; _accept = socket.interpose("accept", function(self, opts, timeout)
222	-- :accept used to take just a timeout as argument
223	if type(opts) == "number" then
224		timeout, opts = opts, nil
225	else
226		timeout = timeout or self:timeout()
227	end
228	local deadline = timeout and (monotime() + timeout)
229	local con, why = _accept(self, opts)
230
231	while not con do
232		if why == EAGAIN then
233			if not timed_poll(self, deadline) then
234				return nil, oops(self, "accept", ETIMEDOUT)
235			end
236		else
237			return nil, oops(self, "accept", why)
238		end
239
240		con, why = _accept(self, opts)
241	end
242
243	return con
244end)
245
246
247--
248-- Add socket:clients
249--
250socket.interpose("clients", function(self, opts, timeout)
251	return function() return self:accept(opts, timeout) end
252end)
253
254
255--
256-- Yielding socket:connect
257--
258local _connect; _connect = socket.interpose("connect", function(self, timeout)
259	if not timeout then
260		timeout = self:timeout()
261	end
262	local deadline = timeout and (monotime() + timeout)
263	local ok, why = _connect(self)
264
265	while not ok do
266		if why == EAGAIN then
267			if not timed_poll(self, deadline) then
268				return nil, oops(self, "connect", ETIMEDOUT)
269			end
270		else
271			return nil, oops(self, "connect", why)
272		end
273
274		ok, why = _connect(self)
275	end
276
277	return self
278end)
279
280
281--
282-- Yielding socket:starttls
283--
284local _starttls; _starttls = socket.interpose("starttls", function(self, arg1, arg2)
285	local ctx, timeout
286
287	if type(arg1) == "userdata" then
288		ctx = arg1
289	elseif type(arg2) == "userdata" then
290		ctx = arg2
291	end
292
293	if type(arg1) == "number" then
294		timeout = arg1
295	elseif type(arg2) == "number" then
296		timeout = arg2
297	else
298		timeout = self:timeout()
299	end
300
301	local deadline = timeout and monotime() + timeout
302	local ok, why = _starttls(self, ctx)
303
304	while not ok do
305		if why == EAGAIN then
306			if not timed_poll(self, deadline) then
307				return nil, oops(self, "starttls", ETIMEDOUT)
308			end
309		else
310			return nil, oops(self, "starttls", why)
311		end
312
313		ok, why = _starttls(self, ctx)
314	end
315
316	return self
317end)
318
319
320--
321-- Smarter socket:checktls
322--
323local havessl, whynossl
324
325local _checktls; _checktls = socket.interpose("checktls", function(self)
326	if not havessl then
327		if havessl == false then
328			return nil, whynossl
329		end
330
331		havessl, whynossl = pcall(require, "openssl.ssl")
332
333		if not havessl then
334			return nil, whynossl
335		end
336	end
337
338	return _checktls(self)
339end)
340
341
342--
343-- Yielding socket:flush
344--
345local _flush;
346
347local function timed_flush(self, mode, timeout, level)
348	local ok, why = _flush(self, mode)
349
350	if not ok then
351		local deadline = timeout and (monotime() + timeout)
352
353		repeat
354			if why == EAGAIN then
355				if not timed_poll(self, deadline) then
356					return false, oops(self, "flush", ETIMEDOUT, level + 1)
357				end
358			else
359				return false, oops(self, "flush", why, level + 1)
360			end
361
362			ok, why = _flush(self, mode)
363		until ok
364	end
365
366	return true
367end -- timed_flush
368
369_flush = socket.interpose("flush", function (self, arg1, arg2)
370	local mode, timeout
371
372	if type(arg1) == "string" then
373		mode = arg1
374	elseif type(arg2) == "string" then
375		mode = arg2
376	end
377
378	if type(arg1) == "number" then
379		timeout = arg1
380	elseif type(arg2) == "number" then
381		timeout = arg2
382	else
383		timeout = self:timeout()
384	end
385
386	return timed_flush(self, mode, timeout, 2)
387end)
388
389
390--
391-- Yielding socket:read
392--
393local function read(self, func, what, ...)
394	if not what then
395		return
396	end
397
398	local data, why = self:recv(what)
399
400	if not data then
401		local timeout = self:timeout()
402		local deadline = timeout and (monotime() + timeout)
403
404		repeat
405			if why == EAGAIN then
406				if not timed_poll(self, deadline) then
407					return nil, oops(self, func, ETIMEDOUT, 2)
408				end
409			elseif why then
410				return nil, oops(self, func, why, 2)
411			else
412				return -- EOF or end-of-headers
413			end
414
415			data, why = self:recv(what)
416		until data
417	end
418
419	return data, read(self, func, ...)
420end
421
422socket.interpose("read", function(self, what, ...)
423	if what then
424		return read(self, "read", what, ...)
425	else
426		return read(self, "read", "*l")
427	end
428end)
429
430
431--
432-- Yielding socket:write
433--
434-- This is complicated by the fact that we want error messages to get the
435-- correct stack trace, and also because on failure we want to return a list
436-- of error values of indeterminate length.
437--
438local writeall; writeall = function(self, data, ...)
439	if not data then
440		return self
441	end
442
443	data = tostring(data)
444
445	local i = 1
446
447	while i <= #data do
448		-- use only full buffering mode here to minimize socket I/O
449		local n, why = self:send(data, i, #data, "f")
450
451		i = i + n
452
453		if i <= #data then
454			if why == EAGAIN then
455				local timeout = self:timeout()
456				local deadline = timeout and (monotime() + timeout)
457
458				if not timed_poll(self, deadline) then
459					return nil, oops(self, "write", ETIMEDOUT, 3)
460				end
461			else
462				return nil, oops(self, "write", why, 3)
463			end
464		end
465	end
466
467	return writeall(self, ...)
468end
469
470local function fileresult(self, ok, ...)
471	if ok then
472		return self
473	else
474		return nil, ...
475	end
476end -- fileresult
477
478local function flushwrite(self, ok, ...)
479	if not ok then
480		return nil, ...
481	end
482
483	-- Flush the buffer here because we used full buffering mode in
484	-- writeall. But pass empty mode so it uses the configured flushing
485	-- mode instead of an implicit flush all.
486	return fileresult(self, timed_flush(self, "", nil, 2))
487end -- flushwrite
488
489socket.interpose("write", function (self, ...)
490	return flushwrite(self, writeall(self, ...))
491end)
492
493
494--
495-- Add socket:lines
496--
497-- We optimize single-mode case so we're not unpacking tables all the time.
498--
499local unpack = assert(table.unpack or unpack)
500
501socket.interpose("lines", function (self, mode, ...)
502	if mode then
503		local n = select("#", ...)
504		if n > 0 then
505			local args = { ... }
506
507			return function ()
508				return read(self, "lines", mode, unpack(args, 1, n))
509			end
510		end
511	else
512		mode = "*l"
513	end
514
515	return function ()
516		return read(self, "lines", mode)
517	end
518end)
519
520
521-- returns mode, timeout
522local function xopts(arg1, arg2)
523	if tonumber(arg1) then
524		return arg2, arg1
525	else
526		return arg1, arg2
527	end
528end -- xopts
529
530
531local function xdeadline(self, timeout)
532	timeout = timeout or self:timeout()
533
534	return timeout and (monotime() + timeout)
535end -- xdeadline
536
537
538--
539-- Smarter socket:read
540--
541socket.interpose("xread", function (self, what, ...)
542	local mode, timeout = xopts(...)
543
544	local data, why = self:recv(what, mode)
545
546	if not data then
547		local deadline = xdeadline(self, timeout)
548
549		repeat
550			if why == EAGAIN then
551				if not timed_poll(self, deadline) then
552					return nil, oops(self, "read", ETIMEDOUT)
553				end
554			elseif why then
555				return nil, oops(self, "read", why)
556			else
557				return --> EOF
558			end
559
560			data, why = self:recv(what, mode)
561		until data
562	end
563
564	return data
565end) -- xread
566
567
568--
569-- Smarter socket:write
570--
571socket.interpose("xwrite", function (self, data, ...)
572	local mode, timeout = xopts(...)
573	local i = 1
574
575	--
576	-- should we default to full-buffering here (and the :send below) if
577	-- mode is nil?
578	--
579	local n, why = self:send(data, i, #data, mode)
580
581	i = i + n
582
583	if i <= #data then
584		local deadline = xdeadline(self, timeout)
585
586		repeat
587			if why == EAGAIN then
588				if not timed_poll(self, deadline) then
589					return nil, oops(self, "write", ETIMEDOUT)
590				end
591			else
592				return nil, oops(self, "write", why)
593			end
594
595			n, why = self:send(data, i, #data, mode)
596
597			i = i + n
598		until i > #data
599
600		timeout = deadline and math.max(0, deadline - monotime())
601	end
602
603	return fileresult(self, self:flush(mode or "", timeout))
604end)
605
606
607--
608-- Smarter socket:lines
609--
610socket.interpose("xlines", function (self, what, ...)
611	local mode, timeout = xopts(...)
612
613	return function ()
614		return self:xread(what, mode, timeout)
615	end
616end)
617
618
619--
620-- Yielding socket:sendfd
621--
622local _sendfd; _sendfd = socket.interpose("sendfd", function (self, msg, fd, timeout)
623	if not timeout then
624		timeout = self:timeout()
625	end
626	local deadline = timeout and (monotime() + timeout)
627	local ok, why
628
629	repeat
630		ok, why = _sendfd(self, msg, fd)
631
632		if not ok then
633			if why == EAGAIN then
634				if not timed_poll(self, deadline) then
635					return false, oops(self, "sendfd", ETIMEDOUT)
636				end
637			else
638				return false, oops(self, "sendfd", why)
639			end
640		end
641	until ok
642
643	return ok
644end)
645
646
647--
648-- Yielding socket:recvfd
649--
650local _recvfd; _recvfd = socket.interpose("recvfd", function (self, prepbufsiz, timeout)
651	if not timeout then
652		timeout = self:timeout()
653	end
654	local deadline = timeout and (monotime() + timeout)
655	local msg, fd, why
656
657	repeat
658		msg, fd, why = _recvfd(self, prepbufsiz)
659
660		if not msg then
661			if why == EAGAIN then
662				if not timed_poll(self, deadline) then
663					return nil, nil, oops(self, "recvfd", ETIMEDOUT)
664				end
665			else
666				return nil, nil, oops(self, "recvfd", why)
667			end
668		end
669	until msg
670
671	return msg, fd
672end)
673
674
675--
676-- Yielding socket:pack
677--
678local _pack; _pack = socket.interpose("pack", function (self, num, nbits, mode)
679	local ok, why = _pack(self, num, nbits, mode)
680
681	if not ok then
682		local timeout = self:timeout()
683		local deadline = timeout and (monotime() + timeout)
684
685		repeat
686			if why == EAGAIN then
687				if not timed_poll(self, deadline) then
688					return false, oops(self, "pack", ETIMEDOUT)
689				end
690			else
691				return false, oops(self, "pack", why)
692			end
693
694			ok, why = _pack(self, num, nbits, mode)
695		until ok
696	end
697
698	return ok
699end)
700
701
702--
703-- Yielding socket:unpack
704--
705local _unpack; _unpack = socket.interpose("unpack", function (self, nbits)
706	local num, why = _unpack(self, nbits)
707
708	if not num then
709		local timeout = self:timeout()
710		local deadline = timeout and (monotime() + timeout)
711
712		repeat
713			if why == EAGAIN then
714				if not timed_poll(self, deadline) then
715					return nil, oops(self, "unpack", ETIMEDOUT)
716				end
717			else
718				return nil, oops(self, "unpack", why)
719			end
720
721			num, why = _unpack(self, nbits)
722		until num
723	end
724
725	return num
726end)
727
728
729--
730-- Yielding socket:fill
731--
732local _fill; _fill = socket.interpose("fill", function (self, size, timeout)
733	local ok, why = _fill(self, size)
734
735	if not ok then
736		if not timeout then
737			timeout = self:timeout()
738		end
739		local deadline = timeout and (monotime() + timeout)
740
741		repeat
742			if why == EAGAIN then
743				if not timed_poll(self, deadline) then
744					return false, oops(self, "fill", ETIMEDOUT)
745				end
746			else
747				return false, oops(self, "fill", why)
748			end
749
750			ok, why = _fill(self, size)
751		until ok
752	end
753
754	return true
755end)
756
757
758--
759-- Extend socket:peername
760--
761local function getname(get, self)
762	local af, r1, r2 = get(self)
763
764	if af then
765		return af, r1, r2
766	elseif r1 == ENOTCONN or r1 == ENOTSOCK or r1 == EAGAIN then
767		return 0
768	else
769		return nil, r1
770	end
771end
772
773local _peername; _peername = socket.interpose("peername", function (self)
774	return getname(_peername, self)
775end)
776
777
778--
779-- Extend socket:localname
780--
781local _localname; _localname = socket.interpose("localname", function (self)
782	return getname(_localname, self)
783end)
784
785
786socket.loader = loader
787
788return socket
789
790end -- loader
791
792return loader(loader)
793