1 /*
2  *
3  * luamm:  C++ binding for lua
4  *
5  * Copyright (C) 2010 Pavel Labath et al.
6  *
7  * This program is free software: you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation, either version 3 of the License, or
10  * (at your option) any later version.
11  *
12  * This program is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  * You should have received a copy of the GNU General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 #include <config.h>
22 
23 #include "luamm.hh"
24 
25 namespace lua {
26 namespace {
27 
28 #if LUA_VERSION_NUM >= 502
29 // These two functions were deprecated in 5.2. Limited backwards compatibility
30 // is provided by macros. We want them as real functions, because we take their
31 // addresses.
32 
33 #undef lua_equal
lua_equal(lua_State * L,int index1,int index2)34 int lua_equal(lua_State *L, int index1, int index2) {
35   return lua_compare(L, index1, index2, LUA_OPEQ);
36 }
37 
38 #undef lua_lessthan
lua_lessthan(lua_State * L,int index1,int index2)39 int lua_lessthan(lua_State *L, int index1, int index2) {
40   return lua_compare(L, index1, index2, LUA_OPLT);
41 }
42 #endif
43 
44 // keys for storing values in lua registry
45 const char cpp_exception_metatable[] = "lua::cpp_exception_metatable";
46 const char cpp_function_metatable[] = "lua::cpp_function_metatable";
47 const char lua_exception_namespace[] = "lua::lua_exception_namespace";
48 const char this_cpp_object[] = "lua::this_cpp_object";
49 
50 // converts C++ exceptions to strings, so lua can do something with them
exception_to_string(lua_State * l)51 int exception_to_string(lua_State *l) {
52   auto *ptr = static_cast<std::exception_ptr *>(lua_touserdata(l, -1));
53   assert(ptr);
54   try {
55     std::rethrow_exception(*ptr);
56   } catch (std::exception &e) { lua_pushstring(l, e.what()); } catch (...) {
57     lua_pushstring(l, typeid(*ptr).name());
58   }
59   return 1;
60 }
61 
absindex(lua_State * l,int index)62 int absindex(lua_State *l, int index) {
63   return index < 0 && -index <= lua_gettop(l) ? lua_gettop(l) + 1 + index
64                                               : index;
65 }
66 
67 // Just like getfield(), only without calling metamethods (or throwing random
68 // exceptions)
rawgetfield(lua_State * l,int index,const char * k)69 inline void rawgetfield(lua_State *l, int index, const char *k) {
70   index = absindex(l, index);
71   if (lua_checkstack(l, 1) == 0) { throw std::bad_alloc(); }
72 
73   lua_pushstring(l, k);
74   lua_rawget(l, index);
75 }
76 
77 // Just like setfield(), only without calling metamethods (or throwing random
78 // exceptions)
rawsetfield(lua_State * l,int index,const char * k)79 inline void rawsetfield(lua_State *l, int index, const char *k) {
80   index = absindex(l, index);
81   if (lua_checkstack(l, 2) == 0) { throw std::bad_alloc(); }
82 
83   lua_pushstring(l, k);
84   lua_insert(l, -2);
85   lua_rawset(l, index);
86 }
87 
closure_trampoline(lua_State * l)88 int closure_trampoline(lua_State *l) {
89   lua_checkstack(l, 2);
90   rawgetfield(l, REGISTRYINDEX, this_cpp_object);
91   assert(lua_islightuserdata(l, -1));
92   auto *L = static_cast<state *>(lua_touserdata(l, -1));
93   lua_pop(l, 1);
94 
95   try {
96     auto *fn = static_cast<cpp_function *>(L->touserdata(lua_upvalueindex(1)));
97     assert(fn);
98     return (*fn)(L);
99   } catch (lua::exception &e) {
100     // rethrow lua errors as such
101     e.push_lua_error(L);
102   } catch (...) {
103     // C++ exceptions (pointers to them, actually) are stored as lua userdata
104     // and then thrown
105     L->createuserdata<std::exception_ptr>(std::current_exception());
106     L->rawgetfield(REGISTRYINDEX, cpp_exception_metatable);
107     L->setmetatable(-2);
108   }
109 
110   // lua_error does longjmp(), so destructors for objects in this function will
111   // not be called
112   return lua_error(l);
113 }
114 
115 /*
116  * This function is called when lua encounters an error outside of any protected
117  * environment
118  * Throwing the exception through lua code appears to work, even if it was
119  * compiled without -fexceptions. If it turns out, it fails in some conditions,
120  * it could be replaced with some longjmp() magic. But that shouldn't be
121  * necessary, as this function will not be called under normal conditions (we
122  * execute everything in protected mode).
123  */
panic_throw(lua_State * l)124 int panic_throw(lua_State *l) {
125   if (lua_checkstack(l, 1) == 0) { throw std::bad_alloc(); }
126 
127   rawgetfield(l, REGISTRYINDEX, this_cpp_object);
128   assert(lua_islightuserdata(l, -1));
129   auto *L = static_cast<state *>(lua_touserdata(l, -1));
130   lua_pop(l, 1);
131 
132   throw lua::exception(L);
133 }
134 
135 // protected mode wrappers for various lua functions
safe_concat_trampoline(lua_State * l)136 int safe_concat_trampoline(lua_State *l) {
137   lua_concat(l, lua_gettop(l));
138   return 1;
139 }
140 
141 template <int (*compare)(lua_State *, int, int)>
safe_compare_trampoline(lua_State * l)142 int safe_compare_trampoline(lua_State *l) {
143   int r = compare(l, 1, 2);
144   lua_pop(l, 2);
145   lua_pushinteger(l, r);
146   return 1;
147 }
148 
safe_gc_trampoline(lua_State * l)149 int safe_gc_trampoline(lua_State *l) {
150   int what = lua_tointeger(l, -2);
151   int data = lua_tointeger(l, -1);
152   lua_pop(l, 2);
153   lua_pushinteger(l, lua_gc(l, what, data));
154   return 1;
155 }
156 
157 template <void (*misc)(lua_State *, int), int nresults>
safe_misc_trampoline(lua_State * l)158 int safe_misc_trampoline(lua_State *l) {
159   misc(l, 1);
160   return nresults;
161 }
162 
163 // Overloaded for Lua 5.3+ as lua_gettable and others return an int
164 template <int (*misc)(lua_State *, int), int nresults>
safe_misc_trampoline(lua_State * l)165 int safe_misc_trampoline(lua_State *l) {
166   misc(l, 1);
167   return nresults;
168 }
169 
safe_next_trampoline(lua_State * l)170 int safe_next_trampoline(lua_State *l) {
171   int r = lua_next(l, 1);
172   lua_checkstack(l, 1);
173   lua_pushinteger(l, r);
174   return r != 0 ? 3 : 1;
175 }
176 
177 }  // namespace
178 
get_error_msg(state * L)179 std::string exception::get_error_msg(state *L) {
180   static const std::string default_msg("Unknown lua exception");
181 
182   try {
183     return L->tostring(-1);
184   } catch (not_string_error &e) { return default_msg; }
185 }
186 
exception(state * l)187 exception::exception(state *l) : std::runtime_error(get_error_msg(l)), L(l) {
188   L->checkstack(1);
189 
190   L->rawgetfield(REGISTRYINDEX, lua_exception_namespace);
191   L->insert(-2);
192   key = L->ref(-2);
193   L->pop(1);
194 }
195 
~exception()196 exception::~exception() {
197   if (L == nullptr) { return; }
198   L->checkstack(1);
199 
200   L->rawgetfield(REGISTRYINDEX, lua_exception_namespace);
201   L->unref(-1, key);
202   L->pop();
203 }
204 
push_lua_error(state * l)205 void exception::push_lua_error(state *l) {
206   if (l != L) {
207     throw std::runtime_error(
208         "Cannot transfer exceptions between different lua contexts");
209   }
210   l->checkstack(2);
211 
212   l->rawgetfield(REGISTRYINDEX, lua_exception_namespace);
213   l->rawgeti(-1, key);
214   l->replace(-2);
215 }
216 
state()217 state::state() {
218   if (lua_State *l = luaL_newstate()) {
219     cobj.reset(l, &lua_close);
220   } else {
221     // docs say this can happen only in case of a memory allocation error
222     throw std::bad_alloc();
223   }
224 
225   // set our panic function
226   lua_atpanic(cobj.get(), panic_throw);
227 
228   checkstack(2);
229 
230   // store a pointer to ourselves
231   pushlightuserdata(this);
232   rawsetfield(REGISTRYINDEX, this_cpp_object);
233 
234   // a metatable for C++ exceptions travelling through lua code
235   newmetatable(cpp_exception_metatable);
236   lua_pushcfunction(cobj.get(), &exception_to_string);
237   rawsetfield(-2, "__tostring");
238   pushboolean(false);
239   rawsetfield(-2, "__metatable");
240   pushdestructor<std::exception_ptr>();
241   rawsetfield(-2, "__gc");
242   pop();
243 
244   // a metatable for C++ functions callable from lua code
245   newmetatable(cpp_function_metatable);
246   pushboolean(false);
247   rawsetfield(-2, "__metatable");
248   pushdestructor<cpp_function>();
249   rawsetfield(-2, "__gc");
250   pop();
251 
252   // while they're travelling through C++ code, lua exceptions will reside here
253   newtable();
254   rawsetfield(REGISTRYINDEX, lua_exception_namespace);
255 
256   luaL_openlibs(cobj.get());
257 }
258 
call(int nargs,int nresults,int errfunc)259 void state::call(int nargs, int nresults, int errfunc) {
260   int r = lua_pcall(cobj.get(), nargs, nresults, errfunc);
261   if (r == 0) { return; }
262 
263   if (r == LUA_ERRMEM) {
264     // memory allocation error, cross your fingers
265     throw std::bad_alloc();
266   }
267 
268   checkstack(3);
269   rawgetfield(REGISTRYINDEX, cpp_exception_metatable);
270   if (getmetatable(-2)) {
271     if (rawequal(-1, -2)) {
272       // it's a C++ exception, rethrow it
273       auto *ptr = static_cast<std::exception_ptr *>(touserdata(-3));
274       assert(ptr);
275 
276       /*
277        * we create a copy, so we can pop the object without fearing the
278        * exception will be collected by lua's GC
279        */
280       std::exception_ptr t(*ptr);
281       ptr = nullptr;
282       pop(3);
283       std::rethrow_exception(t);
284     }
285     pop(2);
286   }
287   // it's a lua exception, wrap it
288   if (r == LUA_ERRERR) { throw lua::errfunc_error(this); }
289   { throw lua::exception(this); }
290 }
291 
checkstack(int extra)292 void state::checkstack(int extra) {
293   if (lua_checkstack(cobj.get(), extra) == 0) { throw std::bad_alloc(); }
294 }
295 
concat(int n)296 void state::concat(int n) {
297   assert(n >= 0);
298   checkstack(1);
299   lua_pushcfunction(cobj.get(), safe_concat_trampoline);
300   insert(-n - 1);
301   call(n, 1, 0);
302 }
303 
equal(int index1,int index2)304 bool state::equal(int index1, int index2) {
305   // avoid pcall overhead in trivial cases
306   if (rawequal(index1, index2)) { return true; }
307 
308   return safe_compare(&safe_compare_trampoline<lua_equal>, index1, index2);
309 }
310 
gc(int what,int data)311 int state::gc(int what, int data) {
312   checkstack(3);
313   lua_pushcfunction(cobj.get(), safe_gc_trampoline);
314   pushinteger(what);
315   pushinteger(data);
316   call(2, 1, 0);
317   assert(state::_isnumber(-1));
318   int r = tointeger(-1);
319   pop();
320   return r;
321 }
322 
getfield(int index,const char * k)323 void state::getfield(int index, const char *k) {
324   checkstack(1);
325   index = absindex(index);
326   pushstring(k);
327   gettable(index);
328 }
329 
getglobal(const char * name)330 void state::getglobal(const char *name) {
331 #if LUA_VERSION_NUM >= 502
332   checkstack(1);
333   pushinteger(LUA_RIDX_GLOBALS);
334   gettable(REGISTRYINDEX);
335   getfield(-1, name);
336   replace(-2);
337 #else
338   getfield(LUA_GLOBALSINDEX, name);
339 #endif
340 }
341 
gettable(int index)342 void state::gettable(int index) {
343   checkstack(2);
344   pushvalue(index);
345   insert(-2);
346   lua_pushcfunction(cobj.get(), (&safe_misc_trampoline<&lua_gettable, 1>));
347   insert(-3);
348   call(2, 1, 0);
349 }
350 
lessthan(int index1,int index2)351 bool state::lessthan(int index1, int index2) {
352   return safe_compare(&safe_compare_trampoline<&lua_lessthan>, index1, index2);
353 }
354 
loadfile(const char * filename)355 void state::loadfile(const char *filename) {
356   switch (luaL_loadfile(cobj.get(), filename)) {
357     case 0:
358       return;
359     case LUA_ERRSYNTAX:
360       throw lua::syntax_error(this);
361     case LUA_ERRFILE:
362       throw lua::file_error(this);
363     case LUA_ERRMEM:
364       throw std::bad_alloc();
365     default:
366       assert(0);
367   }
368 }
369 
loadstring(const char * s)370 void state::loadstring(const char *s) {
371   switch (luaL_loadstring(cobj.get(), s)) {
372     case 0:
373       return;
374     case LUA_ERRSYNTAX:
375       throw lua::syntax_error(this);
376     case LUA_ERRMEM:
377       throw std::bad_alloc();
378     default:
379       assert(0);
380   }
381 }
382 
next(int index)383 bool state::next(int index) {
384   checkstack(2);
385   pushvalue(index);
386   insert(-2);
387   lua_pushcfunction(cobj.get(), &safe_next_trampoline);
388   insert(-3);
389 
390   call(2, MULTRET, 0);
391 
392   assert(state::_isnumber(-1));
393   int r = tointeger(-1);
394   pop();
395   return r != 0;
396 }
397 
pushclosure(const cpp_function & fn,int n)398 void state::pushclosure(const cpp_function &fn, int n) {
399   checkstack(2);
400 
401   createuserdata<cpp_function>(fn);
402   rawgetfield(REGISTRYINDEX, cpp_function_metatable);
403   setmetatable(-2);
404 
405   insert(-n - 1);
406   lua_pushcclosure(cobj.get(), &closure_trampoline, n + 1);
407 }
408 
rawgetfield(int index,const char * k)409 void state::rawgetfield(int index, const char *k) {
410   lua::rawgetfield(cobj.get(), index, k);
411 }
412 
rawsetfield(int index,const char * k)413 void state::rawsetfield(int index, const char *k) {
414   lua::rawsetfield(cobj.get(), index, k);
415 }
416 
safe_compare(lua_CFunction trampoline,int index1,int index2)417 bool state::safe_compare(lua_CFunction trampoline, int index1, int index2) {
418   // if one of the indexes is invalid, return false
419   if (isnone(index1) || isnone(index2)) { return false; }
420 
421   // convert relative indexes into absolute
422   index1 = absindex(index1);
423   index2 = absindex(index2);
424 
425   checkstack(3);
426   lua_pushcfunction(cobj.get(), trampoline);
427   pushvalue(index1);
428   pushvalue(index2);
429   call(2, 1, 0);
430   assert(state::_isnumber(-1));
431   int r = tointeger(-1);
432   pop();
433   return r != 0;
434 }
435 
setfield(int index,const char * k)436 void state::setfield(int index, const char *k) {
437   checkstack(1);
438   index = absindex(index);
439   pushstring(k);
440   insert(-2);
441   settable(index);
442 }
443 
setglobal(const char * name)444 void state::setglobal(const char *name) {
445 #if LUA_VERSION_NUM >= 502
446   stack_sentry s(*this, -1);
447   checkstack(1);
448   pushinteger(LUA_RIDX_GLOBALS);
449   gettable(REGISTRYINDEX);
450   insert(-2);
451   setfield(-2, name);
452   pop();
453 #else
454   setfield(LUA_GLOBALSINDEX, name);
455 #endif
456 }
457 
settable(int index)458 void state::settable(int index) {
459   checkstack(2);
460   pushvalue(index);
461   insert(-3);
462   lua_pushcfunction(cobj.get(), (&safe_misc_trampoline<&lua_settable, 0>));
463   insert(-4);
464   call(3, 0, 0);
465 }
466 
tostring(int index)467 std::string state::tostring(int index) {
468   size_t len;
469   const char *str = lua_tolstring(cobj.get(), index, &len);
470   if (str == nullptr) { throw not_string_error(); }
471   return std::string(str, len);
472 }
473 }  // namespace lua
474