--[[ --***************************************************************************** --* Copyright (C) 1994-2016 Lua.org, PUC-Rio. --* --* Permission is hereby granted, free of charge, to any person obtaining --* a copy of this software and associated documentation files (the --* "Software"), to deal in the Software without restriction, including --* without limitation the rights to use, copy, modify, merge, publish, --* distribute, sublicense, and/or sell copies of the Software, and to --* permit persons to whom the Software is furnished to do so, subject to --* the following conditions: --* --* The above copyright notice and this permission notice shall be --* included in all copies or substantial portions of the Software. --* --* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, --* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF --* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. --* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY --* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, --* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE --* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --***************************************************************************** --]] local f local main, ismain = coroutine.running() assert(type(main) == "thread" and ismain) assert(not coroutine.resume(main)) -- tests for multiple yield/resume arguments local function eqtab (t1, t2) assert(#t1 == #t2) for i = 1, #t1 do local v = t1[i] assert(t2[i] == v) end end _G.x = nil -- declare x function foo (a, ...) local x, y = coroutine.running() assert(x == f and y == false) -- next call should not corrupt coroutine (but must fail, -- as it attempts to resume the running coroutine) assert(coroutine.resume(f) == false) assert(coroutine.status(f) == "running") local arg = {...} for i=1,#arg do _G.x = {coroutine.yield(table.unpack(arg[i]))} end return table.unpack(a) end f = coroutine.create(foo) assert(type(f) == "thread" and coroutine.status(f) == "suspended") assert(string.find(tostring(f), "thread")) local s,a,b,c,d s,a,b,c,d = coroutine.resume(f, {1,2,3}, {}, {1}, {'a', 'b', 'c'}) assert(s and a == nil and coroutine.status(f) == "suspended") s,a,b,c,d = coroutine.resume(f) eqtab(_G.x, {}) assert(s and a == 1 and b == nil) s,a,b,c,d = coroutine.resume(f, 1, 2, 3) eqtab(_G.x, {1, 2, 3}) assert(s and a == 'a' and b == 'b' and c == 'c' and d == nil) s,a,b,c,d = coroutine.resume(f, "xuxu") eqtab(_G.x, {"xuxu"}) assert(s and a == 1 and b == 2 and c == 3 and d == nil) assert(coroutine.status(f) == "dead") s, a = coroutine.resume(f, "xuxu") assert(not s and string.find(a, "dead") and coroutine.status(f) == "dead") -- yields in tail calls local function foo (i) return coroutine.yield(i) end f = coroutine.wrap(function () for i=1,10 do assert(foo(i) == _G.x) end return 'a' end) for i=1,10 do _G.x = i; assert(f(i) == i) end _G.x = 'xuxu'; assert(f('xuxu') == 'a') -- recursive function pf (n, i) coroutine.yield(n) pf(n*i, i+1) end f = coroutine.wrap(pf) local s=1 for i=1,10 do assert(f(1, 1) == s) s = s*i end -- sieve implemented with co-routines -- generate all the numbers from 2 to n function gen (n) return coroutine.wrap(function () for i=2,n do coroutine.yield(i) end end) end -- filter the numbers generated by 'g', removing multiples of 'p' function filter (p, g) return coroutine.wrap(function () for n in g do if n%p ~= 0 then coroutine.yield(n) end end end) end -- generate primes up to 20 local x = gen(20) local a = {} while 1 do local n = x() if n == nil then break end table.insert(a, n) x = filter(n, x) end -- expect 8 primes and last one is 19 assert(#a == 8 and a[#a] == 19) x, a = nil -- yielding across C boundaries co = coroutine.wrap(function() coroutine.yield(20) return 30 end) assert(co() == 20) assert(co() == 30) local f = function (s, i) return coroutine.yield(i) end function f (a, b) a = coroutine.yield(a); error{a + b} end function g(x) return x[1]*2 end -- unyieldable C call do local function f (c) return c .. c end local co = coroutine.wrap(function (c) local s = string.gsub("a", ".", f) return s end) assert(co() == "aa") end -- errors in coroutines function foo () coroutine.yield(3) error(foo) end function goo() foo() end x = coroutine.wrap(goo) assert(x() == 3) x = coroutine.create(goo) a,b = coroutine.resume(x) assert(a and b == 3) a,b = coroutine.resume(x) assert(not a and b == foo and coroutine.status(x) == "dead") a,b = coroutine.resume(x) assert(not a and string.find(b, "dead") and coroutine.status(x) == "dead") -- co-routines x for loop function all (a, n, k) if k == 0 then coroutine.yield(a) else for i=1,n do a[k] = i all(a, n, k-1) end end end local a = 0 for t in coroutine.wrap(function () all({}, 5, 4) end) do a = a+1 end assert(a == 5^4) -- access to locals of collected corroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 local function f () a = a+10; return a end while true do a = a+1 coroutine.yield(f) end end) C[1] = x; local f = x() assert(f() == 21 and x()() == 32 and x() == f) x = nil collectgarbage() assert(C[1] == nil) assert(f() == 43 and f() == 53) -- old bug: attempt to resume itself function co_func (current_co) assert(coroutine.running() == current_co) assert(coroutine.resume(current_co) == false) coroutine.yield(10, 20) assert(coroutine.resume(current_co) == false) coroutine.yield(23) return 10 end local co = coroutine.create(co_func) local a,b,c = coroutine.resume(co, co) assert(a == true and b == 10 and c == 20) a,b = coroutine.resume(co, co) assert(a == true and b == 23) a,b = coroutine.resume(co, co) assert(a == true and b == 10) assert(coroutine.resume(co, co) == false) assert(coroutine.resume(co, co) == false) -- attempt to resume 'normal' coroutine local co1, co2 co1 = coroutine.create(function () return co2() end) co2 = coroutine.wrap(function () assert(coroutine.status(co1) == 'normal') assert(not coroutine.resume(co1)) coroutine.yield(3) end) a,b = coroutine.resume(co1) assert(a and b == 3) assert(coroutine.status(co1) == 'dead') -- access to locals of erroneous coroutines local x = coroutine.create (function () local a = 10 _G.f = function () a=a+1; return a end error('x') end) assert(not coroutine.resume(x)) -- overwrite previous position of local `a' assert(not coroutine.resume(x, 1, 1, 1, 1, 1, 1, 1)) assert(_G.f() == 11) assert(_G.f() == 12) -- leaving a pending coroutine open _X = coroutine.wrap(function () local a = 10 local x = function () a = a+1 end coroutine.yield() end) _X() assert(coroutine.running() == main) -- testing yields inside metamethods local mt = { __eq = function(a,b) coroutine.yield(nil, "eq"); return a.x == b.x end, __lt = function(a,b) coroutine.yield(nil, "lt"); return a.x < b.x end, __le = function(a,b) coroutine.yield(nil, "le"); return a - b <= 0 end, __add = function(a,b) coroutine.yield(nil, "add"); return a.x + b.x end, __sub = function(a,b) coroutine.yield(nil, "sub"); return a.x - b.x end, __mod = function(a,b) coroutine.yield(nil, "mod"); return a.x % b.x end, __unm = function(a,b) coroutine.yield(nil, "unm"); return -a.x end, __concat = function(a,b) coroutine.yield(nil, "concat"); a = type(a) == "table" and a.x or a b = type(b) == "table" and b.x or b return a .. b end, __index = function (t,k) coroutine.yield(nil, "idx"); return t.k[k] end, __newindex = function (t,k,v) coroutine.yield(nil, "nidx"); t.k[k] = v end, } local function new (x) return setmetatable({x = x, k = {}}, mt) end local a = new(10) local b = new(12) local c = new"hello" local function run (f, t) local i = 1 local c = coroutine.wrap(f) while true do local res, stat = c() if res then assert(t[i] == nil); return res, t end assert(stat == t[i]) i = i + 1 end end assert(run(function () if (a>=b) then return '>=' else return '<' end end, {"le", "sub"}) == "<") -- '<=' using '<' mt.__le = nil assert(run(function () if (a<=b) then return '<=' else return '>' end end, {"lt"}) == "<=") assert(run(function () if (a==b) then return '==' else return '~=' end end, {"eq"}) == "~=") assert(run(function () return a % b end, {"mod"}) == 10) assert(run(function () return a..b end, {"concat"}) == "1012") assert(run(function() return a .. b .. c .. a end, {"concat", "concat", "concat"}) == "1012hello10") assert(run(function() return "a" .. "b" .. a .. "c" .. c .. b .. "x" end, {"concat", "concat", "concat"}) == "ab10chello12x") -- testing yields inside 'for' iterators local f = function (s, i) if i%2 == 0 then coroutine.yield(nil, "for") end if i < s then return i + 1 end end assert(run(function () local s = 0 for i in f, 4, 0 do s = s + i end return s end, {"for", "for", "for"}) == 10) return "OK"