1----------------------------------------------------------------------------- 2-- A hacked dispatcher module 3-- LuaSocket sample files 4-- Author: Diego Nehab 5----------------------------------------------------------------------------- 6local base = _G 7local table = require("table") 8local string = require("string") 9local socket = require("socket") 10local coroutine = require("coroutine") 11module("dispatch") 12 13-- if too much time goes by without any activity in one of our sockets, we 14-- just kill it 15TIMEOUT = 60 16 17----------------------------------------------------------------------------- 18-- We implement 3 types of dispatchers: 19-- sequential 20-- coroutine 21-- threaded 22-- The user can choose whatever one is needed 23----------------------------------------------------------------------------- 24local handlert = {} 25 26-- default handler is coroutine 27function newhandler(mode) 28 mode = mode or "coroutine" 29 return handlert[mode]() 30end 31 32local function seqstart(self, func) 33 return func() 34end 35 36-- sequential handler simply calls the functions and doesn't wrap I/O 37function handlert.sequential() 38 return { 39 tcp = socket.tcp, 40 start = seqstart 41 } 42end 43 44----------------------------------------------------------------------------- 45-- Mega hack. Don't try to do this at home. 46----------------------------------------------------------------------------- 47-- we can't yield across calls to protect on Lua 5.1, so we rewrite it with 48-- coroutines 49-- make sure you don't require any module that uses socket.protect before 50-- loading our hack 51if string.sub(base._VERSION, -3) == "5.1" then 52 local function _protect(co, status, ...) 53 if not status then 54 local msg = ... 55 if base.type(msg) == 'table' then 56 return nil, msg[1] 57 else 58 base.error(msg, 0) 59 end 60 end 61 if coroutine.status(co) == "suspended" then 62 return _protect(co, coroutine.resume(co, coroutine.yield(...))) 63 else 64 return ... 65 end 66 end 67 68 function socket.protect(f) 69 return function(...) 70 local co = coroutine.create(f) 71 return _protect(co, coroutine.resume(co, ...)) 72 end 73 end 74end 75 76----------------------------------------------------------------------------- 77-- Simple set data structure. O(1) everything. 78----------------------------------------------------------------------------- 79local function newset() 80 local reverse = {} 81 local set = {} 82 return base.setmetatable(set, {__index = { 83 insert = function(set, value) 84 if not reverse[value] then 85 table.insert(set, value) 86 reverse[value] = #set 87 end 88 end, 89 remove = function(set, value) 90 local index = reverse[value] 91 if index then 92 reverse[value] = nil 93 local top = table.remove(set) 94 if top ~= value then 95 reverse[top] = index 96 set[index] = top 97 end 98 end 99 end 100 }}) 101end 102 103----------------------------------------------------------------------------- 104-- socket.tcp() wrapper for the coroutine dispatcher 105----------------------------------------------------------------------------- 106local function cowrap(dispatcher, tcp, error) 107 if not tcp then return nil, error end 108 -- put it in non-blocking mode right away 109 tcp:settimeout(0) 110 -- metatable for wrap produces new methods on demand for those that we 111 -- don't override explicitly. 112 local metat = { __index = function(table, key) 113 table[key] = function(...) 114 return tcp[key](tcp,select(2,...)) 115 end 116 return table[key] 117 end} 118 -- does our user want to do his own non-blocking I/O? 119 local zero = false 120 -- create a wrap object that will behave just like a real socket object 121 local wrap = { } 122 -- we ignore settimeout to preserve our 0 timeout, but record whether 123 -- the user wants to do his own non-blocking I/O 124 function wrap:settimeout(value, mode) 125 if value == 0 then zero = true 126 else zero = false end 127 return 1 128 end 129 -- send in non-blocking mode and yield on timeout 130 function wrap:send(data, first, last) 131 first = (first or 1) - 1 132 local result, error 133 while true do 134 -- return control to dispatcher and tell it we want to send 135 -- if upon return the dispatcher tells us we timed out, 136 -- return an error to whoever called us 137 if coroutine.yield(dispatcher.sending, tcp) == "timeout" then 138 return nil, "timeout" 139 end 140 -- try sending 141 result, error, first = tcp:send(data, first+1, last) 142 -- if we are done, or there was an unexpected error, 143 -- break away from loop 144 if error ~= "timeout" then return result, error, first end 145 end 146 end 147 -- receive in non-blocking mode and yield on timeout 148 -- or simply return partial read, if user requested timeout = 0 149 function wrap:receive(pattern, partial) 150 local error = "timeout" 151 local value 152 while true do 153 -- return control to dispatcher and tell it we want to receive 154 -- if upon return the dispatcher tells us we timed out, 155 -- return an error to whoever called us 156 if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then 157 return nil, "timeout" 158 end 159 -- try receiving 160 value, error, partial = tcp:receive(pattern, partial) 161 -- if we are done, or there was an unexpected error, 162 -- break away from loop. also, if the user requested 163 -- zero timeout, return all we got 164 if (error ~= "timeout") or zero then 165 return value, error, partial 166 end 167 end 168 end 169 -- connect in non-blocking mode and yield on timeout 170 function wrap:connect(host, port) 171 local result, error = tcp:connect(host, port) 172 if error == "timeout" then 173 -- return control to dispatcher. we will be writable when 174 -- connection succeeds. 175 -- if upon return the dispatcher tells us we have a 176 -- timeout, just abort 177 if coroutine.yield(dispatcher.sending, tcp) == "timeout" then 178 return nil, "timeout" 179 end 180 -- when we come back, check if connection was successful 181 result, error = tcp:connect(host, port) 182 if result or error == "already connected" then return 1 183 else return nil, "non-blocking connect failed" end 184 else return result, error end 185 end 186 -- accept in non-blocking mode and yield on timeout 187 function wrap:accept() 188 while 1 do 189 -- return control to dispatcher. we will be readable when a 190 -- connection arrives. 191 -- if upon return the dispatcher tells us we have a 192 -- timeout, just abort 193 if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then 194 return nil, "timeout" 195 end 196 local client, error = tcp:accept() 197 if error ~= "timeout" then 198 return cowrap(dispatcher, client, error) 199 end 200 end 201 end 202 -- remove cortn from context 203 function wrap:close() 204 dispatcher.stamp[tcp] = nil 205 dispatcher.sending.set:remove(tcp) 206 dispatcher.sending.cortn[tcp] = nil 207 dispatcher.receiving.set:remove(tcp) 208 dispatcher.receiving.cortn[tcp] = nil 209 return tcp:close() 210 end 211 return base.setmetatable(wrap, metat) 212end 213 214 215----------------------------------------------------------------------------- 216-- Our coroutine dispatcher 217----------------------------------------------------------------------------- 218local cometat = { __index = {} } 219 220function schedule(cortn, status, operation, tcp) 221 if status then 222 if cortn and operation then 223 operation.set:insert(tcp) 224 operation.cortn[tcp] = cortn 225 operation.stamp[tcp] = socket.gettime() 226 end 227 else base.error(operation) end 228end 229 230function kick(operation, tcp) 231 operation.cortn[tcp] = nil 232 operation.set:remove(tcp) 233end 234 235function wakeup(operation, tcp) 236 local cortn = operation.cortn[tcp] 237 -- if cortn is still valid, wake it up 238 if cortn then 239 kick(operation, tcp) 240 return cortn, coroutine.resume(cortn) 241 -- othrewise, just get scheduler not to do anything 242 else 243 return nil, true 244 end 245end 246 247function abort(operation, tcp) 248 local cortn = operation.cortn[tcp] 249 if cortn then 250 kick(operation, tcp) 251 coroutine.resume(cortn, "timeout") 252 end 253end 254 255-- step through all active cortns 256function cometat.__index:step() 257 -- check which sockets are interesting and act on them 258 local readable, writable = socket.select(self.receiving.set, 259 self.sending.set, 1) 260 -- for all readable connections, resume their cortns and reschedule 261 -- when they yield back to us 262 for _, tcp in base.ipairs(readable) do 263 schedule(wakeup(self.receiving, tcp)) 264 end 265 -- for all writable connections, do the same 266 for _, tcp in base.ipairs(writable) do 267 schedule(wakeup(self.sending, tcp)) 268 end 269 -- politely ask replacement I/O functions in idle cortns to 270 -- return reporting a timeout 271 local now = socket.gettime() 272 for tcp, stamp in base.pairs(self.stamp) do 273 if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then 274 abort(self.sending, tcp) 275 abort(self.receiving, tcp) 276 end 277 end 278end 279 280function cometat.__index:start(func) 281 local cortn = coroutine.create(func) 282 schedule(cortn, coroutine.resume(cortn)) 283end 284 285function handlert.coroutine() 286 local stamp = {} 287 local dispatcher = { 288 stamp = stamp, 289 sending = { 290 name = "sending", 291 set = newset(), 292 cortn = {}, 293 stamp = stamp 294 }, 295 receiving = { 296 name = "receiving", 297 set = newset(), 298 cortn = {}, 299 stamp = stamp 300 }, 301 } 302 function dispatcher.tcp() 303 return cowrap(dispatcher, socket.tcp()) 304 end 305 return base.setmetatable(dispatcher, cometat) 306end 307 308