1--[[
2--*****************************************************************************
3--* Copyright (C) 1994-2016 Lua.org, PUC-Rio.
4--*
5--* Permission is hereby granted, free of charge, to any person obtaining
6--* a copy of this software and associated documentation files (the
7--* "Software"), to deal in the Software without restriction, including
8--* without limitation the rights to use, copy, modify, merge, publish,
9--* distribute, sublicense, and/or sell copies of the Software, and to
10--* permit persons to whom the Software is furnished to do so, subject to
11--* the following conditions:
12--*
13--* The above copyright notice and this permission notice shall be
14--* included in all copies or substantial portions of the Software.
15--*
16--* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17--* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18--* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19--* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20--* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
21--* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
22--* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23--*****************************************************************************
24--]]
25
26local f
27
28local main, ismain = coroutine.running()
29assert(type(main) == "thread" and ismain)
30assert(not coroutine.resume(main))
31
32
33-- tests for multiple yield/resume arguments
34
35local function eqtab (t1, t2)
36  assert(#t1 == #t2)
37  for i = 1, #t1 do
38    local v = t1[i]
39    assert(t2[i] == v)
40  end
41end
42
43_G.x = nil   -- declare x
44function foo (a, ...)
45  local x, y = coroutine.running()
46  assert(x == f and y == false)
47  -- next call should not corrupt coroutine (but must fail,
48  -- as it attempts to resume the running coroutine)
49  assert(coroutine.resume(f) == false)
50  assert(coroutine.status(f) == "running")
51  local arg = {...}
52  for i=1,#arg do
53    _G.x = {coroutine.yield(table.unpack(arg[i]))}
54  end
55  return table.unpack(a)
56end
57
58f = coroutine.create(foo)
59assert(type(f) == "thread" and coroutine.status(f) == "suspended")
60assert(string.find(tostring(f), "thread"))
61local s,a,b,c,d
62s,a,b,c,d = coroutine.resume(f, {1,2,3}, {}, {1}, {'a', 'b', 'c'})
63assert(s and a == nil and coroutine.status(f) == "suspended")
64s,a,b,c,d = coroutine.resume(f)
65eqtab(_G.x, {})
66assert(s and a == 1 and b == nil)
67s,a,b,c,d = coroutine.resume(f, 1, 2, 3)
68eqtab(_G.x, {1, 2, 3})
69assert(s and a == 'a' and b == 'b' and c == 'c' and d == nil)
70s,a,b,c,d = coroutine.resume(f, "xuxu")
71eqtab(_G.x, {"xuxu"})
72assert(s and a == 1 and b == 2 and c == 3 and d == nil)
73assert(coroutine.status(f) == "dead")
74s, a = coroutine.resume(f, "xuxu")
75assert(not s and string.find(a, "dead") and coroutine.status(f) == "dead")
76
77
78-- yields in tail calls
79local function foo (i) return coroutine.yield(i) end
80f = coroutine.wrap(function ()
81  for i=1,10 do
82    assert(foo(i) == _G.x)
83  end
84  return 'a'
85end)
86for i=1,10 do _G.x = i; assert(f(i) == i) end
87_G.x = 'xuxu'; assert(f('xuxu') == 'a')
88
89-- recursive
90function pf (n, i)
91  coroutine.yield(n)
92  pf(n*i, i+1)
93end
94
95f = coroutine.wrap(pf)
96local s=1
97for i=1,10 do
98  assert(f(1, 1) == s)
99  s = s*i
100end
101
102-- sieve implemented with co-routines
103
104-- generate all the numbers from 2 to n
105function gen (n)
106  return coroutine.wrap(function ()
107    for i=2,n do coroutine.yield(i) end
108  end)
109end
110
111-- filter the numbers generated by 'g', removing multiples of 'p'
112function filter (p, g)
113  return coroutine.wrap(function ()
114    for n in g do
115      if n%p ~= 0 then coroutine.yield(n) end
116    end
117  end)
118end
119
120-- generate primes up to 20
121local x = gen(20)
122local a = {}
123while 1 do
124  local n = x()
125  if n == nil then break end
126  table.insert(a, n)
127  x = filter(n, x)
128end
129
130-- expect 8 primes and last one is 19
131assert(#a == 8 and a[#a] == 19)
132x, a = nil
133
134
135-- yielding across C boundaries
136
137co = coroutine.wrap(function()
138       coroutine.yield(20)
139       return 30
140     end)
141
142assert(co() == 20)
143assert(co() == 30)
144
145
146local f = function (s, i) return coroutine.yield(i) end
147function f (a, b) a = coroutine.yield(a);  error{a + b} end
148function g(x) return x[1]*2 end
149
150
151-- unyieldable C call
152do
153  local function f (c)
154          return c .. c
155        end
156
157  local co = coroutine.wrap(function (c)
158               local s = string.gsub("a", ".", f)
159               return s
160             end)
161  assert(co() == "aa")
162end
163
164
165-- errors in coroutines
166function foo ()
167  coroutine.yield(3)
168  error(foo)
169end
170
171function goo() foo() end
172x = coroutine.wrap(goo)
173assert(x() == 3)
174x = coroutine.create(goo)
175a,b = coroutine.resume(x)
176assert(a and b == 3)
177a,b = coroutine.resume(x)
178assert(not a and b == foo and coroutine.status(x) == "dead")
179a,b = coroutine.resume(x)
180assert(not a and string.find(b, "dead") and coroutine.status(x) == "dead")
181
182
183-- co-routines x for loop
184function all (a, n, k)
185  if k == 0 then coroutine.yield(a)
186  else
187    for i=1,n do
188      a[k] = i
189      all(a, n, k-1)
190    end
191  end
192end
193
194local a = 0
195for t in coroutine.wrap(function () all({}, 5, 4) end) do
196  a = a+1
197end
198assert(a == 5^4)
199
200
201-- access to locals of collected corroutines
202local C = {}; setmetatable(C, {__mode = "kv"})
203local x = coroutine.wrap (function ()
204            local a = 10
205            local function f () a = a+10; return a end
206            while true do
207              a = a+1
208              coroutine.yield(f)
209            end
210          end)
211
212C[1] = x;
213
214local f = x()
215assert(f() == 21 and x()() == 32 and x() == f)
216x = nil
217collectgarbage()
218assert(C[1] == nil)
219assert(f() == 43 and f() == 53)
220
221
222-- old bug: attempt to resume itself
223
224function co_func (current_co)
225  assert(coroutine.running() == current_co)
226  assert(coroutine.resume(current_co) == false)
227  coroutine.yield(10, 20)
228  assert(coroutine.resume(current_co) == false)
229  coroutine.yield(23)
230  return 10
231end
232
233local co = coroutine.create(co_func)
234local a,b,c = coroutine.resume(co, co)
235assert(a == true and b == 10 and c == 20)
236a,b = coroutine.resume(co, co)
237assert(a == true and b == 23)
238a,b = coroutine.resume(co, co)
239assert(a == true and b == 10)
240assert(coroutine.resume(co, co) == false)
241assert(coroutine.resume(co, co) == false)
242
243
244-- attempt to resume 'normal' coroutine
245local co1, co2
246co1 = coroutine.create(function () return co2() end)
247co2 = coroutine.wrap(function ()
248        assert(coroutine.status(co1) == 'normal')
249        assert(not coroutine.resume(co1))
250        coroutine.yield(3)
251      end)
252
253a,b = coroutine.resume(co1)
254assert(a and b == 3)
255assert(coroutine.status(co1) == 'dead')
256
257
258-- access to locals of erroneous coroutines
259local x = coroutine.create (function ()
260            local a = 10
261            _G.f = function () a=a+1; return a end
262            error('x')
263          end)
264
265assert(not coroutine.resume(x))
266-- overwrite previous position of local `a'
267assert(not coroutine.resume(x, 1, 1, 1, 1, 1, 1, 1))
268assert(_G.f() == 11)
269assert(_G.f() == 12)
270
271
272-- leaving a pending coroutine open
273_X = coroutine.wrap(function ()
274      local a = 10
275      local x = function () a = a+1 end
276      coroutine.yield()
277    end)
278
279_X()
280
281assert(coroutine.running() == main)
282
283
284
285-- testing yields inside metamethods
286
287local mt = {
288  __eq = function(a,b) coroutine.yield(nil, "eq"); return a.x == b.x end,
289  __lt = function(a,b) coroutine.yield(nil, "lt"); return a.x < b.x end,
290  __le = function(a,b) coroutine.yield(nil, "le"); return a - b <= 0 end,
291  __add = function(a,b) coroutine.yield(nil, "add"); return a.x + b.x end,
292  __sub = function(a,b) coroutine.yield(nil, "sub"); return a.x - b.x end,
293  __mod = function(a,b) coroutine.yield(nil, "mod"); return a.x % b.x end,
294  __unm = function(a,b) coroutine.yield(nil, "unm"); return -a.x end,
295
296  __concat = function(a,b)
297               coroutine.yield(nil, "concat");
298               a = type(a) == "table" and a.x or a
299               b = type(b) == "table" and b.x or b
300               return a .. b
301             end,
302  __index = function (t,k) coroutine.yield(nil, "idx"); return t.k[k] end,
303  __newindex = function (t,k,v) coroutine.yield(nil, "nidx"); t.k[k] = v end,
304}
305
306
307local function new (x)
308  return setmetatable({x = x, k = {}}, mt)
309end
310
311
312local a = new(10)
313local b = new(12)
314local c = new"hello"
315
316local function run (f, t)
317  local i = 1
318  local c = coroutine.wrap(f)
319  while true do
320    local res, stat = c()
321    if res then assert(t[i] == nil); return res, t end
322    assert(stat == t[i])
323    i = i + 1
324  end
325end
326
327
328assert(run(function () if (a>=b) then return '>=' else return '<' end end,
329       {"le", "sub"}) == "<")
330-- '<=' using '<'
331mt.__le = nil
332assert(run(function () if (a<=b) then return '<=' else return '>' end end,
333       {"lt"}) == "<=")
334assert(run(function () if (a==b) then return '==' else return '~=' end end,
335       {"eq"}) == "~=")
336
337assert(run(function () return a % b end, {"mod"}) == 10)
338
339assert(run(function () return a..b end, {"concat"}) == "1012")
340
341assert(run(function() return a .. b .. c .. a end,
342       {"concat", "concat", "concat"}) == "1012hello10")
343
344assert(run(function() return "a" .. "b" .. a .. "c" .. c .. b .. "x" end,
345       {"concat", "concat", "concat"}) == "ab10chello12x")
346
347
348-- testing yields inside 'for' iterators
349
350local f = function (s, i)
351      if i%2 == 0 then coroutine.yield(nil, "for") end
352      if i < s then return i + 1 end
353    end
354
355assert(run(function ()
356             local s = 0
357             for i in f, 4, 0 do s = s + i end
358             return s
359           end, {"for", "for", "for"}) == 10)
360
361
362return "OK"
363