1 /*=========================================================================*\
2 * Simple exception support
3 * LuaSocket toolkit
4 \*=========================================================================*/
5 #include <stdio.h>
6 
7 #include "lua.h"
8 #include "lauxlib.h"
9 #include "compat.h"
10 
11 #include "except.h"
12 
13 #if LUA_VERSION_NUM < 502
14 #define lua_pcallk(L, na, nr, err, ctx, cont) \
15     (((void)ctx),((void)cont),lua_pcall(L, na, nr, err))
16 #endif
17 
18 #if LUA_VERSION_NUM < 503
19 typedef int lua_KContext;
20 #endif
21 
22 /*=========================================================================*\
23 * Internal function prototypes.
24 \*=========================================================================*/
25 static int global_protect(lua_State *L);
26 static int global_newtry(lua_State *L);
27 static int protected_(lua_State *L);
28 static int finalize(lua_State *L);
29 static int do_nothing(lua_State *L);
30 
31 /* except functions */
32 static luaL_Reg func[] = {
33     {"newtry",    global_newtry},
34     {"protect",   global_protect},
35     {NULL,        NULL}
36 };
37 
38 /*-------------------------------------------------------------------------*\
39 * Try factory
40 \*-------------------------------------------------------------------------*/
wrap(lua_State * L)41 static void wrap(lua_State *L) {
42     lua_createtable(L, 1, 0);
43     lua_pushvalue(L, -2);
44     lua_rawseti(L, -2, 1);
45     lua_pushvalue(L, lua_upvalueindex(1));
46     lua_setmetatable(L, -2);
47 }
48 
finalize(lua_State * L)49 static int finalize(lua_State *L) {
50     if (!lua_toboolean(L, 1)) {
51         lua_pushvalue(L, lua_upvalueindex(2));
52         lua_call(L, 0, 0);
53         lua_settop(L, 2);
54         wrap(L);
55         lua_error(L);
56         return 0;
57     } else return lua_gettop(L);
58 }
59 
do_nothing(lua_State * L)60 static int do_nothing(lua_State *L) {
61     (void) L;
62     return 0;
63 }
64 
global_newtry(lua_State * L)65 static int global_newtry(lua_State *L) {
66     lua_settop(L, 1);
67     if (lua_isnil(L, 1)) lua_pushcfunction(L, do_nothing);
68     lua_pushvalue(L, lua_upvalueindex(1));
69     lua_insert(L, -2);
70     lua_pushcclosure(L, finalize, 2);
71     return 1;
72 }
73 
74 /*-------------------------------------------------------------------------*\
75 * Protect factory
76 \*-------------------------------------------------------------------------*/
unwrap(lua_State * L)77 static int unwrap(lua_State *L) {
78     if (lua_istable(L, -1) && lua_getmetatable(L, -1)) {
79         int r = lua_rawequal(L, -1, lua_upvalueindex(1));
80         lua_pop(L, 1);
81         if (r) {
82             lua_pushnil(L);
83             lua_rawgeti(L, -2, 1);
84             return 1;
85         }
86     }
87     return 0;
88 }
89 
protected_finish(lua_State * L,int status,lua_KContext ctx)90 static int protected_finish(lua_State *L, int status, lua_KContext ctx) {
91     (void)ctx;
92     if (status != 0 && status != LUA_YIELD) {
93         if (unwrap(L)) return 2;
94         else return lua_error(L);
95     } else return lua_gettop(L);
96 }
97 
98 #if LUA_VERSION_NUM == 502
protected_cont(lua_State * L)99 static int protected_cont(lua_State *L) {
100     int ctx = 0;
101     int status = lua_getctx(L, &ctx);
102     return protected_finish(L, status, ctx);
103 }
104 #else
105 #define protected_cont protected_finish
106 #endif
107 
protected_(lua_State * L)108 static int protected_(lua_State *L) {
109     int status;
110     lua_pushvalue(L, lua_upvalueindex(2));
111     lua_insert(L, 1);
112     status = lua_pcallk(L, lua_gettop(L) - 1, LUA_MULTRET, 0, 0, protected_cont);
113     return protected_finish(L, status, 0);
114 }
115 
global_protect(lua_State * L)116 static int global_protect(lua_State *L) {
117     lua_settop(L, 1);
118     lua_pushvalue(L, lua_upvalueindex(1));
119     lua_insert(L, 1);
120     lua_pushcclosure(L, protected_, 2);
121     return 1;
122 }
123 
124 /*-------------------------------------------------------------------------*\
125 * Init module
126 \*-------------------------------------------------------------------------*/
except_open(lua_State * L)127 int except_open(lua_State *L) {
128     lua_newtable(L); /* metatable for wrapped exceptions */
129     lua_pushboolean(L, 0);
130     lua_setfield(L, -2, "__metatable");
131     luaL_setfuncs(L, func, 1);
132     return 0;
133 }
134