1-----------------------------------------------------------------------------
2-- A hacked dispatcher module
3-- LuaSocket sample files
4-- Author: Diego Nehab
5-----------------------------------------------------------------------------
6local base = _G
7local table = require("table")
8local string = require("string")
9local socket = require("socket")
10local coroutine = require("coroutine")
11module("dispatch")
12
13-- if too much time goes by without any activity in one of our sockets, we
14-- just kill it
15TIMEOUT = 60
16
17-----------------------------------------------------------------------------
18-- We implement 3 types of dispatchers:
19--     sequential
20--     coroutine
21--     threaded
22-- The user can choose whatever one is needed
23-----------------------------------------------------------------------------
24local handlert = {}
25
26-- default handler is coroutine
27function newhandler(mode)
28    mode = mode or "coroutine"
29    return handlert[mode]()
30end
31
32local function seqstart(self, func)
33    return func()
34end
35
36-- sequential handler simply calls the functions and doesn't wrap I/O
37function handlert.sequential()
38    return {
39        tcp = socket.tcp,
40        start = seqstart
41    }
42end
43
44-----------------------------------------------------------------------------
45-- Mega hack. Don't try to do this at home.
46-----------------------------------------------------------------------------
47-- we can't yield across calls to protect on Lua 5.1, so we rewrite it with
48-- coroutines
49-- make sure you don't require any module that uses socket.protect before
50-- loading our hack
51if string.sub(base._VERSION, -3) == "5.1" then
52  local function _protect(co, status, ...)
53    if not status then
54      local msg = ...
55      if base.type(msg) == 'table' then
56        return nil, msg[1]
57      else
58        base.error(msg, 0)
59      end
60    end
61    if coroutine.status(co) == "suspended" then
62      return _protect(co, coroutine.resume(co, coroutine.yield(...)))
63    else
64      return ...
65    end
66  end
67
68  function socket.protect(f)
69    return function(...)
70      local co = coroutine.create(f)
71      return _protect(co, coroutine.resume(co, ...))
72    end
73  end
74end
75
76-----------------------------------------------------------------------------
77-- Simple set data structure. O(1) everything.
78-----------------------------------------------------------------------------
79local function newset()
80    local reverse = {}
81    local set = {}
82    return base.setmetatable(set, {__index = {
83        insert = function(set, value)
84            if not reverse[value] then
85                table.insert(set, value)
86                reverse[value] = #set
87            end
88        end,
89        remove = function(set, value)
90            local index = reverse[value]
91            if index then
92                reverse[value] = nil
93                local top = table.remove(set)
94                if top ~= value then
95                    reverse[top] = index
96                    set[index] = top
97                end
98            end
99        end
100    }})
101end
102
103-----------------------------------------------------------------------------
104-- socket.tcp() wrapper for the coroutine dispatcher
105-----------------------------------------------------------------------------
106local function cowrap(dispatcher, tcp, error)
107    if not tcp then return nil, error end
108    -- put it in non-blocking mode right away
109    tcp:settimeout(0)
110    -- metatable for wrap produces new methods on demand for those that we
111    -- don't override explicitly.
112    local metat = { __index = function(table, key)
113        table[key] = function(...)
114            return tcp[key](tcp,select(2,...))
115        end
116        return table[key]
117    end}
118    -- does our user want to do his own non-blocking I/O?
119    local zero = false
120    -- create a wrap object that will behave just like a real socket object
121    local wrap = {  }
122    -- we ignore settimeout to preserve our 0 timeout, but record whether
123    -- the user wants to do his own non-blocking I/O
124    function wrap:settimeout(value, mode)
125        if value == 0 then zero = true
126        else zero = false end
127        return 1
128    end
129    -- send in non-blocking mode and yield on timeout
130    function wrap:send(data, first, last)
131        first = (first or 1) - 1
132        local result, error
133        while true do
134            -- return control to dispatcher and tell it we want to send
135            -- if upon return the dispatcher tells us we timed out,
136            -- return an error to whoever called us
137            if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
138                return nil, "timeout"
139            end
140            -- try sending
141            result, error, first = tcp:send(data, first+1, last)
142            -- if we are done, or there was an unexpected error,
143            -- break away from loop
144            if error ~= "timeout" then return result, error, first end
145        end
146    end
147    -- receive in non-blocking mode and yield on timeout
148    -- or simply return partial read, if user requested timeout = 0
149    function wrap:receive(pattern, partial)
150        local error = "timeout"
151        local value
152        while true do
153            -- return control to dispatcher and tell it we want to receive
154            -- if upon return the dispatcher tells us we timed out,
155            -- return an error to whoever called us
156            if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
157                return nil, "timeout"
158            end
159            -- try receiving
160            value, error, partial = tcp:receive(pattern, partial)
161            -- if we are done, or there was an unexpected error,
162            -- break away from loop. also, if the user requested
163            -- zero timeout, return all we got
164            if (error ~= "timeout") or zero then
165                return value, error, partial
166            end
167        end
168    end
169    -- connect in non-blocking mode and yield on timeout
170    function wrap:connect(host, port)
171        local result, error = tcp:connect(host, port)
172        if error == "timeout" then
173            -- return control to dispatcher. we will be writable when
174            -- connection succeeds.
175            -- if upon return the dispatcher tells us we have a
176            -- timeout, just abort
177            if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
178                return nil, "timeout"
179            end
180            -- when we come back, check if connection was successful
181            result, error = tcp:connect(host, port)
182            if result or error == "already connected" then return 1
183            else return nil, "non-blocking connect failed" end
184        else return result, error end
185    end
186    -- accept in non-blocking mode and yield on timeout
187    function wrap:accept()
188        while 1 do
189            -- return control to dispatcher. we will be readable when a
190            -- connection arrives.
191            -- if upon return the dispatcher tells us we have a
192            -- timeout, just abort
193            if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
194                return nil, "timeout"
195            end
196            local client, error = tcp:accept()
197            if error ~= "timeout" then
198                return cowrap(dispatcher, client, error)
199            end
200        end
201    end
202    -- remove cortn from context
203    function wrap:close()
204        dispatcher.stamp[tcp] = nil
205        dispatcher.sending.set:remove(tcp)
206        dispatcher.sending.cortn[tcp] = nil
207        dispatcher.receiving.set:remove(tcp)
208        dispatcher.receiving.cortn[tcp] = nil
209        return tcp:close()
210    end
211    return base.setmetatable(wrap, metat)
212end
213
214
215-----------------------------------------------------------------------------
216-- Our coroutine dispatcher
217-----------------------------------------------------------------------------
218local cometat = { __index = {} }
219
220function schedule(cortn, status, operation, tcp)
221    if status then
222        if cortn and operation then
223            operation.set:insert(tcp)
224            operation.cortn[tcp] = cortn
225            operation.stamp[tcp] = socket.gettime()
226        end
227    else base.error(operation) end
228end
229
230function kick(operation, tcp)
231    operation.cortn[tcp] = nil
232    operation.set:remove(tcp)
233end
234
235function wakeup(operation, tcp)
236    local cortn = operation.cortn[tcp]
237    -- if cortn is still valid, wake it up
238    if cortn then
239        kick(operation, tcp)
240        return cortn, coroutine.resume(cortn)
241    -- othrewise, just get scheduler not to do anything
242    else
243        return nil, true
244    end
245end
246
247function abort(operation, tcp)
248    local cortn = operation.cortn[tcp]
249    if cortn then
250        kick(operation, tcp)
251        coroutine.resume(cortn, "timeout")
252    end
253end
254
255-- step through all active cortns
256function cometat.__index:step()
257    -- check which sockets are interesting and act on them
258    local readable, writable = socket.select(self.receiving.set,
259        self.sending.set, 1)
260    -- for all readable connections, resume their cortns and reschedule
261    -- when they yield back to us
262    for _, tcp in base.ipairs(readable) do
263        schedule(wakeup(self.receiving, tcp))
264    end
265    -- for all writable connections, do the same
266    for _, tcp in base.ipairs(writable) do
267        schedule(wakeup(self.sending, tcp))
268    end
269    -- politely ask replacement I/O functions in idle cortns to
270    -- return reporting a timeout
271    local now = socket.gettime()
272    for tcp, stamp in base.pairs(self.stamp) do
273        if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then
274            abort(self.sending, tcp)
275            abort(self.receiving, tcp)
276        end
277    end
278end
279
280function cometat.__index:start(func)
281    local cortn = coroutine.create(func)
282    schedule(cortn, coroutine.resume(cortn))
283end
284
285function handlert.coroutine()
286    local stamp = {}
287    local dispatcher = {
288        stamp = stamp,
289        sending  = {
290            name = "sending",
291            set = newset(),
292            cortn = {},
293            stamp = stamp
294        },
295        receiving = {
296            name = "receiving",
297            set = newset(),
298            cortn = {},
299            stamp = stamp
300        },
301    }
302    function dispatcher.tcp()
303        return cowrap(dispatcher, socket.tcp())
304    end
305    return base.setmetatable(dispatcher, cometat)
306end
307
308