1-----------------------------------------------------------------------------
2-- A hacked dispatcher module
3-- LuaSocket sample files
4-- Author: Diego Nehab
5-----------------------------------------------------------------------------
6local base = _G
7local table = require("table")
8local socket = require("socket")
9local coroutine = require("coroutine")
10module("dispatch")
11
12-- if too much time goes by without any activity in one of our sockets, we
13-- just kill it
14TIMEOUT = 60
15
16-----------------------------------------------------------------------------
17-- We implement 3 types of dispatchers:
18--     sequential
19--     coroutine
20--     threaded
21-- The user can choose whatever one is needed
22-----------------------------------------------------------------------------
23local handlert = {}
24
25-- default handler is coroutine
26function newhandler(mode)
27    mode = mode or "coroutine"
28    return handlert[mode]()
29end
30
31local function seqstart(self, func)
32    return func()
33end
34
35-- sequential handler simply calls the functions and doesn't wrap I/O
36function handlert.sequential()
37    return {
38        tcp = socket.tcp,
39        start = seqstart
40    }
41end
42
43-----------------------------------------------------------------------------
44-- Mega hack. Don't try to do this at home.
45-----------------------------------------------------------------------------
46-- we can't yield across calls to protect, so we rewrite it with coxpcall
47-- make sure you don't require any module that uses socket.protect before
48-- loading our hack
49function socket.protect(f)
50  return function(...)
51    local co = coroutine.create(f)
52    while true do
53      local results = {coroutine.resume(co, base.unpack(arg))}
54      local status = table.remove(results, 1)
55      if not status then
56        if base.type(results[1]) == 'table' then
57          return nil, results[1][1]
58        else base.error(results[1]) end
59      end
60      if coroutine.status(co) == "suspended" then
61        arg = {coroutine.yield(base.unpack(results))}
62      else
63        return base.unpack(results)
64      end
65    end
66  end
67end
68
69-----------------------------------------------------------------------------
70-- Simple set data structure. O(1) everything.
71-----------------------------------------------------------------------------
72local function newset()
73    local reverse = {}
74    local set = {}
75    return base.setmetatable(set, {__index = {
76        insert = function(set, value)
77            if not reverse[value] then
78                table.insert(set, value)
79                reverse[value] = table.getn(set)
80            end
81        end,
82        remove = function(set, value)
83            local index = reverse[value]
84            if index then
85                reverse[value] = nil
86                local top = table.remove(set)
87                if top ~= value then
88                    reverse[top] = index
89                    set[index] = top
90                end
91            end
92        end
93    }})
94end
95
96-----------------------------------------------------------------------------
97-- socket.tcp() wrapper for the coroutine dispatcher
98-----------------------------------------------------------------------------
99local function cowrap(dispatcher, tcp, error)
100    if not tcp then return nil, error end
101    -- put it in non-blocking mode right away
102    tcp:settimeout(0)
103    -- metatable for wrap produces new methods on demand for those that we
104    -- don't override explicitly.
105    local metat = { __index = function(table, key)
106        table[key] = function(...)
107            arg[1] = tcp
108            return tcp[key](base.unpack(arg))
109        end
110        return table[key]
111    end}
112    -- does our user want to do his own non-blocking I/O?
113    local zero = false
114    -- create a wrap object that will behave just like a real socket object
115    local wrap = {  }
116    -- we ignore settimeout to preserve our 0 timeout, but record whether
117    -- the user wants to do his own non-blocking I/O
118    function wrap:settimeout(value, mode)
119        if value == 0 then zero = true
120        else zero = false end
121        return 1
122    end
123    -- send in non-blocking mode and yield on timeout
124    function wrap:send(data, first, last)
125        first = (first or 1) - 1
126        local result, error
127        while true do
128            -- return control to dispatcher and tell it we want to send
129            -- if upon return the dispatcher tells us we timed out,
130            -- return an error to whoever called us
131            if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
132                return nil, "timeout"
133            end
134            -- try sending
135            result, error, first = tcp:send(data, first+1, last)
136            -- if we are done, or there was an unexpected error,
137            -- break away from loop
138            if error ~= "timeout" then return result, error, first end
139        end
140    end
141    -- receive in non-blocking mode and yield on timeout
142    -- or simply return partial read, if user requested timeout = 0
143    function wrap:receive(pattern, partial)
144        local error = "timeout"
145        local value
146        while true do
147            -- return control to dispatcher and tell it we want to receive
148            -- if upon return the dispatcher tells us we timed out,
149            -- return an error to whoever called us
150            if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
151                return nil, "timeout"
152            end
153            -- try receiving
154            value, error, partial = tcp:receive(pattern, partial)
155            -- if we are done, or there was an unexpected error,
156            -- break away from loop. also, if the user requested
157            -- zero timeout, return all we got
158            if (error ~= "timeout") or zero then
159                return value, error, partial
160            end
161        end
162    end
163    -- connect in non-blocking mode and yield on timeout
164    function wrap:connect(host, port)
165        local result, error = tcp:connect(host, port)
166        if error == "timeout" then
167            -- return control to dispatcher. we will be writable when
168            -- connection succeeds.
169            -- if upon return the dispatcher tells us we have a
170            -- timeout, just abort
171            if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
172                return nil, "timeout"
173            end
174            -- when we come back, check if connection was successful
175            result, error = tcp:connect(host, port)
176            if result or error == "already connected" then return 1
177            else return nil, "non-blocking connect failed" end
178        else return result, error end
179    end
180    -- accept in non-blocking mode and yield on timeout
181    function wrap:accept()
182        while 1 do
183            -- return control to dispatcher. we will be readable when a
184            -- connection arrives.
185            -- if upon return the dispatcher tells us we have a
186            -- timeout, just abort
187            if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
188                return nil, "timeout"
189            end
190            local client, error = tcp:accept()
191            if error ~= "timeout" then
192                return cowrap(dispatcher, client, error)
193            end
194        end
195    end
196    -- remove cortn from context
197    function wrap:close()
198        dispatcher.stamp[tcp] = nil
199        dispatcher.sending.set:remove(tcp)
200        dispatcher.sending.cortn[tcp] = nil
201        dispatcher.receiving.set:remove(tcp)
202        dispatcher.receiving.cortn[tcp] = nil
203        return tcp:close()
204    end
205    return base.setmetatable(wrap, metat)
206end
207
208
209-----------------------------------------------------------------------------
210-- Our coroutine dispatcher
211-----------------------------------------------------------------------------
212local cometat = { __index = {} }
213
214function schedule(cortn, status, operation, tcp)
215    if status then
216        if cortn and operation then
217            operation.set:insert(tcp)
218            operation.cortn[tcp] = cortn
219            operation.stamp[tcp] = socket.gettime()
220        end
221    else base.error(operation) end
222end
223
224function kick(operation, tcp)
225    operation.cortn[tcp] = nil
226    operation.set:remove(tcp)
227end
228
229function wakeup(operation, tcp)
230    local cortn = operation.cortn[tcp]
231    -- if cortn is still valid, wake it up
232    if cortn then
233        kick(operation, tcp)
234        return cortn, coroutine.resume(cortn)
235    -- othrewise, just get scheduler not to do anything
236    else
237        return nil, true
238    end
239end
240
241function abort(operation, tcp)
242    local cortn = operation.cortn[tcp]
243    if cortn then
244        kick(operation, tcp)
245        coroutine.resume(cortn, "timeout")
246    end
247end
248
249-- step through all active cortns
250function cometat.__index:step()
251    -- check which sockets are interesting and act on them
252    local readable, writable = socket.select(self.receiving.set,
253        self.sending.set, 1)
254    -- for all readable connections, resume their cortns and reschedule
255    -- when they yield back to us
256    for _, tcp in base.ipairs(readable) do
257        schedule(wakeup(self.receiving, tcp))
258    end
259    -- for all writable connections, do the same
260    for _, tcp in base.ipairs(writable) do
261        schedule(wakeup(self.sending, tcp))
262    end
263    -- politely ask replacement I/O functions in idle cortns to
264    -- return reporting a timeout
265    local now = socket.gettime()
266    for tcp, stamp in base.pairs(self.stamp) do
267        if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then
268            abort(self.sending, tcp)
269            abort(self.receiving, tcp)
270        end
271    end
272end
273
274function cometat.__index:start(func)
275    local cortn = coroutine.create(func)
276    schedule(cortn, coroutine.resume(cortn))
277end
278
279function handlert.coroutine()
280    local stamp = {}
281    local dispatcher = {
282        stamp = stamp,
283        sending  = {
284            name = "sending",
285            set = newset(),
286            cortn = {},
287            stamp = stamp
288        },
289        receiving = {
290            name = "receiving",
291            set = newset(),
292            cortn = {},
293            stamp = stamp
294        },
295    }
296    function dispatcher.tcp()
297        return cowrap(dispatcher, socket.tcp())
298    end
299    return base.setmetatable(dispatcher, cometat)
300end
301
302