1 /*
2 ** LuaSQL, ODBC driver
3 ** Authors: Pedro Rabinovitch, Roberto Ierusalimschy, Diego Nehab,
4 ** Tomas Guisasola
5 ** See Copyright Notice in license.html
6 */
7 
8 #include <assert.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <time.h>
13 
14 #if defined(_WIN32)
15 #include <windows.h>
16 #include <sqlext.h>
17 #elif defined(INFORMIX)
18 #include "infxcli.h"
19 #elif defined(UNIXODBC)
20 #include "sql.h"
21 #include "sqltypes.h"
22 #include "sqlext.h"
23 #endif
24 
25 #include "lua.h"
26 #include "lauxlib.h"
27 
28 #include "luasql.h"
29 
30 #define LUASQL_ENVIRONMENT_ODBC "ODBC environment"
31 #define LUASQL_CONNECTION_ODBC "ODBC connection"
32 #define LUASQL_STATEMENT_ODBC "ODBC statement"
33 #define LUASQL_CURSOR_ODBC "ODBC cursor"
34 
35 /* holds data for paramter binding */
36 typedef struct {
37 	SQLPOINTER buf;
38 	SQLINTEGER len;
39 	SQLINTEGER type;
40 } param_data;
41 
42 /* general form of the driver objects */
43 typedef struct {
44 	short closed;
45 	int   lock;
46 } obj_data;
47 
48 typedef struct {
49 	short         closed;
50 	int           lock;               /* lock count for open connections */
51 	SQLHENV       henv;               /* environment handle */
52 } env_data;
53 
54 typedef struct {
55 	short         closed;
56 	int           lock;               /* lock count for open statements */
57 	env_data      *env;               /* the connection's environment */
58 	SQLHDBC       hdbc;               /* database connection handle */
59 } conn_data;
60 
61 typedef struct {
62 	short         closed;
63 	int           lock;               /* lock count for open cursors */
64 	unsigned char hidden;             /* these statement was created indirectly */
65 	conn_data     *conn;              /* the statement's connection */
66 	SQLHSTMT      hstmt;              /* statement handle */
67 	SQLSMALLINT   numparams;          /* number of input parameters */
68 	int           paramtypes;         /* reference to param type table */
69 	param_data    *params;            /* array of parater data */
70 } stmt_data;
71 
72 typedef struct {
73 	short         closed;
74 	stmt_data     *stmt;              /* the cursor's statement */
75 	int           numcols;            /* number of columns */
76 	int           coltypes, colnames; /* reference to column information tables */
77 } cur_data;
78 
79 
80 /* we are lazy */
81 #define hENV SQL_HANDLE_ENV
82 #define hSTMT SQL_HANDLE_STMT
83 #define hDBC SQL_HANDLE_DBC
84 
error(SQLRETURN a)85 static int error(SQLRETURN a)
86 {
87 	return (a != SQL_SUCCESS) && (a != SQL_SUCCESS_WITH_INFO) && (a != SQL_NO_DATA);
88 }
89 
90 
91 /*
92 ** Registers a given C object in the registry to avoid GC
93 */
luasql_registerobj(lua_State * L,int index,void * obj)94 static void luasql_registerobj(lua_State *L, int index, void *obj)
95 {
96 	lua_pushvalue(L, index);
97 	lua_pushlightuserdata(L, obj);
98 	lua_pushvalue(L, -2);
99 	lua_settable(L, LUA_REGISTRYINDEX);
100 	lua_pop(L, 1);
101 }
102 
103 /*
104 ** Unregisters a given C object from the registry
105 */
luasql_unregisterobj(lua_State * L,void * obj)106 static void luasql_unregisterobj(lua_State *L, void *obj)
107 {
108 	lua_pushlightuserdata(L, obj);
109 	lua_pushnil(L);
110 	lua_settable(L, LUA_REGISTRYINDEX);
111 }
112 
lock_obj(lua_State * L,int indx,void * ptr)113 static int lock_obj(lua_State *L, int indx, void *ptr)
114 {
115 	obj_data *obj = (obj_data *)ptr;
116 
117 	luasql_registerobj(L, indx, obj);
118 	return ++obj->lock;
119 }
120 
unlock_obj(lua_State * L,void * ptr)121 static int unlock_obj(lua_State *L, void *ptr)
122 {
123 	obj_data *obj = (obj_data *)ptr;
124 
125 	if(--obj->lock == 0) {
126 		luasql_unregisterobj(L, obj);
127 	}
128 
129 	return obj->lock;
130 }
131 
132 /*
133 ** Check for valid environment.
134 */
getenvironment(lua_State * L,int i)135 static env_data *getenvironment (lua_State *L, int i)
136 {
137 	env_data *env = (env_data *)luaL_checkudata (L, i, LUASQL_ENVIRONMENT_ODBC);
138 	luaL_argcheck (L, env != NULL, i, LUASQL_PREFIX"environment expected");
139 	luaL_argcheck (L, !env->closed, i, LUASQL_PREFIX"environment is closed");
140 	return env;
141 }
142 
143 
144 /*
145 ** Check for valid connection.
146 */
getconnection(lua_State * L,int i)147 static conn_data *getconnection (lua_State *L, int i)
148 {
149 	conn_data *conn = (conn_data *)luaL_checkudata (L, i, LUASQL_CONNECTION_ODBC);
150 	luaL_argcheck (L, conn != NULL, i, LUASQL_PREFIX"connection expected");
151 	luaL_argcheck (L, !conn->closed, i, LUASQL_PREFIX"connection is closed");
152 	return conn;
153 }
154 
155 
156 /*
157 ** Check for valid connection.
158 */
getstatement(lua_State * L,int i)159 static stmt_data *getstatement (lua_State *L, int i)
160 {
161 	stmt_data *stmt = (stmt_data *)luaL_checkudata (L, i, LUASQL_STATEMENT_ODBC);
162 	luaL_argcheck (L, stmt != NULL, i, LUASQL_PREFIX"statement expected");
163 	luaL_argcheck (L, !stmt->closed, i, LUASQL_PREFIX"statement is closed");
164 	return stmt;
165 }
166 
167 
168 /*
169 ** Check for valid cursor.
170 */
getcursor(lua_State * L,int i)171 static cur_data *getcursor (lua_State *L, int i)
172 {
173 	cur_data *cursor = (cur_data *)luaL_checkudata (L, i, LUASQL_CURSOR_ODBC);
174 	luaL_argcheck (L, cursor != NULL, i, LUASQL_PREFIX"cursor expected");
175 	luaL_argcheck (L, !cursor->closed, i, LUASQL_PREFIX"cursor is closed");
176 	return cursor;
177 }
178 
179 
180 /*
181 ** Pushes true and returns 1
182 */
pass(lua_State * L)183 static int pass(lua_State *L) {
184     lua_pushboolean (L, 1);
185     return 1;
186 }
187 
188 
189 /*
190 ** Fails with error message from ODBC
191 ** Inputs:
192 **   type: type of handle used in operation
193 **   handle: handle used in operation
194 */
fail(lua_State * L,const SQLSMALLINT type,const SQLHANDLE handle)195 static int fail(lua_State *L,  const SQLSMALLINT type, const SQLHANDLE handle) {
196     SQLCHAR State[6];
197     SQLINTEGER NativeError;
198     SQLSMALLINT MsgSize, i;
199     SQLRETURN ret;
200     SQLCHAR Msg[SQL_MAX_MESSAGE_LENGTH];
201     luaL_Buffer b;
202     lua_pushnil(L);
203 
204     luaL_buffinit(L, &b);
205     i = 1;
206     while (1) {
207         ret = SQLGetDiagRec(type, handle, i, State, &NativeError, Msg,
208                 sizeof(Msg), &MsgSize);
209         if (ret == SQL_NO_DATA) break;
210         luaL_addlstring(&b, (char*)Msg, MsgSize);
211         luaL_addchar(&b, '\n');
212         i++;
213     }
214     luaL_pushresult(&b);
215     return 2;
216 }
217 
malloc_stmt_params(SQLSMALLINT c)218 static param_data *malloc_stmt_params(SQLSMALLINT c)
219 {
220 	param_data *p = (param_data *)malloc(sizeof(param_data)*c);
221 	memset(p, 0, sizeof(param_data)*c);
222 
223 	return p;
224 }
225 
free_stmt_params(param_data * data,SQLSMALLINT c)226 static param_data *free_stmt_params(param_data *data, SQLSMALLINT c)
227 {
228 	if(data != NULL) {
229 		param_data *p = data;
230 
231 		for(; c>0; ++p, --c) {
232 			free(p->buf);
233 		}
234 		free(data);
235 	}
236 
237 	return NULL;
238 }
239 
240 /*
241 ** Shuts a statement
242 ** Returns non-zero on error
243 */
stmt_shut(lua_State * L,stmt_data * stmt)244 static int stmt_shut(lua_State *L, stmt_data *stmt)
245 {
246 	SQLRETURN ret;
247 
248 	unlock_obj(L, stmt->conn);
249 	stmt->closed = 1;
250 
251 	luaL_unref (L, LUA_REGISTRYINDEX, stmt->paramtypes);
252 	stmt->paramtypes = LUA_NOREF;
253 	stmt->params = free_stmt_params(stmt->params, stmt->numparams);
254 
255 	ret = SQLFreeHandle(hSTMT, stmt->hstmt);
256 	if (error(ret)) {
257 		return 1;
258 	}
259 
260 	return 0;
261 }
262 
263 /*
264 ** Closes a cursor directly
265 ** Returns non-zero on error
266 */
cur_shut(lua_State * L,cur_data * cur)267 static int cur_shut(lua_State *L, cur_data *cur)
268 {
269 	/* Nullify structure fields. */
270 	cur->closed = 1;
271 	if (error(SQLCloseCursor(cur->stmt->hstmt))) {
272 		return fail(L, hSTMT, cur->stmt->hstmt);
273 	}
274 
275 	/* release col tables */
276 	luaL_unref (L, LUA_REGISTRYINDEX, cur->colnames);
277 	luaL_unref (L, LUA_REGISTRYINDEX, cur->coltypes);
278 
279 	/* release statement and, if hidden, shut it */
280 	if(unlock_obj(L, cur->stmt) == 0) {
281 		if(cur->stmt->hidden) {
282 			return stmt_shut(L, cur->stmt);
283 		}
284 	}
285 
286 	return 0;
287 }
288 
289 
290 /*
291 ** Returns the name of an equivalent lua type for a SQL type.
292 */
sqltypetolua(const SQLSMALLINT type)293 static const char *sqltypetolua (const SQLSMALLINT type) {
294     switch (type) {
295         case SQL_UNKNOWN_TYPE: case SQL_CHAR: case SQL_VARCHAR:
296         case SQL_TYPE_DATE: case SQL_TYPE_TIME: case SQL_TYPE_TIMESTAMP:
297         case SQL_DATE: case SQL_INTERVAL: case SQL_TIMESTAMP:
298         case SQL_LONGVARCHAR:
299         case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR:
300             return "string";
301         case SQL_BIGINT: case SQL_TINYINT:
302         case SQL_INTEGER: case SQL_SMALLINT:
303 #if LUA_VERSION_NUM>=503
304 			return "integer";
305 #endif
306 		case SQL_NUMERIC: case SQL_DECIMAL:
307         case SQL_FLOAT: case SQL_REAL: case SQL_DOUBLE:
308             return "number";
309         case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY:
310             return "binary";	/* !!!!!! nao seria string? */
311         case SQL_BIT:
312             return "boolean";
313         default:
314             assert(0);
315             return NULL;
316     }
317 }
318 
319 
320 /*
321 ** Retrieves data from the i_th column in the current row
322 ** Input
323 **   types: index in stack of column types table
324 **   hstmt: statement handle
325 **   i: column number
326 ** Returns:
327 **   0 if successfull, non-zero otherwise;
328 */
push_column(lua_State * L,int coltypes,const SQLHSTMT hstmt,SQLUSMALLINT i)329 static int push_column(lua_State *L, int coltypes, const SQLHSTMT hstmt,
330         SQLUSMALLINT i) {
331     const char *tname;
332     char type;
333     /* get column type from types table */
334 	lua_rawgeti (L, LUA_REGISTRYINDEX, coltypes);
335 	lua_rawgeti (L, -1, i);	/* typename of the column */
336     tname = lua_tostring(L, -1);
337     if (!tname)
338 		return luasql_faildirect(L, "invalid type in table.");
339     type = tname[1];
340     lua_pop(L, 2);	/* pops type name and coltypes table */
341 
342     /* deal with data according to type */
343     switch (type) {
344         /* nUmber */
345         case 'u': {
346 			SQLDOUBLE num;
347 			SQLLEN got;
348 			SQLRETURN rc = SQLGetData(hstmt, i, SQL_C_DOUBLE, &num, 0, &got);
349 			if (error(rc))
350 				return fail(L, hSTMT, hstmt);
351 			if (got == SQL_NULL_DATA)
352 				lua_pushnil(L);
353 			else
354 				lua_pushnumber(L, num);
355 			return 0;
356 		}
357 #if LUA_VERSION_NUM>=503
358 		/* iNteger */
359 		case 'n': {
360 			SQLINTEGER num;
361 			SQLLEN got;
362 			SQLRETURN rc = SQLGetData(hstmt, i, SQL_C_SLONG, &num, 0, &got);
363 			if (error(rc))
364 				return fail(L, hSTMT, hstmt);
365 			if (got == SQL_NULL_DATA)
366 				lua_pushnil(L);
367 			else
368 				lua_pushinteger(L, num);
369 			return 0;
370 		}
371 #endif
372 		/* bOol */
373         case 'o': {
374 			SQLCHAR b;
375 			SQLLEN got;
376 			SQLRETURN rc = SQLGetData(hstmt, i, SQL_C_BIT, &b, 0, &got);
377 			if (error(rc))
378 				return fail(L, hSTMT, hstmt);
379 			if (got == SQL_NULL_DATA)
380 				lua_pushnil(L);
381 			else
382 				lua_pushboolean(L, b);
383 			return 0;
384 		}
385         /* sTring */
386         case 't':
387         /* bInary */
388         case 'i': {
389 			SQLSMALLINT stype = (type == 't') ? SQL_C_CHAR : SQL_C_BINARY;
390 			SQLLEN got;
391 			char *buffer;
392 			luaL_Buffer b;
393 			SQLRETURN rc;
394 			luaL_buffinit(L, &b);
395 			buffer = luaL_prepbuffer(&b);
396 			rc = SQLGetData(hstmt, i, stype, buffer, LUAL_BUFFERSIZE, &got);
397 			if (got == SQL_NULL_DATA) {
398 				lua_pushnil(L);
399 				return 0;
400 			}
401 			/* concat intermediary chunks */
402 			while (rc == SQL_SUCCESS_WITH_INFO) {
403 				if (got >= LUAL_BUFFERSIZE || got == SQL_NO_TOTAL) {
404 					got = LUAL_BUFFERSIZE;
405 					/* get rid of null termination in string block */
406 					if (stype == SQL_C_CHAR) got--;
407 				}
408 				luaL_addsize(&b, got);
409 				buffer = luaL_prepbuffer(&b);
410 				rc = SQLGetData(hstmt, i, stype, buffer,
411 					LUAL_BUFFERSIZE, &got);
412 			}
413 			/* concat last chunk */
414 			if (rc == SQL_SUCCESS) {
415 				if (got >= LUAL_BUFFERSIZE || got == SQL_NO_TOTAL) {
416 					got = LUAL_BUFFERSIZE;
417 					/* get rid of null termination in string block */
418 					if (stype == SQL_C_CHAR) got--;
419 				}
420 				luaL_addsize(&b, got);
421 			}
422 			if (rc == SQL_ERROR) return fail(L, hSTMT, hstmt);
423 			/* return everything we got */
424 			luaL_pushresult(&b);
425 			return 0;
426 		}
427     }
428     return 0;
429 }
430 
431 /*
432 ** Get another row of the given cursor.
433 */
cur_fetch(lua_State * L)434 static int cur_fetch (lua_State *L)
435 {
436 	cur_data *cur = getcursor (L, 1);
437 	SQLHSTMT hstmt = cur->stmt->hstmt;
438 	int ret;
439 	SQLRETURN rc = SQLFetch(hstmt);
440 	if (rc == SQL_NO_DATA) {
441 		/* automatically close cursor when end of resultset is reached */
442 		if((ret = cur_shut(L, cur)) != 0) {
443 			return ret;
444 		}
445 
446 		lua_pushnil(L);
447 		return 1;
448 	} else {
449 		if (error(rc)) {
450 			return fail(L, hSTMT, hstmt);
451 		}
452 	}
453 
454 	if (lua_istable (L, 2)) {
455 		SQLUSMALLINT i;
456 		const char *opts = luaL_optstring (L, 3, "n");
457 		int num = strchr (opts, 'n') != NULL;
458 		int alpha = strchr (opts, 'a') != NULL;
459 		for (i = 1; i <= cur->numcols; i++) {
460 			ret = push_column (L, cur->coltypes, hstmt, i);
461 			if (ret) {
462 				return ret;
463 			}
464 			if (alpha) {
465 				lua_rawgeti (L, LUA_REGISTRYINDEX, cur->colnames);
466 				lua_rawgeti (L, -1, i); /* gets column name */
467 				lua_pushvalue (L, -3); /* duplicates column value */
468 				lua_rawset (L, 2); /* table[name] = value */
469 				lua_pop (L, 1);	/* pops colnames table */
470 			}
471 			if (num) {
472 				lua_rawseti (L, 2, i);
473 			} else {
474 				lua_pop (L, 1); /* pops value */
475 			}
476 		}
477 		lua_pushvalue (L, 2);
478 		return 1;	/* return table */
479 	} else {
480 		SQLUSMALLINT i;
481 		luaL_checkstack (L, cur->numcols, LUASQL_PREFIX"too many columns");
482 		for (i = 1; i <= cur->numcols; i++) {
483 			ret = push_column (L, cur->coltypes, hstmt, i);
484 			if (ret) {
485 				return ret;
486 			}
487 		}
488 		return cur->numcols;
489 	}
490 }
491 
492 /*
493 ** Closes a cursor.
494 */
cur_close(lua_State * L)495 static int cur_close (lua_State *L)
496 {
497 	int res;
498 	cur_data *cur = (cur_data *) luaL_checkudata (L, 1, LUASQL_CURSOR_ODBC);
499 	luaL_argcheck (L, cur != NULL, 1, LUASQL_PREFIX"cursor expected");
500 
501 	if (cur->closed) {
502 		lua_pushboolean (L, 0);
503 		return 1;
504 	}
505 
506 	if((res = cur_shut(L, cur)) != 0) {
507 		return res;
508 	}
509 
510 	return pass(L);
511 }
512 
513 /*
514 ** Returns the table with column names.
515 */
cur_colnames(lua_State * L)516 static int cur_colnames (lua_State *L) {
517 	cur_data *cur = (cur_data *) getcursor (L, 1);
518 	lua_rawgeti (L, LUA_REGISTRYINDEX, cur->colnames);
519 	return 1;
520 }
521 
522 
523 /*
524 ** Returns the table with column types.
525 */
cur_coltypes(lua_State * L)526 static int cur_coltypes (lua_State *L) {
527 	cur_data *cur = (cur_data *) getcursor (L, 1);
528 	lua_rawgeti (L, LUA_REGISTRYINDEX, cur->coltypes);
529 	return 1;
530 }
531 
532 
533 /*
534 ** Creates two tables with the names and the types of the columns.
535 */
create_colinfo(lua_State * L,cur_data * cur)536 static int create_colinfo (lua_State *L, cur_data *cur)
537 {
538 	SQLCHAR buffer[256];
539 	SQLSMALLINT namelen, datatype, i;
540 	SQLRETURN ret;
541 	int types, names;
542 
543 	lua_newtable(L);
544 	types = lua_gettop (L);
545 	lua_newtable(L);
546 	names = lua_gettop (L);
547 	for (i = 1; i <= cur->numcols; i++) {
548 		ret = SQLDescribeCol(cur->stmt->hstmt, i, buffer, sizeof(buffer),
549 		                     &namelen, &datatype, NULL, NULL, NULL);
550 		if (ret == SQL_ERROR) {
551 			lua_pop(L, 2);
552 			return -1;
553 		}
554 
555 		lua_pushstring (L, (char *)buffer);
556 		lua_rawseti (L, names, i);
557 		lua_pushstring(L, sqltypetolua(datatype));
558 		lua_rawseti (L, types, i);
559 	}
560 	cur->colnames = luaL_ref (L, LUA_REGISTRYINDEX);
561 	cur->coltypes = luaL_ref (L, LUA_REGISTRYINDEX);
562 
563 	return 0;
564 }
565 
566 
567 /*
568 ** Creates a cursor table and leave it on the top of the stack.
569 */
create_cursor(lua_State * L,int stmt_i,stmt_data * stmt,const SQLSMALLINT numcols)570 static int create_cursor (lua_State *L, int stmt_i, stmt_data *stmt,
571                           const SQLSMALLINT numcols)
572 {
573 	cur_data *cur;
574 
575 	lock_obj(L, stmt_i, stmt);
576 
577 	cur = (cur_data *) lua_newuserdata(L, sizeof(cur_data));
578 	luasql_setmeta (L, LUASQL_CURSOR_ODBC);
579 
580 	/* fill in structure */
581 	cur->closed = 0;
582 	cur->stmt = stmt;
583 	cur->numcols = numcols;
584 	cur->colnames = LUA_NOREF;
585 	cur->coltypes = LUA_NOREF;
586 
587 	/* make and store column information table */
588 	if(create_colinfo (L, cur) < 0) {
589 		lua_pop(L, 1);
590 		return fail(L, hSTMT, cur->stmt->hstmt);
591 	}
592 
593 	return 1;
594 }
595 
596 
597 /*
598 ** Returns the table with statement params.
599 */
stmt_paramtypes(lua_State * L)600 static int stmt_paramtypes (lua_State *L)
601 {
602 	stmt_data *stmt = getstatement(L, 1);
603 	lua_rawgeti (L, LUA_REGISTRYINDEX, stmt->paramtypes);
604 	return 1;
605 }
606 
stmt_close(lua_State * L)607 static int stmt_close(lua_State *L)
608 {
609 	stmt_data *stmt = (stmt_data *) luaL_checkudata (L, 1, LUASQL_STATEMENT_ODBC);
610 	luaL_argcheck (L, stmt != NULL, 1, LUASQL_PREFIX"statement expected");
611 	luaL_argcheck (L, stmt->lock == 0, 1,
612 	               LUASQL_PREFIX"there are still open cursors");
613 
614 	if (stmt->closed) {
615 		lua_pushboolean (L, 0);
616 		return 1;
617 	}
618 
619 	if(stmt_shut(L, stmt)) {
620 		return fail(L, hSTMT, stmt->hstmt);
621 	}
622 
623 	lua_pushboolean(L, 1);
624 	return 1;
625 }
626 
627 
628 /*
629 ** Closes a connection.
630 */
conn_close(lua_State * L)631 static int conn_close (lua_State *L)
632 {
633 	SQLRETURN ret;
634 	conn_data *conn = (conn_data *)luaL_checkudata(L,1,LUASQL_CONNECTION_ODBC);
635 	luaL_argcheck (L, conn != NULL, 1, LUASQL_PREFIX"connection expected");
636 	if (conn->closed) {
637 		lua_pushboolean (L, 0);
638 		return 1;
639 	}
640 	if (conn->lock > 0) {
641 		return luaL_error (L, LUASQL_PREFIX"there are open statements/cursors");
642 	}
643 
644 	/* Decrement connection counter on environment object */
645 	unlock_obj(L, conn->env);
646 
647 	/* Nullify structure fields. */
648 	conn->closed = 1;
649 	ret = SQLDisconnect(conn->hdbc);
650 	if (error(ret)) {
651 		return fail(L, hDBC, conn->hdbc);
652 	}
653 
654 	ret = SQLFreeHandle(hDBC, conn->hdbc);
655 	if (error(ret)) {
656 		return fail(L, hDBC, conn->hdbc);
657 	}
658 
659 	return pass(L);
660 }
661 
662 /*
663 ** Executes the given statement
664 ** Takes:
665 **  * istmt : location of the statement object on the stack
666 */
raw_execute(lua_State * L,int istmt)667 static int raw_execute(lua_State *L, int istmt)
668 {
669 	SQLSMALLINT numcols;
670 
671 	stmt_data *stmt = getstatement(L, istmt);
672 
673 	/* execute the statement */
674 	if (error(SQLExecute(stmt->hstmt))) {
675 		return fail(L, hSTMT, stmt->hstmt);
676 	}
677 
678 	/* determine the number of result columns */
679 	if (error(SQLNumResultCols(stmt->hstmt, &numcols))) {
680 		return fail(L, hSTMT, stmt->hstmt);
681 	}
682 
683 	if (numcols > 0) {
684 		/* if there is a results table (e.g., SELECT) */
685 		return create_cursor(L, -1, stmt, numcols);
686 	} else {
687 		/* if action has no results (e.g., UPDATE) */
688 		SQLINTEGER numrows;
689 		if(error(SQLRowCount(stmt->hstmt, &numrows))) {
690 			return fail(L, hSTMT, stmt->hstmt);
691 		}
692 
693 		lua_pushnumber(L, numrows);
694 		return 1;
695 	}
696 }
697 
set_param(lua_State * L,stmt_data * stmt,int i,param_data * data)698 static int set_param(lua_State *L, stmt_data *stmt, int i, param_data *data)
699 {
700 	static SQLINTEGER cbNull = SQL_NULL_DATA;
701 
702 	switch(lua_type(L, -1)) {
703 	case LUA_TNIL: {
704 			lua_pop(L, 1);
705 
706 			if(error(SQLBindParameter(stmt->hstmt, i, SQL_PARAM_INPUT, SQL_C_DEFAULT,
707 			                          SQL_DOUBLE, 0, 0, NULL, 0, &cbNull))) {
708 				return fail(L, hSTMT, stmt->hstmt);
709 			}
710 		}
711 		break;
712 
713 	case LUA_TNUMBER: {
714 			data->buf = malloc(sizeof(double));
715 			*(double *)data->buf = (double)lua_tonumber(L, -1);
716 			data->len = sizeof(double);
717 			data->type = 0;
718 
719 			lua_pop(L, 1);
720 
721 			if(error(SQLBindParameter(stmt->hstmt, i, SQL_PARAM_INPUT, SQL_C_DOUBLE,
722 			                          SQL_DOUBLE, 0, 0, data->buf, data->len,
723 			                          &data->type))) {
724 				return fail(L, hSTMT, stmt->hstmt);
725 			}
726 		}
727 		break;
728 
729 	case LUA_TSTRING: {
730 			const char *str = lua_tostring(L, -1);
731 			size_t len = strlen(str);
732 
733 			data->buf = malloc(len+1);
734 			memcpy((char *)data->buf, str, len+1);
735 			data->len = len;
736 			data->type = SQL_NTS;
737 
738 			lua_pop(L, 1);
739 
740 			if(error(SQLBindParameter(stmt->hstmt, i, SQL_PARAM_INPUT, SQL_C_CHAR,
741 			                          SQL_CHAR, len, 0, data->buf, data->len,
742 			                          &data->type))) {
743 				return fail(L, hSTMT, stmt->hstmt);
744 			}
745 		}
746 		break;
747 
748 	case LUA_TBOOLEAN: {
749 			data->buf = malloc(sizeof(SQLCHAR));
750 			*(SQLCHAR *)data->buf = (SQLCHAR)lua_toboolean(L, -1);
751 			data->len = 0;
752 			data->type = 0;
753 
754 			lua_pop(L, 1);
755 
756 			if(error(SQLBindParameter(stmt->hstmt, i, SQL_PARAM_INPUT, SQL_C_BIT,
757 			                          SQL_BIT, 0, 0, data->buf, data->len,
758 			                          &data->type))) {
759 				return fail(L, hSTMT, stmt->hstmt);
760 			}
761 		}
762 		break;
763 
764 	default:
765 		lua_pop(L, 1);
766 		return luasql_faildirect(L, "unsupported parameter type");
767 	}
768 
769 	return 0;
770 }
771 
772 /*
773 ** Reads a param table into a statement
774 */
raw_readparams_table(lua_State * L,stmt_data * stmt,int iparams)775 static int raw_readparams_table(lua_State *L, stmt_data *stmt, int iparams)
776 {
777 	static SQLINTEGER cbNull = SQL_NULL_DATA;
778 	SQLSMALLINT i;
779 	param_data *data;
780 	int res = 0;
781 
782 	free_stmt_params(stmt->params, stmt->numparams);
783 	stmt->params = malloc_stmt_params(stmt->numparams);
784 	data = stmt->params;
785 
786 	for(i=1; i<=stmt->numparams; ++i, ++data) {
787 		/* not using lua_geti for backwards compat with Lua 5.1/LuaJIT */
788 		lua_pushnumber(L, i);
789 		lua_gettable(L, iparams);
790 
791 		res = set_param(L, stmt, i, data);
792 		if(res != 0) {
793 			return res;
794 		}
795 	}
796 
797 	return 0;
798 }
799 
800 /*
801 ** Reads a param table into a statement
802 */
raw_readparams_args(lua_State * L,stmt_data * stmt,int arg,int ltop)803 static int raw_readparams_args(lua_State *L, stmt_data *stmt, int arg, int ltop)
804 {
805 	static SQLINTEGER cbNull = SQL_NULL_DATA;
806 	SQLSMALLINT i;
807 	param_data *data;
808 	int res = 0;
809 
810 	free_stmt_params(stmt->params, stmt->numparams);
811 	stmt->params = malloc_stmt_params(stmt->numparams);
812 	data = stmt->params;
813 
814 	for(i=1; i<=stmt->numparams; ++i, ++data, ++arg) {
815 		if(arg > ltop) {
816 			lua_pushnil(L);
817 		} else {
818 			lua_pushvalue(L, arg);
819 		}
820 		res = set_param(L, stmt, i, data);
821 		if(res != 0) {
822 			return res;
823 		}
824 	}
825 
826 	return 0;
827 }
828 
829 /*
830 ** Executes the prepared statement
831 ** Lua Input: [params]
832 **   params: A table of parameters to use in the statement
833 ** Lua Returns
834 **   cursor object: if there are results or
835 **   row count: number of rows affected by statement if no results
836 */
stmt_execute(lua_State * L)837 static int stmt_execute(lua_State *L)
838 {
839 	int ltop = lua_gettop(L);
840 	int res;
841 
842 	/* any parameters to use */
843 	if(ltop > 1) {
844 		stmt_data *stmt = getstatement(L, 1);
845 		if(lua_type(L, 2) == LUA_TTABLE) {
846 			res = raw_readparams_table(L, stmt, 2);
847 		} else {
848 			res = raw_readparams_args(L, stmt, 2, ltop);
849 		}
850 		if(res != 0) {
851 			return res;
852 		}
853 	}
854 
855 	return raw_execute(L, 1);
856 }
857 
858 /*
859 ** creates a table of parameter types (maybe)
860 ** Returns: the reference key of the table (noref if unable to build the table)
861 */
desc_params(lua_State * L,stmt_data * stmt)862 static int desc_params(lua_State *L, stmt_data *stmt)
863 {
864 	SQLSMALLINT i;
865 
866 	lua_newtable(L);
867 	for(i=1; i <= stmt->numparams; ++i) {
868 		SQLSMALLINT type, digits, nullable;
869 		SQLULEN len;
870 
871 		/* fun fact: most ODBC drivers don't support this function (MS Access for
872 		   example), so we can't get a param type table */
873 		if(error(SQLDescribeParam(stmt->hstmt, i, &type, &len, &digits, &nullable))) {
874 			lua_pop(L,1);
875 			return LUA_NOREF;
876 		}
877 
878 		lua_pushstring(L, sqltypetolua(type));
879 		lua_rawseti(L, -2, i);
880 	}
881 
882 	return luaL_ref(L, LUA_REGISTRYINDEX);
883 }
884 
885 /*
886 ** Prepares a statement
887 ** Lua Input: sql
888 **   sql: The SQL statement to prepare
889 ** Lua Returns:
890 **   Statement object
891 */
conn_prepare(lua_State * L)892 static int conn_prepare(lua_State *L)
893 {
894 	conn_data *conn = getconnection(L, 1);
895 	SQLCHAR *statement = (SQLCHAR *)luaL_checkstring(L, 2);
896 	SQLHDBC hdbc = conn->hdbc;
897 	SQLHSTMT hstmt;
898 	SQLRETURN ret;
899 
900 	stmt_data *stmt;
901 
902 	ret = SQLAllocHandle(hSTMT, hdbc, &hstmt);
903 	if (error(ret)) {
904 		return fail(L, hDBC, hdbc);
905 	}
906 
907 	ret = SQLPrepare(hstmt, statement, SQL_NTS);
908 	if (error(ret)) {
909 		ret = fail(L, hSTMT, hstmt);
910 		SQLFreeHandle(hSTMT, hstmt);
911 		return ret;
912 	}
913 
914 	stmt = (stmt_data *)lua_newuserdata(L, sizeof(stmt_data));
915 	memset(stmt, 0, sizeof(stmt_data));
916 
917 	stmt->closed = 0;
918 	stmt->lock = 0;
919 	stmt->hidden = 0;
920 	stmt->conn = conn;
921 	stmt->hstmt = hstmt;
922 	if(error(SQLNumParams(hstmt, &stmt->numparams))) {
923 		int res;
924 		lua_pop(L, 1);
925 		res = fail(L, hSTMT, hstmt);
926 		SQLFreeHandle(hSTMT, hstmt);
927 		return res;
928 	}
929 	stmt->paramtypes = desc_params(L, stmt);
930 	stmt->params = NULL;
931 
932 	/* activate statement object */
933 	luasql_setmeta(L, LUASQL_STATEMENT_ODBC);
934 	lock_obj(L, 1, conn);
935 
936 	return 1;
937 }
938 
939 /*
940 ** Executes a SQL statement directly
941 ** Lua Input: sql, [params]
942 **   sql: The SQL statement to execute
943 **   params: A table of parameters to use in the SQL
944 ** Lua Returns
945 **   cursor object: if there are results or
946 **   row count: number of rows affected by statement if no results
947 */
conn_execute(lua_State * L)948 static int conn_execute (lua_State *L)
949 {
950 	stmt_data *stmt;
951 	int res, istmt;
952 	int ltop = lua_gettop(L);
953 
954 	/* prepare statement */
955 	if((res = conn_prepare(L)) != 1) {
956 		return res;
957 	}
958 	istmt = lua_gettop(L);
959 	stmt = getstatement(L, istmt);
960 
961 	/* because this is a direct execute, statement is hidden from user */
962 	stmt->hidden = 1;
963 
964 	/* do we have any parameters */
965 	if(ltop > 2) {
966 		if(lua_type(L, 3) == LUA_TTABLE) {
967 			res = raw_readparams_table(L, stmt, 3);
968 		} else {
969 			res = raw_readparams_args(L, stmt, 3, ltop);
970 		}
971 		if(res != 0) {
972 			return res;
973 		}
974 	}
975 
976 	/* do it */
977 	res = raw_execute(L, istmt);
978 
979 	/* anything but a cursor, close the statement directly */
980 	if(!lua_isuserdata(L, -res)) {
981 		stmt_shut(L, stmt);
982 	}
983 
984 	/* tidy up */
985 	lua_remove(L, istmt);
986 
987 	return res;
988 }
989 
990 /*
991 ** Rolls back a transaction.
992 */
conn_commit(lua_State * L)993 static int conn_commit (lua_State *L) {
994 	conn_data *conn = (conn_data *) getconnection (L, 1);
995 	SQLRETURN ret = SQLEndTran(hDBC, conn->hdbc, SQL_COMMIT);
996 	if (error(ret))
997 		return fail(L, hSTMT, conn->hdbc);
998 	else
999 		return pass(L);
1000 }
1001 
1002 /*
1003 ** Rollback the current transaction.
1004 */
conn_rollback(lua_State * L)1005 static int conn_rollback (lua_State *L) {
1006 	conn_data *conn = (conn_data *) getconnection (L, 1);
1007 	SQLRETURN ret = SQLEndTran(hDBC, conn->hdbc, SQL_ROLLBACK);
1008 	if (error(ret))
1009 		return fail(L, hSTMT, conn->hdbc);
1010 	else
1011 		return pass(L);
1012 }
1013 
1014 /*
1015 ** Sets the auto commit mode
1016 */
conn_setautocommit(lua_State * L)1017 static int conn_setautocommit (lua_State *L) {
1018 	conn_data *conn = (conn_data *) getconnection (L, 1);
1019 	SQLRETURN ret;
1020 	if (lua_toboolean (L, 2)) {
1021 		ret = SQLSetConnectAttr(conn->hdbc, SQL_ATTR_AUTOCOMMIT,
1022 			(SQLPOINTER) SQL_AUTOCOMMIT_ON, 0);
1023 	} else {
1024 		ret = SQLSetConnectAttr(conn->hdbc, SQL_ATTR_AUTOCOMMIT,
1025 			(SQLPOINTER) SQL_AUTOCOMMIT_OFF, 0);
1026 	}
1027 	if (error(ret))
1028 		return fail(L, hSTMT, conn->hdbc);
1029 	else
1030 		return pass(L);
1031 }
1032 
1033 
1034 /*
1035 ** Create a new Connection object and push it on top of the stack.
1036 */
create_connection(lua_State * L,int o,env_data * env,SQLHDBC hdbc)1037 static int create_connection (lua_State *L, int o, env_data *env, SQLHDBC hdbc)
1038 {
1039 	conn_data *conn = (conn_data *)lua_newuserdata(L, sizeof(conn_data));
1040 
1041 	/* set auto commit mode */
1042 	SQLRETURN ret = SQLSetConnectAttr(hdbc, SQL_ATTR_AUTOCOMMIT,
1043 	                                  (SQLPOINTER) SQL_AUTOCOMMIT_ON, 0);
1044 	if (error(ret)) {
1045 		return fail(L, hDBC, hdbc);
1046 	}
1047 
1048 	luasql_setmeta (L, LUASQL_CONNECTION_ODBC);
1049 
1050 	/* fill in structure */
1051 	conn->closed = 0;
1052 	conn->lock = 0;
1053 	conn->env = env;
1054 	conn->hdbc = hdbc;
1055 
1056 	lock_obj(L, 1, env);
1057 
1058 	return 1;
1059 }
1060 
1061 
1062 /*
1063 ** Creates and returns a connection object
1064 ** Lua Input: source [, user [, pass]]
1065 **   source: data source
1066 **   user, pass: data source authentication information
1067 ** Lua Returns:
1068 **   connection object if successfull
1069 **   nil and error message otherwise.
1070 */
env_connect(lua_State * L)1071 static int env_connect (lua_State *L) {
1072 	env_data *env = (env_data *) getenvironment (L, 1);
1073 	SQLCHAR *sourcename = (SQLCHAR*)luaL_checkstring (L, 2);
1074 	SQLCHAR *username = (SQLCHAR*)luaL_optstring (L, 3, NULL);
1075 	SQLCHAR *password = (SQLCHAR*)luaL_optstring (L, 4, NULL);
1076 	SQLHDBC hdbc;
1077 	SQLRETURN ret;
1078 
1079 	/* tries to allocate connection handle */
1080 	ret = SQLAllocHandle (hDBC, env->henv, &hdbc);
1081 	if (error(ret))
1082 		return luasql_faildirect (L, "connection allocation error.");
1083 
1084 	/* tries to connect handle */
1085 	ret = SQLConnect (hdbc, sourcename, SQL_NTS,
1086 		username, SQL_NTS, password, SQL_NTS);
1087 	if (error(ret)) {
1088 		ret = fail(L, hDBC, hdbc);
1089 		SQLFreeHandle(hDBC, hdbc);
1090 		return ret;
1091 	}
1092 
1093 	/* success, return connection object */
1094 	return create_connection (L, 1, env, hdbc);
1095 }
1096 
1097 /*
1098 ** Closes an environment object
1099 */
env_close(lua_State * L)1100 static int env_close (lua_State *L)
1101 {
1102 	SQLRETURN ret;
1103 	env_data *env = (env_data *)luaL_checkudata(L, 1, LUASQL_ENVIRONMENT_ODBC);
1104 	luaL_argcheck (L, env != NULL, 1, LUASQL_PREFIX"environment expected");
1105 	if (env->closed) {
1106 		lua_pushboolean (L, 0);
1107 		return 1;
1108 	}
1109 	if (env->lock > 0) {
1110 		return luaL_error (L, LUASQL_PREFIX"there are open connections");
1111 	}
1112 
1113 	env->closed = 1;
1114 	ret = SQLFreeHandle (hENV, env->henv);
1115 	if (error(ret)) {
1116 		int ret = fail(L, hENV, env->henv);
1117 		env->henv = NULL;
1118 		return ret;
1119 	}
1120 	return pass(L);
1121 }
1122 
1123 
1124 /*
1125 ** Create metatables for each class of object.
1126 */
create_metatables(lua_State * L)1127 static void create_metatables (lua_State *L) {
1128 	struct luaL_Reg environment_methods[] = {
1129 		{"__gc", env_close}, /* Should this method be changed? */
1130 		{"close", env_close},
1131 		{"connect", env_connect},
1132 		{NULL, NULL},
1133 	};
1134 	struct luaL_Reg connection_methods[] = {
1135 		{"__gc", conn_close}, /* Should this method be changed? */
1136 		{"close", conn_close},
1137 		{"prepare", conn_prepare},
1138 		{"execute", conn_execute},
1139 		{"commit", conn_commit},
1140 		{"rollback", conn_rollback},
1141 		{"setautocommit", conn_setautocommit},
1142 		{NULL, NULL},
1143 	};
1144 	struct luaL_Reg statement_methods[] = {
1145 		{"__gc", stmt_close}, /* Should this method be changed? */
1146 		{"close", stmt_close},
1147 		{"execute", stmt_execute},
1148 		{"getparamtypes", stmt_paramtypes},
1149 		{NULL, NULL},
1150 	};
1151 	struct luaL_Reg cursor_methods[] = {
1152 		{"__gc", cur_close}, /* Should this method be changed? */
1153 		{"close", cur_close},
1154 		{"fetch", cur_fetch},
1155 		{"getcoltypes", cur_coltypes},
1156 		{"getcolnames", cur_colnames},
1157 		{NULL, NULL},
1158 	};
1159 	luasql_createmeta (L, LUASQL_ENVIRONMENT_ODBC, environment_methods);
1160 	luasql_createmeta (L, LUASQL_CONNECTION_ODBC, connection_methods);
1161 	luasql_createmeta (L, LUASQL_STATEMENT_ODBC, statement_methods);
1162 	luasql_createmeta (L, LUASQL_CURSOR_ODBC, cursor_methods);
1163 	lua_pop (L, 4);
1164 }
1165 
1166 
1167 /*
1168 ** Creates an Environment and returns it.
1169 */
create_environment(lua_State * L)1170 static int create_environment (lua_State *L)
1171 {
1172 	env_data *env;
1173 	SQLHENV henv;
1174 	SQLRETURN ret = SQLAllocHandle(hENV, SQL_NULL_HANDLE, &henv);
1175 	if (error(ret)) {
1176 		return luasql_faildirect(L, "error creating environment.");
1177 	}
1178 
1179 	ret = SQLSetEnvAttr (henv, SQL_ATTR_ODBC_VERSION, (void *)SQL_OV_ODBC3, 0);
1180 	if (error(ret)) {
1181 		ret = luasql_faildirect (L, "error setting SQL version.");
1182 		SQLFreeHandle (hENV, henv);
1183 		return ret;
1184 	}
1185 
1186 	env = (env_data *)lua_newuserdata (L, sizeof (env_data));
1187 	luasql_setmeta (L, LUASQL_ENVIRONMENT_ODBC);
1188 	/* fill in structure */
1189 	env->closed = 0;
1190 	env->lock = 0;
1191 	env->henv = henv;
1192 	return 1;
1193 }
1194 
1195 
1196 /*
1197 ** Creates the metatables for the objects and registers the
1198 ** driver open method.
1199 */
luaopen_luasql_odbc(lua_State * L)1200 LUASQL_API int luaopen_luasql_odbc (lua_State *L) {
1201 	struct luaL_Reg driver[] = {
1202 		{"odbc", create_environment},
1203 		{NULL, NULL},
1204 	};
1205 	create_metatables (L);
1206 	lua_newtable (L);
1207 	luaL_setfuncs (L, driver, 0);
1208 	luasql_set_info (L);
1209 	return 1;
1210 }
1211