1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 //  Copyright (C) 2010-2011  Artyom Beilis (Tonkikh) <artyomtnk@yahoo.com>
4 //
5 //  Distributed under:
6 //
7 //                   the Boost Software License, Version 1.0.
8 //              (See accompanying file LICENSE_1_0.txt or copy at
9 //                     http://www.boost.org/LICENSE_1_0.txt)
10 //
11 //  or (at your opinion) under:
12 //
13 //                               The MIT License
14 //                 (See accompanying file MIT.txt or a copy at
15 //              http://www.opensource.org/licenses/mit-license.php)
16 //
17 ///////////////////////////////////////////////////////////////////////////////
18 #define CPPDB_DRIVER_SOURCE
19 #ifdef CPPDB_WITH_ODBC
20 #define CPPDB_SOURCE
21 #endif
22 #include <cppdb/backend.h>
23 #include <cppdb/utils.h>
24 #include <cppdb/numeric_util.h>
25 #include <list>
26 #include <vector>
27 #include <iostream>
28 #include <sstream>
29 #include <limits>
30 #include <iomanip>
31 #include <string.h>
32 
33 #if defined(_WIN32) || defined(__WIN32) || defined(WIN32) || defined(__CYGWIN__)
34 #include <windows.h>
35 #endif
36 #include <sqlext.h>
37 
38 namespace cppdb {
39 namespace odbc_backend {
40 
41 typedef unsigned odbc_u32;
42 typedef unsigned short odbc_u16;
43 
44 int assert_on_unsigned_is_32[sizeof(unsigned) == 4 ? 1 : -1];
45 int assert_on_unsigned_short_is_16[sizeof(unsigned short) == 2 ? 1 : -1];
46 int assert_on_sqlwchar_is_16[sizeof(SQLWCHAR) == 2 ? 1 : -1];
47 
48 
49 
50 namespace utf {
51 	static const odbc_u32 illegal = 0xFFFFFFFFu;
valid(odbc_u32 v)52 	inline bool valid(odbc_u32 v)
53 	{
54 		if(v>0x10FFFF)
55 			return false;
56 		if(0xD800 <=v && v<= 0xDFFF) // surragates
57 			return false;
58 		return true;
59 	}
60 }
61 
62 namespace utf8 {
63 	// See RFC 3629
64 	// Based on: http://www.w3.org/International/questions/qa-forms-utf-8
65 	template<typename Iterator>
next(Iterator & p,Iterator e)66 	odbc_u32 next(Iterator &p,Iterator e)
67 	{
68 		unsigned char c=*p++;
69 		unsigned char seq0,seq1=0,seq2=0,seq3=0;
70 		seq0=c;
71 		int len=1;
72 		if((c & 0xC0) == 0xC0) {
73 			if(p==e)
74 				return utf::illegal;
75 			seq1=*p++;
76 			len=2;
77 		}
78 		if((c & 0xE0) == 0xE0) {
79 			if(p==e)
80 				return utf::illegal;
81 			seq2=*p++;
82 			len=3;
83 		}
84 		if((c & 0xF0) == 0xF0) {
85 			if(p==e)
86 				return utf::illegal;
87 			seq3=*p++;
88 			len=4;
89 		}
90 		switch(len) {
91 		case 1: // ASCII -- remove codes for HTML only
92 			if(seq0 > 0x7F)
93 				return utf::illegal;
94 			break;
95 		case 2: // non-overloading 2 bytes
96 			if(0xC2 <= seq0 && seq0 <= 0xDF) {
97 				if(0x80 <= seq1 && seq1<= 0xBF)
98 					break;
99 			}
100 			return utf::illegal;
101 		case 3:
102 			if(seq0==0xE0) { // exclude overloadings
103 				if(0xA0 <=seq1 && seq1<= 0xBF && 0x80 <=seq2 && seq2<=0xBF)
104 					break;
105 			}
106 			else if( (0xE1 <= seq0 && seq0 <=0xEC) || seq0==0xEE || seq0==0xEF) { // stright 3 bytes
107 				if(	0x80 <=seq1 && seq1<=0xBF &&
108 					0x80 <=seq2 && seq2<=0xBF)
109 					break;
110 			}
111 			else if(seq0 == 0xED) { // exclude surrogates
112 				if(	0x80 <=seq1 && seq1<=0x9F &&
113 					0x80 <=seq2 && seq2<=0xBF)
114 					break;
115 			}
116 			return utf::illegal;
117 		case 4:
118 			switch(seq0) {
119 			case 0xF0: // planes 1-3
120 				if(	0x90 <=seq1 && seq1<=0xBF &&
121 					0x80 <=seq2 && seq2<=0xBF &&
122 					0x80 <=seq3 && seq3<=0xBF)
123 					break;
124 				return utf::illegal;
125 			case 0xF1: // planes 4-15
126 			case 0xF2:
127 			case 0xF3:
128 				if(	0x80 <=seq1 && seq1<=0xBF &&
129 					0x80 <=seq2 && seq2<=0xBF &&
130 					0x80 <=seq3 && seq3<=0xBF)
131 					break;
132 				return utf::illegal;
133 			case 0xF4: // pane 16
134 				if(	0x80 <=seq1 && seq1<=0x8F &&
135 					0x80 <=seq2 && seq2<=0xBF &&
136 					0x80 <=seq3 && seq3<=0xBF)
137 					break;
138 				return utf::illegal;
139 			default:
140 				return utf::illegal;
141 			}
142 
143 		}
144 
145 		switch(len) {
146 		case 1:
147 			return seq0;
148 		case 2:
149 			return ((seq0 & 0x1F) << 6) | (seq1 & 0x3F);
150 		case 3:
151 			return ((seq0 & 0x0F) << 12) | ((seq1 & 0x3F) << 6) | (seq2 & 0x3F)  ;
152 		case 4:
153 			return ((seq0 & 0x07) << 18) | ((seq1 & 0x3F) << 12) | ((seq2 & 0x3F) << 6) | (seq3 & 0x3F) ;
154 		}
155 
156 		return utf::illegal;
157 	} // valid
158 
159 
160 	struct seq {
161 		char c[4];
162 		unsigned len;
163 	};
encode(odbc_u32 value)164 	inline seq encode(odbc_u32 value)
165 	{
166 		seq out=seq();
167 		if(value <=0x7F) {
168 			out.c[0]=value;
169 			out.len=1;
170 		}
171 		else if(value <=0x7FF) {
172 			out.c[0]=(value >> 6) | 0xC0;
173 			out.c[1]=(value & 0x3F) | 0x80;
174 			out.len=2;
175 		}
176 		else if(value <=0xFFFF) {
177 			out.c[0]=(value >> 12) | 0xE0;
178 			out.c[1]=((value >> 6) & 0x3F) | 0x80;
179 			out.c[2]=(value & 0x3F) | 0x80;
180 			out.len=3;
181 		}
182 		else {
183 			out.c[0]=(value >> 18) | 0xF0;
184 			out.c[1]=((value >> 12) & 0x3F) | 0x80;
185 			out.c[2]=((value >> 6) & 0x3F) | 0x80;
186 			out.c[3]=(value & 0x3F) | 0x80;
187 			out.len=4;
188 		}
189 		return out;
190 	}
191 } // namespace utf8
192 
193 
194 namespace utf16 {
195 
196 	// See RFC 2781
is_first_surrogate(odbc_u16 x)197 	inline bool is_first_surrogate(odbc_u16 x)
198 	{
199 		return 0xD800 <=x && x<= 0xDBFF;
200 	}
is_second_surrogate(odbc_u16 x)201 	inline bool is_second_surrogate(odbc_u16 x)
202 	{
203 		return 0xDC00 <=x && x<= 0xDFFF;
204 	}
combine_surrogate(odbc_u16 w1,odbc_u16 w2)205 	inline odbc_u32 combine_surrogate(odbc_u16 w1,odbc_u16 w2)
206 	{
207 		return ((odbc_u32(w1 & 0x3FF) << 10) | (w2 & 0x3FF)) + 0x10000;
208 	}
209 
210 	template<typename It>
next(It & current,It last)211 	inline odbc_u32 next(It &current,It last)
212 	{
213 		odbc_u16 w1=*current++;
214 		if(w1 < 0xD800 || 0xDFFF < w1) {
215 			return w1;
216 		}
217 		if(w1 > 0xDBFF)
218 			return utf::illegal;
219 		if(current==last)
220 			return utf::illegal;
221 		odbc_u16 w2=*current++;
222 		if(w2 < 0xDC00 || 0xDFFF < w2)
223 			return utf::illegal;
224 		return combine_surrogate(w1,w2);
225 	}
width(odbc_u32 u)226 	inline int width(odbc_u32 u)
227 	{
228 		return u>=0x100000 ? 2 : 1;
229 	}
230 	struct seq {
231 		odbc_u16 c[2];
232 		unsigned len;
233 	};
encode(odbc_u32 u)234 	inline seq encode(odbc_u32 u)
235 	{
236 		seq out=seq();
237 		if(u<=0xFFFF) {
238 			out.c[0]=u;
239 			out.len=1;
240 		}
241 		else {
242 			u-=0x10000;
243 			out.c[0]=0xD800 | (u>>10);
244 			out.c[1]=0xDC00 | (u & 0x3FF);
245 			out.len=2;
246 		}
247 		return out;
248 	}
249 } // utf16;
250 
251 } // odbc_backend
252 } // cppdb
253 
254 
255 namespace cppdb {
256 
257 class connection_info;
258 class pool;
259 
260 namespace odbc_backend {
261 
widen(char const * b,char const * e)262 std::string widen(char const *b,char const *e)
263 {
264 	std::string result;
265 	result.reserve((e-b)*2);
266 	odbc_u32 code_point = 0;
267 	while(b < e && (code_point = utf8::next(b,e))!=utf::illegal) {
268 		utf16::seq sq = utf16::encode(code_point);
269 		char *str = (char *)sq.c;
270 		result.append(str,sq.len * 2);
271 	}
272 	if(b!=e || code_point == utf::illegal) {
273 		throw cppdb_error("cppdb::odbc invalid UTF-8 input");
274 	}
275 	return result;
276 }
277 
widen(std::string const & s)278 std::string widen(std::string const &s)
279 {
280 	return widen(s.c_str(),s.c_str()+s.size());
281 }
282 
narrower(std::basic_string<SQLWCHAR> const & wide)283 std::string narrower(std::basic_string<SQLWCHAR> const &wide)
284 {
285 	odbc_u16 const *b = reinterpret_cast<odbc_u16 const *>(wide.c_str());
286 	odbc_u16 const *e = b + wide.size();
287 
288 	std::string result;
289 	result.reserve((e-b));
290 
291 	odbc_u32 code_point = 0;
292 	while(b < e && (code_point = utf16::next(b,e))!=utf::illegal) {
293 		utf8::seq sq = utf8::encode(code_point);
294 		result.append(sq.c,sq.len);
295 	}
296 	if(b!=e || code_point == utf::illegal) {
297 		throw cppdb_error("cppdb::odbc got invalid UTF-16");
298 	}
299 	return result;
300 }
301 
narrower(std::string const & wide)302 std::string narrower(std::string const &wide)
303 {
304 	if(wide.size() % 2 != 0) {
305 		throw cppdb_error("cppdb::odbc got invalid UTF-16");
306 	}
307 	odbc_u16 const *b = reinterpret_cast<odbc_u16 const *>(wide.c_str());
308 	odbc_u16 const *e = b + wide.size() / 2;
309 
310 	std::string result;
311 	result.reserve((e-b));
312 
313 	odbc_u32 code_point = 0;
314 	while(b < e && (code_point = utf16::next(b,e))!=utf::illegal) {
315 		utf8::seq sq = utf8::encode(code_point);
316 		result.append(sq.c,sq.len);
317 	}
318 	if(b!=e || code_point == utf::illegal) {
319 		throw cppdb_error("cppdb::odbc got invalid UTF-16");
320 	}
321 	return result;
322 }
323 
tosqlwide(std::string const & n)324 std::basic_string<SQLWCHAR> tosqlwide(std::string const &n)
325 {
326 	std::basic_string<SQLWCHAR> result;
327 	char const *b=n.c_str();
328 	char const *e=b+n.size();
329 	result.reserve(e-b);
330 	odbc_u32 code_point = 0;
331 	while(b < e && (code_point = utf8::next(b,e))!=utf::illegal) {
332 		utf16::seq sq = utf16::encode(code_point);
333 		result.append((SQLWCHAR*)sq.c,sq.len);
334 	}
335 	if(b!=e || code_point == utf::illegal) {
336 		throw cppdb_error("cppdb::odbc invalid UTF-8 input");
337 	}
338 	return result;
339 }
340 
check_odbc_errorW(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type)341 void check_odbc_errorW(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type)
342 {
343 	if(SQL_SUCCEEDED(error))
344 		return;
345 	std::basic_string<SQLWCHAR> error_message;
346 	int rec=1,r;
347 	for(;;){
348 		SQLWCHAR msg[SQL_MAX_MESSAGE_LENGTH + 2] = {0};
349 		SQLWCHAR stat[SQL_SQLSTATE_SIZE + 1] = {0};
350 		SQLINTEGER err;
351 		SQLSMALLINT len;
352 		r = SQLGetDiagRecW(type,h,rec,stat,&err,msg,sizeof(msg)/sizeof(SQLWCHAR),&len);
353 		rec++;
354 		if(r==SQL_SUCCESS || r==SQL_SUCCESS_WITH_INFO) {
355 			if(!error_message.empty()) {
356 				SQLWCHAR nl = '\n';
357 				error_message+=nl;
358 			}
359 			error_message.append(msg);
360 		}
361 		else
362 			break;
363 
364 	}
365 	std::string utf8_str = "Unconvertable string";
366 	try { std::string tmp = narrower(error_message); utf8_str = tmp; } catch(...){}
367 	throw cppdb_error("cppdb::odbc_backend::Failed with error `" + utf8_str +"'");
368 }
369 
check_odbc_errorA(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type)370 void check_odbc_errorA(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type)
371 {
372 	if(SQL_SUCCEEDED(error))
373 		return;
374 	std::string error_message;
375 	int rec=1,r;
376 	for(;;){
377 		SQLCHAR msg[SQL_MAX_MESSAGE_LENGTH + 2] = {0};
378 		SQLCHAR stat[SQL_SQLSTATE_SIZE + 1] = {0};
379 		SQLINTEGER err;
380 		SQLSMALLINT len;
381 		r = SQLGetDiagRecA(type,h,rec,stat,&err,msg,sizeof(msg),&len);
382 		rec++;
383 		if(r==SQL_SUCCESS || r==SQL_SUCCESS_WITH_INFO) {
384 			if(!error_message.empty())
385 				error_message+='\n';
386 			error_message +=(char *)msg;
387 		}
388 		else
389 			break;
390 
391 	}
392 	throw cppdb_error("cppdb::odbc::Failed with error `" + error_message +"'");
393 }
394 
check_odbc_error(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type,bool wide)395 void check_odbc_error(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type,bool wide)
396 {
397 	if(wide)
398 		check_odbc_errorW(error,h,type);
399 	else
400 		check_odbc_errorA(error,h,type);
401 }
402 
403 
404 
405 class result : public backend::result {
406 public:
407 	typedef std::pair<bool,std::string> cell_type;
408 	typedef std::vector<cell_type> row_type;
409 	typedef std::list<row_type> rows_type;
410 
has_next()411 	virtual next_row has_next()
412 	{
413 		rows_type::iterator p=current_;
414 		if(p == rows_.end() || ++p==rows_.end())
415 			return last_row_reached;
416 		else
417 			return next_row_exists;
418 	}
next()419 	virtual bool next()
420 	{
421 		if(started_ == false) {
422 			current_ = rows_.begin();
423 			started_ = true;
424 		}
425 		else if(current_!=rows_.end()) {
426 			++current_;
427 		}
428 		return current_!=rows_.end();
429 	}
430 	template<typename T>
do_fetch(int col,T & v)431 	bool do_fetch(int col,T &v)
432 	{
433 		if(at(col).first)
434 			return false;
435 		v=parse_number<T>(at(col).second,ss_);
436 		return true;
437 	}
fetch(int col,short & v)438 	virtual bool fetch(int col,short &v)
439 	{
440 		return  do_fetch(col,v);
441 	}
fetch(int col,unsigned short & v)442 	virtual bool fetch(int col,unsigned short &v)
443 	{
444 		return  do_fetch(col,v);
445 	}
fetch(int col,int & v)446 	virtual bool fetch(int col,int &v)
447 	{
448 		return  do_fetch(col,v);
449 	}
fetch(int col,unsigned & v)450 	virtual bool fetch(int col,unsigned &v)
451 	{
452 		return  do_fetch(col,v);
453 	}
fetch(int col,long & v)454 	virtual bool fetch(int col,long &v)
455 	{
456 		return  do_fetch(col,v);
457 	}
fetch(int col,unsigned long & v)458 	virtual bool fetch(int col,unsigned long &v)
459 	{
460 		return  do_fetch(col,v);
461 	}
fetch(int col,long long & v)462 	virtual bool fetch(int col,long long &v)
463 	{
464 		return  do_fetch(col,v);
465 	}
fetch(int col,unsigned long long & v)466 	virtual bool fetch(int col,unsigned long long &v)
467 	{
468 		return  do_fetch(col,v);
469 	}
fetch(int col,float & v)470 	virtual bool fetch(int col,float &v)
471 	{
472 		return  do_fetch(col,v);
473 	}
fetch(int col,double & v)474 	virtual bool fetch(int col,double &v)
475 	{
476 		return  do_fetch(col,v);
477 	}
fetch(int col,long double & v)478 	virtual bool fetch(int col,long double &v)
479 	{
480 		return  do_fetch(col,v);
481 	}
fetch(int col,std::string & v)482 	virtual bool fetch(int col,std::string &v)
483 	{
484 		if(at(col).first)
485 			return false;
486 		v=at(col).second;
487 		return true;
488 	}
fetch(int col,std::ostream & v)489 	virtual bool fetch(int col,std::ostream &v)
490 	{
491 		if(at(col).first)
492 			return false;
493 		v << at(col).second;
494 		return true;
495 	}
fetch(int col,std::tm & v)496 	virtual bool fetch(int col,std::tm &v)
497 	{
498 		if(at(col).first)
499 			return false;
500 		v = parse_time(at(col).second);
501 		return true;
502 	}
is_null(int col)503 	virtual bool is_null(int col)
504 	{
505 		return at(col).first;
506 	}
cols()507 	virtual int cols()
508 	{
509 		return cols_;
510 	}
name_to_column(std::string const & cn)511 	virtual int name_to_column(std::string const &cn)
512 	{
513 		for(unsigned i=0;i<names_.size();i++)
514 			if(names_[i]==cn)
515 				return i;
516 		return -1;
517 	}
column_to_name(int c)518 	virtual std::string column_to_name(int c)
519 	{
520 		if(c < 0 || c >= int(names_.size()))
521 			throw invalid_column();
522 		return names_[c];
523 	}
524 
result(rows_type & rows,std::vector<std::string> & names,int cols)525 	result(rows_type &rows,std::vector<std::string> &names,int cols) : cols_(cols)
526 	{
527 		names_.swap(names);
528 		rows_.swap(rows);
529 		started_ = false;
530 		current_ = rows_.end();
531 		ss_.imbue(std::locale::classic());
532 	}
at(int col)533 	cell_type &at(int col)
534 	{
535 		if(current_!=rows_.end() && col >= 0 && col <int(current_->size()))
536 			return current_->at(col);
537 		throw invalid_column();
538 	}
539 private:
540 	int cols_;
541 	bool started_;
542 	std::vector<std::string> names_;
543 	rows_type::iterator current_;
544 	rows_type rows_;
545 	std::istringstream ss_;
546 };
547 
548 class statements_cache;
549 
550 class connection;
551 
552 class statement : public backend::statement {
553 	struct parameter {
parametercppdb::odbc_backend::statement::parameter554 		parameter() :
555 			null(true),
556 			ctype(SQL_C_CHAR),
557 			sqltype(SQL_C_NUMERIC)
558 		{
559 		}
set_binarycppdb::odbc_backend::statement::parameter560 		void set_binary(char const *b,char const *e)
561 		{
562 			value.assign(b,e-b);
563 			null=false;
564 			ctype=SQL_C_BINARY;
565 			sqltype = SQL_LONGVARBINARY;
566 		}
set_textcppdb::odbc_backend::statement::parameter567 		void set_text(char const *b,char const *e,bool wide)
568 		{
569 			if(!wide) {
570 				value.assign(b,e-b);
571 				null=false;
572 				ctype=SQL_C_CHAR;
573 				sqltype = SQL_LONGVARCHAR;
574 			}
575 			else {
576 				std::string tmp = widen(b,e);
577 				value.swap(tmp);
578 				null=false;
579 				ctype=SQL_C_WCHAR;
580 				sqltype = SQL_WLONGVARCHAR;
581 			}
582 		}
setcppdb::odbc_backend::statement::parameter583 		void set(std::tm const &v)
584 		{
585 			value = cppdb::format_time(v);
586 			null=false;
587 			sqltype = SQL_C_TIMESTAMP;
588 			ctype = SQL_C_CHAR;
589 		}
590 
591 		template<typename T>
setcppdb::odbc_backend::statement::parameter592 		void set(T v)
593 		{
594 			std::ostringstream ss;
595 			ss.imbue(std::locale::classic());
596 			if(!std::numeric_limits<T>::is_integer)
597 				ss << std::setprecision(std::numeric_limits<T>::digits10+1);
598 			ss << v;
599 
600 			value=ss.str();
601 			null=false;
602 			ctype = SQL_C_CHAR;
603 			if(std::numeric_limits<T>::is_integer)
604 				sqltype = SQL_INTEGER;
605 			else
606 				sqltype = SQL_DOUBLE;
607 
608 		}
bindcppdb::odbc_backend::statement::parameter609 		void bind(int col,SQLHSTMT stmt,bool wide)
610 		{
611 			int r;
612 			if(null) {
613 				lenval = SQL_NULL_DATA;
614 				r = SQLBindParameter(	stmt,
615 							col,
616 							SQL_PARAM_INPUT,
617 							SQL_C_CHAR,
618 							SQL_NUMERIC, // for null
619 							10, // COLUMNSIZE
620 							0, //  Presision
621 							0, // string
622 							0, // size
623 							&lenval);
624 			}
625             else {
626                 lenval=value.size();
627                 size_t column_size = value.size();
628                 if(ctype == SQL_C_WCHAR)
629                     column_size/=2;
630                 if(value.empty())
631                     column_size=1;
632                 r = SQLBindParameter(	stmt,
633                             col,
634                             SQL_PARAM_INPUT,
635                             ctype,
636                             sqltype,
637                             column_size, // COLUMNSIZE
638                             0, //  Presision
639                             (void*)value.c_str(), // string
640                             value.size(),
641                             &lenval);
642             }
643 			check_odbc_error(r,stmt,SQL_HANDLE_STMT,wide);
644 		}
645 
646 		std::string value;
647 		bool null;
648 		SQLSMALLINT ctype;
649 		SQLSMALLINT sqltype;
650 		SQLLEN lenval;
651 	};
652 public:
653 	// Begin of API
reset()654 	virtual void reset()
655 	{
656 		SQLFreeStmt(stmt_,SQL_UNBIND);
657 		SQLCloseCursor(stmt_);
658 		params_.resize(0);
659 		if(params_no_ > 0)
660 			params_.resize(params_no_);
661 
662 	}
param_at(int col)663 	parameter &param_at(int col)
664 	{
665 		col --;
666 		if(col < 0)
667 			throw invalid_placeholder();
668 		if(params_no_ < 0) {
669 			if(params_.size() < size_t(col+1))
670 				params_.resize(col+1);
671 		}
672 		else if(col >= params_no_) {
673 			throw invalid_placeholder();
674 		}
675 		return params_[col];
676 	}
sql_query()677 	virtual std::string const &sql_query()
678 	{
679 		return query_;
680 	}
bind(int col,std::string const & s)681 	virtual void bind(int col,std::string const &s)
682 	{
683 		bind(col,s.c_str(),s.c_str()+s.size());
684 	}
bind(int col,char const * s)685 	virtual void bind(int col,char const *s)
686 	{
687 		bind(col,s,s+strlen(s));
688 	}
bind(int col,char const * b,char const * e)689 	virtual void bind(int col,char const *b,char const *e)
690 	{
691 		param_at(col).set_text(b,e,wide_);
692 	}
bind(int col,std::tm const & s)693 	virtual void bind(int col,std::tm const &s)
694 	{
695 		param_at(col).set(s);
696 	}
bind(int col,std::istream & in)697 	virtual void bind(int col,std::istream &in)
698 	{
699 		std::ostringstream ss;
700 		ss << in.rdbuf();
701 		std::string s = ss.str();
702 		param_at(col).set_binary(s.c_str(),s.c_str()+s.size());
703 	}
704 	template<typename T>
do_bind_num(int col,T v)705 	void do_bind_num(int col,T v)
706 	{
707 		param_at(col).set(v);
708 	}
bind(int col,int v)709 	virtual void bind(int col,int v)
710 	{
711 		do_bind_num(col,v);
712 	}
bind(int col,unsigned v)713 	virtual void bind(int col,unsigned v)
714 	{
715 		do_bind_num(col,v);
716 	}
bind(int col,long v)717 	virtual void bind(int col,long v)
718 	{
719 		do_bind_num(col,v);
720 	}
bind(int col,unsigned long v)721 	virtual void bind(int col,unsigned long v)
722 	{
723 		do_bind_num(col,v);
724 	}
bind(int col,long long v)725 	virtual void bind(int col,long long v)
726 	{
727 		do_bind_num(col,v);
728 	}
bind(int col,unsigned long long v)729 	virtual void bind(int col,unsigned long long v)
730 	{
731 		do_bind_num(col,v);
732 	}
bind(int col,double v)733 	virtual void bind(int col,double v)
734 	{
735 		do_bind_num(col,v);
736 	}
bind(int col,long double v)737 	virtual void bind(int col,long double v)
738 	{
739 		do_bind_num(col,v);
740 	}
bind_null(int col)741 	virtual void bind_null(int col)
742 	{
743 		param_at(col) = parameter();
744 	}
bind_all()745 	void bind_all()
746 	{
747 		for(unsigned i=0;i<params_.size();i++) {
748 			params_[i].bind(i+1,stmt_,wide_);
749 		}
750 
751 	}
sequence_last(std::string const & sequence)752 	virtual long long sequence_last(std::string const &sequence)
753 	{
754 		ref_ptr<statement> st;
755 		if(!sequence_last_.empty()) {
756 			st = new statement(sequence_last_,dbc_,wide_,false);
757 			st->bind(1,sequence);
758 		}
759 		else if(!last_insert_id_.empty()) {
760 			st = new statement(last_insert_id_,dbc_,wide_,false);
761 		}
762 		else {
763 			throw not_supported_by_backend(
764 				"cppdb::odbc::sequence_last is not supported by odbc backend "
765 				"unless properties @squence_last, @last_insert_id are specified "
766 				"or @engine is one of mysql, sqlite3, postgresql, mssql");
767 		}
768 		ref_ptr<result> res = st->query();
769 		long long last_id;
770 		if(!res->next() || res->cols()!=1 || !res->fetch(0,last_id)) {
771 			throw cppdb_error("cppdb::odbc::sequence_last failed to fetch last value");
772 		}
773 		res.reset();
774 		st.reset();
775 		return last_id;
776 	}
affected()777 	virtual unsigned long long affected()
778 	{
779 		SQLLEN rows = 0;
780 		int r = SQLRowCount(stmt_,&rows);
781 		check_error(r);
782 		return rows;
783 	}
query()784 	virtual result *query()
785 	{
786 		bind_all();
787 		int r = real_exec();
788 		check_error(r);
789 		result::rows_type rows;
790 		result::row_type row;
791 
792 		std::string value;
793 		bool is_null = false;
794 		SQLSMALLINT ocols;
795 		r = SQLNumResultCols(stmt_,&ocols);
796 		check_error(r);
797 		int cols = ocols;
798 
799 		std::vector<std::string> names(cols);
800 		std::vector<int> types(cols,SQL_C_CHAR);
801 
802 		for(int col=0;col < cols;col++) {
803 			SQLSMALLINT name_length=0,data_type=0,digits=0,nullable=0;
804 			SQLULEN collen = 0;
805 
806 			if(wide_) {
807 				SQLWCHAR name[257] = {0};
808 				r=SQLDescribeColW(stmt_,col+1,name,256,&name_length,&data_type,&collen,&digits,&nullable);
809 				check_error(r);
810 				names[col]=narrower(name);
811 			}
812 			else {
813 				SQLCHAR name[257] = {0};
814 				r=SQLDescribeColA(stmt_,col+1,name,256,&name_length,&data_type,&collen,&digits,&nullable);
815 				check_error(r);
816 				names[col]=(char*)name;
817 			}
818 			switch(data_type) {
819 			case SQL_CHAR:
820 			case SQL_VARCHAR:
821 			case SQL_LONGVARCHAR:
822 				types[col]=SQL_C_CHAR;
823 				break;
824 			case SQL_WCHAR:
825 			case SQL_WVARCHAR:
826 			case SQL_WLONGVARCHAR:
827 				types[col]=SQL_C_WCHAR ;
828 				break;
829 			case SQL_BINARY:
830 			case SQL_VARBINARY:
831 			case SQL_LONGVARBINARY:
832 				types[col]=SQL_C_BINARY ;
833 				break;
834 			default:
835 				types[col]=SQL_C_DEFAULT;
836 				// Just a hack, actually I'm going to use C
837 				;
838 			}
839 		}
840 
841 		while((r=SQLFetch(stmt_))==SQL_SUCCESS || r==SQL_SUCCESS_WITH_INFO) {
842 			row.resize(cols);
843 			for(int col=0;col < cols;col++) {
844 				SQLLEN len = 0;
845 				is_null=false;
846 				int type = types[col];
847 				if(type==SQL_C_DEFAULT) {
848 					char buf[64];
849 					int r = SQLGetData(stmt_,col+1,SQL_C_CHAR,buf,sizeof(buf),&len);
850 					check_error(r);
851 					if(len == SQL_NULL_DATA) {
852 						is_null = true;
853 					}
854 					else if(len <= 64) {
855 						value.assign(buf,len);
856 					}
857 					else {
858 						throw cppdb_error("cppdb::odbc::query - data too long");
859 					}
860 				}
861 				else {
862 					char buf[1024];
863 					size_t real_len;
864 					if(type == SQL_C_CHAR) {
865 						real_len = sizeof(buf)-1;
866 					}
867 					else if(type == SQL_C_BINARY) {
868 						real_len = sizeof(buf);
869 					}
870 					else { // SQL_C_WCHAR
871 						real_len = sizeof(buf) - sizeof(SQLWCHAR);
872 					}
873 
874 					r = SQLGetData(stmt_,col+1,type,buf,sizeof(buf),&len);
875 					check_error(r);
876 					if(len == SQL_NULL_DATA) {
877 						is_null = true;
878 					}
879 					else if(len == SQL_NO_TOTAL) {
880 						while(len==SQL_NO_TOTAL) {
881 							value.append(buf,real_len);
882 							r = SQLGetData(stmt_,col+1,type,buf,sizeof(buf),&len);
883 							check_error(r);
884 						}
885 						value.append(buf,len);
886 					}
887 					else if(0<= len && size_t(len) <= real_len) {
888 						value.assign(buf,len);
889 					}
890 					else if(len>=0) {
891 						value.assign(buf,real_len);
892 						size_t rem_len = len - real_len;
893 						std::vector<char> tmp(rem_len+2,0);
894 						r = SQLGetData(stmt_,col+1,type,&tmp[0],tmp.size(),&len);
895 						check_error(r);
896 						value.append(&tmp[0],rem_len);
897 					}
898 					else {
899 						throw cppdb_error("cppdb::odbc::query invalid result length");
900 					}
901 					if(!is_null && type == SQL_C_WCHAR) {
902 						std::string tmp=narrower(value);
903 						value.swap(tmp);
904 					}
905 				}
906 
907 				row[col].first = is_null;
908 				row[col].second.swap(value);
909 			}
910 			rows.push_back(result::row_type());
911 			rows.back().swap(row);
912 		}
913 		if(r!=SQL_NO_DATA) {
914 			check_error(r);
915 		}
916 		return new result(rows,names,cols);
917 	}
918 
real_exec()919 	int real_exec()
920 	{
921 		int r = 0;
922 		if(prepared_) {
923 			r=SQLExecute(stmt_);
924 		}
925 		else {
926 			if(wide_)
927 				r=SQLExecDirectW(stmt_,(SQLWCHAR*)tosqlwide(query_).c_str(),SQL_NTS);
928 			else
929 				r=SQLExecDirectA(stmt_,(SQLCHAR*)query_.c_str(),SQL_NTS);
930 		}
931 		return r;
932 	}
exec()933 	virtual void exec()
934 	{
935 		bind_all();
936 		int r=real_exec();
937 		if(r!=SQL_NO_DATA)
938 			check_error(r);
939 	}
940 	// End of API
941 
statement(std::string const & q,SQLHDBC dbc,bool wide,bool prepared)942 	statement(std::string const &q,SQLHDBC dbc,bool wide,bool prepared) :
943 		dbc_(dbc),
944 		wide_(wide),
945 		query_(q),
946 		params_no_(-1),
947 		prepared_(prepared)
948 	{
949 		SQLRETURN r = SQLAllocHandle(SQL_HANDLE_STMT,dbc,&stmt_);
950 		check_odbc_error(r,dbc,SQL_HANDLE_DBC,wide_);
951 		if(prepared_) {
952 			try {
953 				if(wide_) {
954 					r = SQLPrepareW(
955 						stmt_,
956 						(SQLWCHAR*)tosqlwide(query_).c_str(),
957 						SQL_NTS);
958 				}
959 				else {
960 					r = SQLPrepareA(
961 						stmt_,
962 						(SQLCHAR*)query_.c_str(),
963 						SQL_NTS);
964 				}
965 				check_error(r);
966 			}
967 			catch(...) {
968 				SQLFreeHandle(SQL_HANDLE_STMT,stmt_);
969 				throw;
970 			}
971 			SQLSMALLINT params_no;
972 			r = SQLNumParams(stmt_,&params_no);
973 			check_error(r);
974 			params_no_ = params_no;
975 			params_.resize(params_no_);
976 		}
977 		else {
978 			params_.reserve(50);
979 		}
980 	}
~statement()981 	~statement()
982 	{
983 		SQLFreeHandle(SQL_HANDLE_STMT,stmt_);
984 	}
985 private:
check_error(int code)986 	void check_error(int code)
987 	{
988 		check_odbc_error(code,stmt_,SQL_HANDLE_STMT,wide_);
989 	}
990 
991 
992 	SQLHDBC dbc_;
993 	SQLHSTMT stmt_;
994 	bool wide_;
995 	std::string query_;
996 	std::vector<parameter> params_;
997 	int params_no_;
998 
999 	friend class connection;
1000 	std::string sequence_last_;
1001 	std::string last_insert_id_;
1002 	bool prepared_;
1003 
1004 };
1005 
1006 class connection : public backend::connection {
1007 public:
1008 
connection(connection_info const & ci)1009 	connection(connection_info const &ci) :
1010 		backend::connection(ci),
1011 		ci_(ci)
1012 	{
1013 		std::string utf_mode = ci.get("@utf","narrow");
1014 
1015 		if(utf_mode == "narrow")
1016 			wide_ = false;
1017 		else if(utf_mode == "wide")
1018 			wide_ = true;
1019 		else
1020 			throw cppdb_error("cppdb::odbc:: @utf property can be either 'narrow' or 'wide'");
1021 
1022 		bool env_created = false;
1023 		bool dbc_created = false;
1024 		bool dbc_connected = false;
1025 
1026 		try {
1027 			SQLRETURN r = SQLAllocHandle(SQL_HANDLE_ENV,SQL_NULL_HANDLE,&env_);
1028 			if(!SQL_SUCCEEDED(r)) {
1029 				throw cppdb_error("cppdb::odbc::Failed to allocate environment handle");
1030 			}
1031 			env_created = true;
1032 			r = SQLSetEnvAttr(env_,SQL_ATTR_ODBC_VERSION,(SQLPOINTER)SQL_OV_ODBC3, 0);
1033 			check_odbc_error(r,env_,SQL_HANDLE_ENV,wide_);
1034 			r = SQLAllocHandle(SQL_HANDLE_DBC,env_,&dbc_);
1035 			check_odbc_error(r,env_,SQL_HANDLE_ENV,wide_);
1036 			dbc_created = true;
1037 			if(wide_) {
1038 				r = SQLDriverConnectW(dbc_,0,
1039 							(SQLWCHAR*)tosqlwide(conn_str(ci)).c_str(),
1040 							SQL_NTS,0,0,0,SQL_DRIVER_COMPLETE);
1041 			}
1042 			else {
1043 				r = SQLDriverConnectA(dbc_,0,
1044 							(SQLCHAR*)conn_str(ci).c_str(),
1045 							SQL_NTS,0,0,0,SQL_DRIVER_COMPLETE);
1046 			}
1047 			check_odbc_error(r,dbc_,SQL_HANDLE_DBC,wide_);
1048 		}
1049 		catch(...) {
1050 			if(dbc_connected)
1051 				SQLDisconnect(dbc_);
1052 			if(dbc_created)
1053 				SQLFreeHandle(SQL_HANDLE_DBC,dbc_);
1054 			if(env_created)
1055 				SQLFreeHandle(SQL_HANDLE_ENV,env_);
1056 			throw;
1057 		}
1058 	}
1059 
conn_str(connection_info const & ci)1060 	std::string conn_str(connection_info const &ci)
1061 	{
1062 		std::map<std::string,std::string>::const_iterator p;
1063 		std::string str;
1064 		for(p=ci.properties.begin();p!=ci.properties.end();p++) {
1065 			if(p->first.empty() || p->first[0]=='@')
1066 				continue;
1067 			str+=p->first;
1068 			str+="=";
1069 			str+=p->second;
1070 			str+=";";
1071 		}
1072 		return str;
1073 	}
1074 
~connection()1075 	~connection()
1076 	{
1077 		SQLDisconnect(dbc_);
1078 		SQLFreeHandle(SQL_HANDLE_DBC,dbc_);
1079 		SQLFreeHandle(SQL_HANDLE_ENV,env_);
1080 	}
1081 
1082 	/// API
begin()1083 	virtual void begin()
1084 	{
1085 		set_autocommit(false);
1086 	}
commit()1087 	virtual void commit()
1088 	{
1089 		SQLRETURN r = SQLEndTran(SQL_HANDLE_DBC,dbc_,SQL_COMMIT);
1090 		check_odbc_error(r,dbc_,SQL_HANDLE_DBC,wide_);
1091 		set_autocommit(true);
1092 	}
1093 
rollback()1094 	virtual void rollback()
1095 	{
1096 		try {
1097 			SQLRETURN r = SQLEndTran(SQL_HANDLE_DBC,dbc_,SQL_ROLLBACK);
1098 			check_odbc_error(r,dbc_,SQL_HANDLE_DBC,wide_);
1099 		}catch(...) {}
1100 		try {
1101 			set_autocommit(true);
1102 		}catch(...){}
1103 	}
real_prepare(std::string const & q,bool prepared)1104 	statement *real_prepare(std::string const &q,bool prepared)
1105 	{
1106 		std::auto_ptr<statement> st(new statement(q,dbc_,wide_,prepared));
1107 		std::string seq = ci_.get("@sequence_last","");
1108 		if(seq.empty()) {
1109 			std::string eng=engine();
1110 			if(eng == "sqlite3")
1111 				st->last_insert_id_ = "select last_insert_rowid()";
1112 			else if(eng == "mysql")
1113 				st->last_insert_id_ = "select last_insert_id()";
1114 			else if(eng == "postgresql")
1115 				st->sequence_last_ = "select currval(?)";
1116 			else if(eng == "mssql")
1117 				st->last_insert_id_ = "select @@identity";
1118 		}
1119 		else {
1120 			if(seq.find('?')==std::string::npos)
1121 				st->last_insert_id_ = seq;
1122 			else
1123 				st->sequence_last_ = seq;
1124 		}
1125 
1126 		return st.release();
1127 	}
1128 
prepare_statement(std::string const & q)1129 	virtual statement *prepare_statement(std::string const &q)
1130 	{
1131 		return real_prepare(q,true);
1132 	}
1133 
create_statement(std::string const & q)1134 	virtual statement *create_statement(std::string const &q)
1135 	{
1136 		return real_prepare(q,false);
1137 	}
1138 
1139 
escape(std::string const & s)1140 	virtual std::string escape(std::string const &s)
1141 	{
1142 		return escape(s.c_str(),s.c_str()+s.size());
1143 	}
escape(char const * s)1144 	virtual std::string escape(char const *s)
1145 	{
1146 		return escape(s,s+strlen(s));
1147 	}
escape(char const *,char const *)1148 	virtual std::string escape(char const * /*b*/,char const * /*e*/)
1149 	{
1150 		throw not_supported_by_backend("cppcms::odbc:: string escaping is not supported");
1151 	}
driver()1152 	virtual std::string driver()
1153 	{
1154 		return "odbc";
1155 	}
engine()1156 	virtual std::string engine()
1157 	{
1158 		return ci_.get("@engine","unknown");
1159 	}
1160 
set_autocommit(bool on)1161 	void set_autocommit(bool on)
1162 	{
1163 		SQLPOINTER mode = (SQLPOINTER)(on ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF);
1164 		SQLRETURN r = SQLSetConnectAttr(
1165 					dbc_, // handler
1166 					SQL_ATTR_AUTOCOMMIT, // option
1167 					mode, //value
1168 					0);
1169 		check_odbc_error(r,dbc_,SQL_HANDLE_DBC,wide_);
1170 	}
1171 
1172 private:
1173 	SQLHENV env_;
1174 	SQLHDBC dbc_;
1175 	bool wide_;
1176 	connection_info ci_;
1177 };
1178 
1179 
1180 } // odbc_backend
1181 } // cppdb
1182 
1183 extern "C" {
cppdb_odbc_get_connection(cppdb::connection_info const & cs)1184 	CPPDB_DRIVER_API cppdb::backend::connection *cppdb_odbc_get_connection(cppdb::connection_info const &cs)
1185 	{
1186 		return new cppdb::odbc_backend::connection(cs);
1187 	}
1188 }
1189