1local loader = function(loader) 2 3local socket = require("_cqueues.socket") 4local cqueues = require("cqueues") 5local errno = require("cqueues.errno") 6 7local poll = cqueues.poll 8local monotime = cqueues.monotime 9 10local AF_INET = socket.AF_INET 11local AF_INET6 = socket.AF_INET6 12local AF_UNIX = socket.AF_UNIX 13local SOCK_STREAM = socket.SOCK_STREAM 14local SOCK_DGRAM = socket.SOCK_DGRAM 15 16local EAGAIN = errno.EAGAIN 17local EPIPE = errno.EPIPE 18local ETIMEDOUT = errno.ETIMEDOUT 19local ENOTCONN = errno.ENOTCONN 20local ENOTSOCK = errno.ENOTSOCK 21local strerror = errno.strerror 22 23local format = string.format 24 25 26-- 27-- H E L P E R R O U T I N E S 28-- 29-- ======================================================================== 30 31local function timed_poll(self, deadline) 32 if deadline then 33 local curtime = monotime() 34 35 if deadline <= curtime then 36 return false 37 end 38 39 poll(self, deadline - curtime) 40 41 return true 42 else 43 poll(self) 44 45 return true 46 end 47end -- timed_poll 48 49 50local function logname(so) 51 local af, addr, port = so:peername() 52 53 if af == AF_INET or af == AF_INET6 then 54 return format("%s.%s", addr, port) 55 elseif af == AF_UNIX then 56 return format("unix:%s", addr or "unnamed") 57 end 58end -- logname 59 60 61-- 62-- E R R O R M A N A G E M E N T 63-- 64-- All errors in the I/O routines are first passed to a per-socket error 65-- handler, which can choose to return or throw them. 66-- 67-- The default error handler is not actually installed with any socket, as 68-- that would create needless churn in the registry index on socket 69-- instantiation. Instead we interpose socket.onerror and socket:onerror and 70-- return our default handler if none was previously installed. 71-- 72-- ======================================================================== 73 74-- default error handler 75local function def_onerror(self, op, why, lvl) 76 if why == EPIPE then 77 return EPIPE 78 elseif why == ETIMEDOUT then 79 return ETIMEDOUT 80 else 81 local addr = logname(self) 82 local msg 83 84 if addr then 85 msg = format("[%s]:%s: %s", addr, op, strerror(why)) 86 else 87 msg = format("socket:%s: %s", op, strerror(why)) 88 end 89 90 error(msg, lvl) 91 end 92end -- def_onerror 93 94 95do 96 local _onerror = socket.onerror; socket.onerror = function(...) 97 return _onerror(...) or def_onerror 98 end 99end 100 101do 102 local _onerror; _onerror = socket.interpose("onerror", function(...) 103 return _onerror(...) or def_onerror 104 end) 105end 106 107 108-- 109-- On buffered I/O we need to preserve errors across calls, otherwise 110-- unchecked transient errors might lead to unexpected behavior by 111-- application code. This is particularly true regarding timeouts, and 112-- especially so when mixed with iterators like socket:lines--doubly so when 113-- reading MIME headers, which could terminate on ETIMEDOUT, EPIPE, or just 114-- when reaching the end of the headers section. 115-- 116-- Why not just always throw on such errors? One reason is that we partially 117-- mimic Lua's file objects, which will return such errors. (And we might 118-- change our semantics to fully mimic Lua in the future.) 119-- 120-- Another reason is that it's very common to want to deal with timeouts 121-- inline. For example, maybe you want to write a keep-alive message after a 122-- read timeout. Timeouts are exceptional but not necessarily errors. 123-- 124local preserve = { 125 read = "r", lines = "r", fill = "r", unpack = "r", 126 write = "w", flush = "w", pack = "w", 127 128 -- these too for good measure, even though they're not buffered 129 recvfd = "r", sendfd = "w", 130} 131 132-- drop EPIPE errors on input channel 133local nopipe = { 134 read = true, lines = true, fill = true, unpack = true, recvfd = true 135} 136 137local function oops(self, op, why, level) 138 local onerror = self:onerror() or def_onerror 139 140 if why == EPIPE and nopipe[op] then 141 return -- EOF 142 elseif preserve[op] then 143 self:seterror(preserve[op], why) 144 end 145 146 -- NOTE: There's normally no need to increment on a tail-call 147 -- (except when directly calling the error() routine), but we 148 -- increment here so the callee has the correct stack level to pass 149 -- to error() directly, without making adjustments for its own 150 -- activation record. 151 return onerror(self, op, why, (level or 2) + 1) 152end -- oops 153 154 155-- 156-- A P I E X T E N S I O N S 157-- 158-- The core sockets implementation in C will not yield on I/O, or throw 159-- recoverable errors. These things are done in Lua code for simplicitly and 160-- portability--Lua 5.1/LuaJIT doesn't support resumption of C routines. 161-- 162-- ======================================================================== 163 164-- 165-- Extended socket.pair 166-- 167local _pair = socket.pair; socket.pair = function(type) 168 if type == "stream" then 169 type = SOCK_STREAM 170 elseif type == "dgram" then 171 type = SOCK_DGRAM 172 end 173 174 return _pair(type) 175end 176 177 178-- 179-- Throwable socket:setbufsiz 180-- 181local _setbufsiz; _setbufsiz = socket.interpose("setbufsiz", function(self, input_, output_) 182 local input, output, why = _setbufsiz(self, input_, output_) 183 184 if not input then 185 return nil, nil, oops(self, "setbufsiz", why) 186 end 187 188 return input, output 189end) 190 191 192-- 193-- Yielding socket:listen 194-- 195local _listen; _listen = socket.interpose("listen", function(self, timeout) 196 if not timeout then 197 timeout = self:timeout() 198 end 199 local deadline = timeout and (monotime() + timeout) 200 local ok, why = _listen(self) 201 202 while not ok do 203 if why == EAGAIN then 204 if not timed_poll(self, deadline) then 205 return nil, oops(self, "listen", ETIMEDOUT) 206 end 207 else 208 return nil, oops(self, "listen", why) 209 end 210 211 ok, why = _listen(self) 212 end 213 214 return self 215end) 216 217 218-- 219-- Yielding socket:accept 220-- 221local _accept; _accept = socket.interpose("accept", function(self, opts, timeout) 222 -- :accept used to take just a timeout as argument 223 if type(opts) == "number" then 224 timeout, opts = opts, nil 225 else 226 timeout = timeout or self:timeout() 227 end 228 local deadline = timeout and (monotime() + timeout) 229 local con, why = _accept(self, opts) 230 231 while not con do 232 if why == EAGAIN then 233 if not timed_poll(self, deadline) then 234 return nil, oops(self, "accept", ETIMEDOUT) 235 end 236 else 237 return nil, oops(self, "accept", why) 238 end 239 240 con, why = _accept(self, opts) 241 end 242 243 return con 244end) 245 246 247-- 248-- Add socket:clients 249-- 250socket.interpose("clients", function(self, opts, timeout) 251 return function() return self:accept(opts, timeout) end 252end) 253 254 255-- 256-- Yielding socket:connect 257-- 258local _connect; _connect = socket.interpose("connect", function(self, timeout) 259 if not timeout then 260 timeout = self:timeout() 261 end 262 local deadline = timeout and (monotime() + timeout) 263 local ok, why = _connect(self) 264 265 while not ok do 266 if why == EAGAIN then 267 if not timed_poll(self, deadline) then 268 return nil, oops(self, "connect", ETIMEDOUT) 269 end 270 else 271 return nil, oops(self, "connect", why) 272 end 273 274 ok, why = _connect(self) 275 end 276 277 return self 278end) 279 280 281-- 282-- Yielding socket:starttls 283-- 284local _starttls; _starttls = socket.interpose("starttls", function(self, arg1, arg2) 285 local ctx, timeout 286 287 if type(arg1) == "userdata" then 288 ctx = arg1 289 elseif type(arg2) == "userdata" then 290 ctx = arg2 291 end 292 293 if type(arg1) == "number" then 294 timeout = arg1 295 elseif type(arg2) == "number" then 296 timeout = arg2 297 else 298 timeout = self:timeout() 299 end 300 301 local deadline = timeout and monotime() + timeout 302 local ok, why = _starttls(self, ctx) 303 304 while not ok do 305 if why == EAGAIN then 306 if not timed_poll(self, deadline) then 307 return nil, oops(self, "starttls", ETIMEDOUT) 308 end 309 else 310 return nil, oops(self, "starttls", why) 311 end 312 313 ok, why = _starttls(self, ctx) 314 end 315 316 return self 317end) 318 319 320-- 321-- Smarter socket:checktls 322-- 323local havessl, whynossl 324 325local _checktls; _checktls = socket.interpose("checktls", function(self) 326 if not havessl then 327 if havessl == false then 328 return nil, whynossl 329 end 330 331 havessl, whynossl = pcall(require, "openssl.ssl") 332 333 if not havessl then 334 return nil, whynossl 335 end 336 end 337 338 return _checktls(self) 339end) 340 341 342-- 343-- Yielding socket:flush 344-- 345local _flush; 346 347local function timed_flush(self, mode, timeout, level) 348 local ok, why = _flush(self, mode) 349 350 if not ok then 351 local deadline = timeout and (monotime() + timeout) 352 353 repeat 354 if why == EAGAIN then 355 if not timed_poll(self, deadline) then 356 return false, oops(self, "flush", ETIMEDOUT, level + 1) 357 end 358 else 359 return false, oops(self, "flush", why, level + 1) 360 end 361 362 ok, why = _flush(self, mode) 363 until ok 364 end 365 366 return true 367end -- timed_flush 368 369_flush = socket.interpose("flush", function (self, arg1, arg2) 370 local mode, timeout 371 372 if type(arg1) == "string" then 373 mode = arg1 374 elseif type(arg2) == "string" then 375 mode = arg2 376 end 377 378 if type(arg1) == "number" then 379 timeout = arg1 380 elseif type(arg2) == "number" then 381 timeout = arg2 382 else 383 timeout = self:timeout() 384 end 385 386 return timed_flush(self, mode, timeout, 2) 387end) 388 389 390-- 391-- Yielding socket:read 392-- 393local function read(self, func, what, ...) 394 if not what then 395 return 396 end 397 398 local data, why = self:recv(what) 399 400 if not data then 401 local timeout = self:timeout() 402 local deadline = timeout and (monotime() + timeout) 403 404 repeat 405 if why == EAGAIN then 406 if not timed_poll(self, deadline) then 407 return nil, oops(self, func, ETIMEDOUT, 2) 408 end 409 elseif why then 410 return nil, oops(self, func, why, 2) 411 else 412 return -- EOF or end-of-headers 413 end 414 415 data, why = self:recv(what) 416 until data 417 end 418 419 return data, read(self, func, ...) 420end 421 422socket.interpose("read", function(self, what, ...) 423 if what then 424 return read(self, "read", what, ...) 425 else 426 return read(self, "read", "*l") 427 end 428end) 429 430 431-- 432-- Yielding socket:write 433-- 434-- This is complicated by the fact that we want error messages to get the 435-- correct stack trace, and also because on failure we want to return a list 436-- of error values of indeterminate length. 437-- 438local writeall; writeall = function(self, data, ...) 439 if not data then 440 return self 441 end 442 443 data = tostring(data) 444 445 local i = 1 446 447 while i <= #data do 448 -- use only full buffering mode here to minimize socket I/O 449 local n, why = self:send(data, i, #data, "f") 450 451 i = i + n 452 453 if i <= #data then 454 if why == EAGAIN then 455 local timeout = self:timeout() 456 local deadline = timeout and (monotime() + timeout) 457 458 if not timed_poll(self, deadline) then 459 return nil, oops(self, "write", ETIMEDOUT, 3) 460 end 461 else 462 return nil, oops(self, "write", why, 3) 463 end 464 end 465 end 466 467 return writeall(self, ...) 468end 469 470local function fileresult(self, ok, ...) 471 if ok then 472 return self 473 else 474 return nil, ... 475 end 476end -- fileresult 477 478local function flushwrite(self, ok, ...) 479 if not ok then 480 return nil, ... 481 end 482 483 -- Flush the buffer here because we used full buffering mode in 484 -- writeall. But pass empty mode so it uses the configured flushing 485 -- mode instead of an implicit flush all. 486 return fileresult(self, timed_flush(self, "", nil, 2)) 487end -- flushwrite 488 489socket.interpose("write", function (self, ...) 490 return flushwrite(self, writeall(self, ...)) 491end) 492 493 494-- 495-- Add socket:lines 496-- 497-- We optimize single-mode case so we're not unpacking tables all the time. 498-- 499local unpack = assert(table.unpack or unpack) 500 501socket.interpose("lines", function (self, mode, ...) 502 if mode then 503 local n = select("#", ...) 504 if n > 0 then 505 local args = { ... } 506 507 return function () 508 return read(self, "lines", mode, unpack(args, 1, n)) 509 end 510 end 511 else 512 mode = "*l" 513 end 514 515 return function () 516 return read(self, "lines", mode) 517 end 518end) 519 520 521-- returns mode, timeout 522local function xopts(arg1, arg2) 523 if tonumber(arg1) then 524 return arg2, arg1 525 else 526 return arg1, arg2 527 end 528end -- xopts 529 530 531local function xdeadline(self, timeout) 532 timeout = timeout or self:timeout() 533 534 return timeout and (monotime() + timeout) 535end -- xdeadline 536 537 538-- 539-- Smarter socket:read 540-- 541socket.interpose("xread", function (self, what, ...) 542 local mode, timeout = xopts(...) 543 544 local data, why = self:recv(what, mode) 545 546 if not data then 547 local deadline = xdeadline(self, timeout) 548 549 repeat 550 if why == EAGAIN then 551 if not timed_poll(self, deadline) then 552 return nil, oops(self, "read", ETIMEDOUT) 553 end 554 elseif why then 555 return nil, oops(self, "read", why) 556 else 557 return --> EOF 558 end 559 560 data, why = self:recv(what, mode) 561 until data 562 end 563 564 return data 565end) -- xread 566 567 568-- 569-- Smarter socket:write 570-- 571socket.interpose("xwrite", function (self, data, ...) 572 local mode, timeout = xopts(...) 573 local i = 1 574 575 -- 576 -- should we default to full-buffering here (and the :send below) if 577 -- mode is nil? 578 -- 579 local n, why = self:send(data, i, #data, mode) 580 581 i = i + n 582 583 if i <= #data then 584 local deadline = xdeadline(self, timeout) 585 586 repeat 587 if why == EAGAIN then 588 if not timed_poll(self, deadline) then 589 return nil, oops(self, "write", ETIMEDOUT) 590 end 591 else 592 return nil, oops(self, "write", why) 593 end 594 595 n, why = self:send(data, i, #data, mode) 596 597 i = i + n 598 until i > #data 599 600 timeout = deadline and math.max(0, deadline - monotime()) 601 end 602 603 return fileresult(self, self:flush(mode or "", timeout)) 604end) 605 606 607-- 608-- Smarter socket:lines 609-- 610socket.interpose("xlines", function (self, what, ...) 611 local mode, timeout = xopts(...) 612 613 return function () 614 return self:xread(what, mode, timeout) 615 end 616end) 617 618 619-- 620-- Yielding socket:sendfd 621-- 622local _sendfd; _sendfd = socket.interpose("sendfd", function (self, msg, fd, timeout) 623 if not timeout then 624 timeout = self:timeout() 625 end 626 local deadline = timeout and (monotime() + timeout) 627 local ok, why 628 629 repeat 630 ok, why = _sendfd(self, msg, fd) 631 632 if not ok then 633 if why == EAGAIN then 634 if not timed_poll(self, deadline) then 635 return false, oops(self, "sendfd", ETIMEDOUT) 636 end 637 else 638 return false, oops(self, "sendfd", why) 639 end 640 end 641 until ok 642 643 return ok 644end) 645 646 647-- 648-- Yielding socket:recvfd 649-- 650local _recvfd; _recvfd = socket.interpose("recvfd", function (self, prepbufsiz, timeout) 651 if not timeout then 652 timeout = self:timeout() 653 end 654 local deadline = timeout and (monotime() + timeout) 655 local msg, fd, why 656 657 repeat 658 msg, fd, why = _recvfd(self, prepbufsiz) 659 660 if not msg then 661 if why == EAGAIN then 662 if not timed_poll(self, deadline) then 663 return nil, nil, oops(self, "recvfd", ETIMEDOUT) 664 end 665 else 666 return nil, nil, oops(self, "recvfd", why) 667 end 668 end 669 until msg 670 671 return msg, fd 672end) 673 674 675-- 676-- Yielding socket:pack 677-- 678local _pack; _pack = socket.interpose("pack", function (self, num, nbits, mode) 679 local ok, why = _pack(self, num, nbits, mode) 680 681 if not ok then 682 local timeout = self:timeout() 683 local deadline = timeout and (monotime() + timeout) 684 685 repeat 686 if why == EAGAIN then 687 if not timed_poll(self, deadline) then 688 return false, oops(self, "pack", ETIMEDOUT) 689 end 690 else 691 return false, oops(self, "pack", why) 692 end 693 694 ok, why = _pack(self, num, nbits, mode) 695 until ok 696 end 697 698 return ok 699end) 700 701 702-- 703-- Yielding socket:unpack 704-- 705local _unpack; _unpack = socket.interpose("unpack", function (self, nbits) 706 local num, why = _unpack(self, nbits) 707 708 if not num then 709 local timeout = self:timeout() 710 local deadline = timeout and (monotime() + timeout) 711 712 repeat 713 if why == EAGAIN then 714 if not timed_poll(self, deadline) then 715 return nil, oops(self, "unpack", ETIMEDOUT) 716 end 717 else 718 return nil, oops(self, "unpack", why) 719 end 720 721 num, why = _unpack(self, nbits) 722 until num 723 end 724 725 return num 726end) 727 728 729-- 730-- Yielding socket:fill 731-- 732local _fill; _fill = socket.interpose("fill", function (self, size, timeout) 733 local ok, why = _fill(self, size) 734 735 if not ok then 736 if not timeout then 737 timeout = self:timeout() 738 end 739 local deadline = timeout and (monotime() + timeout) 740 741 repeat 742 if why == EAGAIN then 743 if not timed_poll(self, deadline) then 744 return false, oops(self, "fill", ETIMEDOUT) 745 end 746 else 747 return false, oops(self, "fill", why) 748 end 749 750 ok, why = _fill(self, size) 751 until ok 752 end 753 754 return true 755end) 756 757 758-- 759-- Extend socket:peername 760-- 761local function getname(get, self) 762 local af, r1, r2 = get(self) 763 764 if af then 765 return af, r1, r2 766 elseif r1 == ENOTCONN or r1 == ENOTSOCK or r1 == EAGAIN then 767 return 0 768 else 769 return nil, r1 770 end 771end 772 773local _peername; _peername = socket.interpose("peername", function (self) 774 return getname(_peername, self) 775end) 776 777 778-- 779-- Extend socket:localname 780-- 781local _localname; _localname = socket.interpose("localname", function (self) 782 return getname(_localname, self) 783end) 784 785 786socket.loader = loader 787 788return socket 789 790end -- loader 791 792return loader(loader) 793