1----------------------------------------------------------------------------- 2-- A hacked dispatcher module 3-- LuaSocket sample files 4-- Author: Diego Nehab 5----------------------------------------------------------------------------- 6local base = _G 7local table = require("table") 8local socket = require("socket") 9local coroutine = require("coroutine") 10module("dispatch") 11 12-- if too much time goes by without any activity in one of our sockets, we 13-- just kill it 14TIMEOUT = 60 15 16----------------------------------------------------------------------------- 17-- We implement 3 types of dispatchers: 18-- sequential 19-- coroutine 20-- threaded 21-- The user can choose whatever one is needed 22----------------------------------------------------------------------------- 23local handlert = {} 24 25-- default handler is coroutine 26function newhandler(mode) 27 mode = mode or "coroutine" 28 return handlert[mode]() 29end 30 31local function seqstart(self, func) 32 return func() 33end 34 35-- sequential handler simply calls the functions and doesn't wrap I/O 36function handlert.sequential() 37 return { 38 tcp = socket.tcp, 39 start = seqstart 40 } 41end 42 43----------------------------------------------------------------------------- 44-- Mega hack. Don't try to do this at home. 45----------------------------------------------------------------------------- 46-- we can't yield across calls to protect, so we rewrite it with coxpcall 47-- make sure you don't require any module that uses socket.protect before 48-- loading our hack 49function socket.protect(f) 50 return function(...) 51 local co = coroutine.create(f) 52 while true do 53 local results = {coroutine.resume(co, base.unpack(arg))} 54 local status = table.remove(results, 1) 55 if not status then 56 if base.type(results[1]) == 'table' then 57 return nil, results[1][1] 58 else base.error(results[1]) end 59 end 60 if coroutine.status(co) == "suspended" then 61 arg = {coroutine.yield(base.unpack(results))} 62 else 63 return base.unpack(results) 64 end 65 end 66 end 67end 68 69----------------------------------------------------------------------------- 70-- Simple set data structure. O(1) everything. 71----------------------------------------------------------------------------- 72local function newset() 73 local reverse = {} 74 local set = {} 75 return base.setmetatable(set, {__index = { 76 insert = function(set, value) 77 if not reverse[value] then 78 table.insert(set, value) 79 reverse[value] = table.getn(set) 80 end 81 end, 82 remove = function(set, value) 83 local index = reverse[value] 84 if index then 85 reverse[value] = nil 86 local top = table.remove(set) 87 if top ~= value then 88 reverse[top] = index 89 set[index] = top 90 end 91 end 92 end 93 }}) 94end 95 96----------------------------------------------------------------------------- 97-- socket.tcp() wrapper for the coroutine dispatcher 98----------------------------------------------------------------------------- 99local function cowrap(dispatcher, tcp, error) 100 if not tcp then return nil, error end 101 -- put it in non-blocking mode right away 102 tcp:settimeout(0) 103 -- metatable for wrap produces new methods on demand for those that we 104 -- don't override explicitly. 105 local metat = { __index = function(table, key) 106 table[key] = function(...) 107 arg[1] = tcp 108 return tcp[key](base.unpack(arg)) 109 end 110 return table[key] 111 end} 112 -- does our user want to do his own non-blocking I/O? 113 local zero = false 114 -- create a wrap object that will behave just like a real socket object 115 local wrap = { } 116 -- we ignore settimeout to preserve our 0 timeout, but record whether 117 -- the user wants to do his own non-blocking I/O 118 function wrap:settimeout(value, mode) 119 if value == 0 then zero = true 120 else zero = false end 121 return 1 122 end 123 -- send in non-blocking mode and yield on timeout 124 function wrap:send(data, first, last) 125 first = (first or 1) - 1 126 local result, error 127 while true do 128 -- return control to dispatcher and tell it we want to send 129 -- if upon return the dispatcher tells us we timed out, 130 -- return an error to whoever called us 131 if coroutine.yield(dispatcher.sending, tcp) == "timeout" then 132 return nil, "timeout" 133 end 134 -- try sending 135 result, error, first = tcp:send(data, first+1, last) 136 -- if we are done, or there was an unexpected error, 137 -- break away from loop 138 if error ~= "timeout" then return result, error, first end 139 end 140 end 141 -- receive in non-blocking mode and yield on timeout 142 -- or simply return partial read, if user requested timeout = 0 143 function wrap:receive(pattern, partial) 144 local error = "timeout" 145 local value 146 while true do 147 -- return control to dispatcher and tell it we want to receive 148 -- if upon return the dispatcher tells us we timed out, 149 -- return an error to whoever called us 150 if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then 151 return nil, "timeout" 152 end 153 -- try receiving 154 value, error, partial = tcp:receive(pattern, partial) 155 -- if we are done, or there was an unexpected error, 156 -- break away from loop. also, if the user requested 157 -- zero timeout, return all we got 158 if (error ~= "timeout") or zero then 159 return value, error, partial 160 end 161 end 162 end 163 -- connect in non-blocking mode and yield on timeout 164 function wrap:connect(host, port) 165 local result, error = tcp:connect(host, port) 166 if error == "timeout" then 167 -- return control to dispatcher. we will be writable when 168 -- connection succeeds. 169 -- if upon return the dispatcher tells us we have a 170 -- timeout, just abort 171 if coroutine.yield(dispatcher.sending, tcp) == "timeout" then 172 return nil, "timeout" 173 end 174 -- when we come back, check if connection was successful 175 result, error = tcp:connect(host, port) 176 if result or error == "already connected" then return 1 177 else return nil, "non-blocking connect failed" end 178 else return result, error end 179 end 180 -- accept in non-blocking mode and yield on timeout 181 function wrap:accept() 182 while 1 do 183 -- return control to dispatcher. we will be readable when a 184 -- connection arrives. 185 -- if upon return the dispatcher tells us we have a 186 -- timeout, just abort 187 if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then 188 return nil, "timeout" 189 end 190 local client, error = tcp:accept() 191 if error ~= "timeout" then 192 return cowrap(dispatcher, client, error) 193 end 194 end 195 end 196 -- remove cortn from context 197 function wrap:close() 198 dispatcher.stamp[tcp] = nil 199 dispatcher.sending.set:remove(tcp) 200 dispatcher.sending.cortn[tcp] = nil 201 dispatcher.receiving.set:remove(tcp) 202 dispatcher.receiving.cortn[tcp] = nil 203 return tcp:close() 204 end 205 return base.setmetatable(wrap, metat) 206end 207 208 209----------------------------------------------------------------------------- 210-- Our coroutine dispatcher 211----------------------------------------------------------------------------- 212local cometat = { __index = {} } 213 214function schedule(cortn, status, operation, tcp) 215 if status then 216 if cortn and operation then 217 operation.set:insert(tcp) 218 operation.cortn[tcp] = cortn 219 operation.stamp[tcp] = socket.gettime() 220 end 221 else base.error(operation) end 222end 223 224function kick(operation, tcp) 225 operation.cortn[tcp] = nil 226 operation.set:remove(tcp) 227end 228 229function wakeup(operation, tcp) 230 local cortn = operation.cortn[tcp] 231 -- if cortn is still valid, wake it up 232 if cortn then 233 kick(operation, tcp) 234 return cortn, coroutine.resume(cortn) 235 -- othrewise, just get scheduler not to do anything 236 else 237 return nil, true 238 end 239end 240 241function abort(operation, tcp) 242 local cortn = operation.cortn[tcp] 243 if cortn then 244 kick(operation, tcp) 245 coroutine.resume(cortn, "timeout") 246 end 247end 248 249-- step through all active cortns 250function cometat.__index:step() 251 -- check which sockets are interesting and act on them 252 local readable, writable = socket.select(self.receiving.set, 253 self.sending.set, 1) 254 -- for all readable connections, resume their cortns and reschedule 255 -- when they yield back to us 256 for _, tcp in base.ipairs(readable) do 257 schedule(wakeup(self.receiving, tcp)) 258 end 259 -- for all writable connections, do the same 260 for _, tcp in base.ipairs(writable) do 261 schedule(wakeup(self.sending, tcp)) 262 end 263 -- politely ask replacement I/O functions in idle cortns to 264 -- return reporting a timeout 265 local now = socket.gettime() 266 for tcp, stamp in base.pairs(self.stamp) do 267 if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then 268 abort(self.sending, tcp) 269 abort(self.receiving, tcp) 270 end 271 end 272end 273 274function cometat.__index:start(func) 275 local cortn = coroutine.create(func) 276 schedule(cortn, coroutine.resume(cortn)) 277end 278 279function handlert.coroutine() 280 local stamp = {} 281 local dispatcher = { 282 stamp = stamp, 283 sending = { 284 name = "sending", 285 set = newset(), 286 cortn = {}, 287 stamp = stamp 288 }, 289 receiving = { 290 name = "receiving", 291 set = newset(), 292 cortn = {}, 293 stamp = stamp 294 }, 295 } 296 function dispatcher.tcp() 297 return cowrap(dispatcher, socket.tcp()) 298 end 299 return base.setmetatable(dispatcher, cometat) 300end 301 302