1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2010-2011 Artyom Beilis (Tonkikh) <artyomtnk@yahoo.com>
4 //
5 // Distributed under:
6 //
7 // the Boost Software License, Version 1.0.
8 // (See accompanying file LICENSE_1_0.txt or copy at
9 // http://www.boost.org/LICENSE_1_0.txt)
10 //
11 // or (at your opinion) under:
12 //
13 // The MIT License
14 // (See accompanying file MIT.txt or a copy at
15 // http://www.opensource.org/licenses/mit-license.php)
16 //
17 ///////////////////////////////////////////////////////////////////////////////
18 #define CPPDB_DRIVER_SOURCE
19 #ifdef CPPDB_WITH_ODBC
20 #define CPPDB_SOURCE
21 #endif
22 #include <cppdb/backend.h>
23 #include <cppdb/utils.h>
24 #include <cppdb/numeric_util.h>
25 #include <list>
26 #include <vector>
27 #include <iostream>
28 #include <sstream>
29 #include <limits>
30 #include <iomanip>
31 #include <string.h>
32
33 #if defined(_WIN32) || defined(__WIN32) || defined(WIN32) || defined(__CYGWIN__)
34 #include <windows.h>
35 #endif
36 #include <sqlext.h>
37
38 namespace cppdb {
39 namespace odbc_backend {
40
41 typedef unsigned odbc_u32;
42 typedef unsigned short odbc_u16;
43
44 int assert_on_unsigned_is_32[sizeof(unsigned) == 4 ? 1 : -1];
45 int assert_on_unsigned_short_is_16[sizeof(unsigned short) == 2 ? 1 : -1];
46 int assert_on_sqlwchar_is_16[sizeof(SQLWCHAR) == 2 ? 1 : -1];
47
48
49
50 namespace utf {
51 static const odbc_u32 illegal = 0xFFFFFFFFu;
valid(odbc_u32 v)52 inline bool valid(odbc_u32 v)
53 {
54 if(v>0x10FFFF)
55 return false;
56 if(0xD800 <=v && v<= 0xDFFF) // surragates
57 return false;
58 return true;
59 }
60 }
61
62 namespace utf8 {
63 // See RFC 3629
64 // Based on: http://www.w3.org/International/questions/qa-forms-utf-8
65 template<typename Iterator>
next(Iterator & p,Iterator e)66 odbc_u32 next(Iterator &p,Iterator e)
67 {
68 unsigned char c=*p++;
69 unsigned char seq0,seq1=0,seq2=0,seq3=0;
70 seq0=c;
71 int len=1;
72 if((c & 0xC0) == 0xC0) {
73 if(p==e)
74 return utf::illegal;
75 seq1=*p++;
76 len=2;
77 }
78 if((c & 0xE0) == 0xE0) {
79 if(p==e)
80 return utf::illegal;
81 seq2=*p++;
82 len=3;
83 }
84 if((c & 0xF0) == 0xF0) {
85 if(p==e)
86 return utf::illegal;
87 seq3=*p++;
88 len=4;
89 }
90 switch(len) {
91 case 1: // ASCII -- remove codes for HTML only
92 if(seq0 > 0x7F)
93 return utf::illegal;
94 break;
95 case 2: // non-overloading 2 bytes
96 if(0xC2 <= seq0 && seq0 <= 0xDF) {
97 if(0x80 <= seq1 && seq1<= 0xBF)
98 break;
99 }
100 return utf::illegal;
101 case 3:
102 if(seq0==0xE0) { // exclude overloadings
103 if(0xA0 <=seq1 && seq1<= 0xBF && 0x80 <=seq2 && seq2<=0xBF)
104 break;
105 }
106 else if( (0xE1 <= seq0 && seq0 <=0xEC) || seq0==0xEE || seq0==0xEF) { // stright 3 bytes
107 if( 0x80 <=seq1 && seq1<=0xBF &&
108 0x80 <=seq2 && seq2<=0xBF)
109 break;
110 }
111 else if(seq0 == 0xED) { // exclude surrogates
112 if( 0x80 <=seq1 && seq1<=0x9F &&
113 0x80 <=seq2 && seq2<=0xBF)
114 break;
115 }
116 return utf::illegal;
117 case 4:
118 switch(seq0) {
119 case 0xF0: // planes 1-3
120 if( 0x90 <=seq1 && seq1<=0xBF &&
121 0x80 <=seq2 && seq2<=0xBF &&
122 0x80 <=seq3 && seq3<=0xBF)
123 break;
124 return utf::illegal;
125 case 0xF1: // planes 4-15
126 case 0xF2:
127 case 0xF3:
128 if( 0x80 <=seq1 && seq1<=0xBF &&
129 0x80 <=seq2 && seq2<=0xBF &&
130 0x80 <=seq3 && seq3<=0xBF)
131 break;
132 return utf::illegal;
133 case 0xF4: // pane 16
134 if( 0x80 <=seq1 && seq1<=0x8F &&
135 0x80 <=seq2 && seq2<=0xBF &&
136 0x80 <=seq3 && seq3<=0xBF)
137 break;
138 return utf::illegal;
139 default:
140 return utf::illegal;
141 }
142
143 }
144
145 switch(len) {
146 case 1:
147 return seq0;
148 case 2:
149 return ((seq0 & 0x1F) << 6) | (seq1 & 0x3F);
150 case 3:
151 return ((seq0 & 0x0F) << 12) | ((seq1 & 0x3F) << 6) | (seq2 & 0x3F) ;
152 case 4:
153 return ((seq0 & 0x07) << 18) | ((seq1 & 0x3F) << 12) | ((seq2 & 0x3F) << 6) | (seq3 & 0x3F) ;
154 }
155
156 return utf::illegal;
157 } // valid
158
159
160 struct seq {
161 char c[4];
162 unsigned len;
163 };
encode(odbc_u32 value)164 inline seq encode(odbc_u32 value)
165 {
166 seq out=seq();
167 if(value <=0x7F) {
168 out.c[0]=value;
169 out.len=1;
170 }
171 else if(value <=0x7FF) {
172 out.c[0]=(value >> 6) | 0xC0;
173 out.c[1]=(value & 0x3F) | 0x80;
174 out.len=2;
175 }
176 else if(value <=0xFFFF) {
177 out.c[0]=(value >> 12) | 0xE0;
178 out.c[1]=((value >> 6) & 0x3F) | 0x80;
179 out.c[2]=(value & 0x3F) | 0x80;
180 out.len=3;
181 }
182 else {
183 out.c[0]=(value >> 18) | 0xF0;
184 out.c[1]=((value >> 12) & 0x3F) | 0x80;
185 out.c[2]=((value >> 6) & 0x3F) | 0x80;
186 out.c[3]=(value & 0x3F) | 0x80;
187 out.len=4;
188 }
189 return out;
190 }
191 } // namespace utf8
192
193
194 namespace utf16 {
195
196 // See RFC 2781
is_first_surrogate(odbc_u16 x)197 inline bool is_first_surrogate(odbc_u16 x)
198 {
199 return 0xD800 <=x && x<= 0xDBFF;
200 }
is_second_surrogate(odbc_u16 x)201 inline bool is_second_surrogate(odbc_u16 x)
202 {
203 return 0xDC00 <=x && x<= 0xDFFF;
204 }
combine_surrogate(odbc_u16 w1,odbc_u16 w2)205 inline odbc_u32 combine_surrogate(odbc_u16 w1,odbc_u16 w2)
206 {
207 return ((odbc_u32(w1 & 0x3FF) << 10) | (w2 & 0x3FF)) + 0x10000;
208 }
209
210 template<typename It>
next(It & current,It last)211 inline odbc_u32 next(It ¤t,It last)
212 {
213 odbc_u16 w1=*current++;
214 if(w1 < 0xD800 || 0xDFFF < w1) {
215 return w1;
216 }
217 if(w1 > 0xDBFF)
218 return utf::illegal;
219 if(current==last)
220 return utf::illegal;
221 odbc_u16 w2=*current++;
222 if(w2 < 0xDC00 || 0xDFFF < w2)
223 return utf::illegal;
224 return combine_surrogate(w1,w2);
225 }
width(odbc_u32 u)226 inline int width(odbc_u32 u)
227 {
228 return u>=0x100000 ? 2 : 1;
229 }
230 struct seq {
231 odbc_u16 c[2];
232 unsigned len;
233 };
encode(odbc_u32 u)234 inline seq encode(odbc_u32 u)
235 {
236 seq out=seq();
237 if(u<=0xFFFF) {
238 out.c[0]=u;
239 out.len=1;
240 }
241 else {
242 u-=0x10000;
243 out.c[0]=0xD800 | (u>>10);
244 out.c[1]=0xDC00 | (u & 0x3FF);
245 out.len=2;
246 }
247 return out;
248 }
249 } // utf16;
250
251 } // odbc_backend
252 } // cppdb
253
254
255 namespace cppdb {
256
257 class connection_info;
258 class pool;
259
260 namespace odbc_backend {
261
widen(char const * b,char const * e)262 std::string widen(char const *b,char const *e)
263 {
264 std::string result;
265 result.reserve((e-b)*2);
266 odbc_u32 code_point = 0;
267 while(b < e && (code_point = utf8::next(b,e))!=utf::illegal) {
268 utf16::seq sq = utf16::encode(code_point);
269 char *str = (char *)sq.c;
270 result.append(str,sq.len * 2);
271 }
272 if(b!=e || code_point == utf::illegal) {
273 throw cppdb_error("cppdb::odbc invalid UTF-8 input");
274 }
275 return result;
276 }
277
widen(std::string const & s)278 std::string widen(std::string const &s)
279 {
280 return widen(s.c_str(),s.c_str()+s.size());
281 }
282
narrower(std::basic_string<SQLWCHAR> const & wide)283 std::string narrower(std::basic_string<SQLWCHAR> const &wide)
284 {
285 odbc_u16 const *b = reinterpret_cast<odbc_u16 const *>(wide.c_str());
286 odbc_u16 const *e = b + wide.size();
287
288 std::string result;
289 result.reserve((e-b));
290
291 odbc_u32 code_point = 0;
292 while(b < e && (code_point = utf16::next(b,e))!=utf::illegal) {
293 utf8::seq sq = utf8::encode(code_point);
294 result.append(sq.c,sq.len);
295 }
296 if(b!=e || code_point == utf::illegal) {
297 throw cppdb_error("cppdb::odbc got invalid UTF-16");
298 }
299 return result;
300 }
301
narrower(std::string const & wide)302 std::string narrower(std::string const &wide)
303 {
304 if(wide.size() % 2 != 0) {
305 throw cppdb_error("cppdb::odbc got invalid UTF-16");
306 }
307 odbc_u16 const *b = reinterpret_cast<odbc_u16 const *>(wide.c_str());
308 odbc_u16 const *e = b + wide.size() / 2;
309
310 std::string result;
311 result.reserve((e-b));
312
313 odbc_u32 code_point = 0;
314 while(b < e && (code_point = utf16::next(b,e))!=utf::illegal) {
315 utf8::seq sq = utf8::encode(code_point);
316 result.append(sq.c,sq.len);
317 }
318 if(b!=e || code_point == utf::illegal) {
319 throw cppdb_error("cppdb::odbc got invalid UTF-16");
320 }
321 return result;
322 }
323
tosqlwide(std::string const & n)324 std::basic_string<SQLWCHAR> tosqlwide(std::string const &n)
325 {
326 std::basic_string<SQLWCHAR> result;
327 char const *b=n.c_str();
328 char const *e=b+n.size();
329 result.reserve(e-b);
330 odbc_u32 code_point = 0;
331 while(b < e && (code_point = utf8::next(b,e))!=utf::illegal) {
332 utf16::seq sq = utf16::encode(code_point);
333 result.append((SQLWCHAR*)sq.c,sq.len);
334 }
335 if(b!=e || code_point == utf::illegal) {
336 throw cppdb_error("cppdb::odbc invalid UTF-8 input");
337 }
338 return result;
339 }
340
check_odbc_errorW(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type)341 void check_odbc_errorW(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type)
342 {
343 if(SQL_SUCCEEDED(error))
344 return;
345 std::basic_string<SQLWCHAR> error_message;
346 int rec=1,r;
347 for(;;){
348 SQLWCHAR msg[SQL_MAX_MESSAGE_LENGTH + 2] = {0};
349 SQLWCHAR stat[SQL_SQLSTATE_SIZE + 1] = {0};
350 SQLINTEGER err;
351 SQLSMALLINT len;
352 r = SQLGetDiagRecW(type,h,rec,stat,&err,msg,sizeof(msg)/sizeof(SQLWCHAR),&len);
353 rec++;
354 if(r==SQL_SUCCESS || r==SQL_SUCCESS_WITH_INFO) {
355 if(!error_message.empty()) {
356 SQLWCHAR nl = '\n';
357 error_message+=nl;
358 }
359 error_message.append(msg);
360 }
361 else
362 break;
363
364 }
365 std::string utf8_str = "Unconvertable string";
366 try { std::string tmp = narrower(error_message); utf8_str = tmp; } catch(...){}
367 throw cppdb_error("cppdb::odbc_backend::Failed with error `" + utf8_str +"'");
368 }
369
check_odbc_errorA(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type)370 void check_odbc_errorA(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type)
371 {
372 if(SQL_SUCCEEDED(error))
373 return;
374 std::string error_message;
375 int rec=1,r;
376 for(;;){
377 SQLCHAR msg[SQL_MAX_MESSAGE_LENGTH + 2] = {0};
378 SQLCHAR stat[SQL_SQLSTATE_SIZE + 1] = {0};
379 SQLINTEGER err;
380 SQLSMALLINT len;
381 r = SQLGetDiagRecA(type,h,rec,stat,&err,msg,sizeof(msg),&len);
382 rec++;
383 if(r==SQL_SUCCESS || r==SQL_SUCCESS_WITH_INFO) {
384 if(!error_message.empty())
385 error_message+='\n';
386 error_message +=(char *)msg;
387 }
388 else
389 break;
390
391 }
392 throw cppdb_error("cppdb::odbc::Failed with error `" + error_message +"'");
393 }
394
check_odbc_error(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type,bool wide)395 void check_odbc_error(SQLRETURN error,SQLHANDLE h,SQLSMALLINT type,bool wide)
396 {
397 if(wide)
398 check_odbc_errorW(error,h,type);
399 else
400 check_odbc_errorA(error,h,type);
401 }
402
403
404
405 class result : public backend::result {
406 public:
407 typedef std::pair<bool,std::string> cell_type;
408 typedef std::vector<cell_type> row_type;
409 typedef std::list<row_type> rows_type;
410
has_next()411 virtual next_row has_next()
412 {
413 rows_type::iterator p=current_;
414 if(p == rows_.end() || ++p==rows_.end())
415 return last_row_reached;
416 else
417 return next_row_exists;
418 }
next()419 virtual bool next()
420 {
421 if(started_ == false) {
422 current_ = rows_.begin();
423 started_ = true;
424 }
425 else if(current_!=rows_.end()) {
426 ++current_;
427 }
428 return current_!=rows_.end();
429 }
430 template<typename T>
do_fetch(int col,T & v)431 bool do_fetch(int col,T &v)
432 {
433 if(at(col).first)
434 return false;
435 v=parse_number<T>(at(col).second,ss_);
436 return true;
437 }
fetch(int col,short & v)438 virtual bool fetch(int col,short &v)
439 {
440 return do_fetch(col,v);
441 }
fetch(int col,unsigned short & v)442 virtual bool fetch(int col,unsigned short &v)
443 {
444 return do_fetch(col,v);
445 }
fetch(int col,int & v)446 virtual bool fetch(int col,int &v)
447 {
448 return do_fetch(col,v);
449 }
fetch(int col,unsigned & v)450 virtual bool fetch(int col,unsigned &v)
451 {
452 return do_fetch(col,v);
453 }
fetch(int col,long & v)454 virtual bool fetch(int col,long &v)
455 {
456 return do_fetch(col,v);
457 }
fetch(int col,unsigned long & v)458 virtual bool fetch(int col,unsigned long &v)
459 {
460 return do_fetch(col,v);
461 }
fetch(int col,long long & v)462 virtual bool fetch(int col,long long &v)
463 {
464 return do_fetch(col,v);
465 }
fetch(int col,unsigned long long & v)466 virtual bool fetch(int col,unsigned long long &v)
467 {
468 return do_fetch(col,v);
469 }
fetch(int col,float & v)470 virtual bool fetch(int col,float &v)
471 {
472 return do_fetch(col,v);
473 }
fetch(int col,double & v)474 virtual bool fetch(int col,double &v)
475 {
476 return do_fetch(col,v);
477 }
fetch(int col,long double & v)478 virtual bool fetch(int col,long double &v)
479 {
480 return do_fetch(col,v);
481 }
fetch(int col,std::string & v)482 virtual bool fetch(int col,std::string &v)
483 {
484 if(at(col).first)
485 return false;
486 v=at(col).second;
487 return true;
488 }
fetch(int col,std::ostream & v)489 virtual bool fetch(int col,std::ostream &v)
490 {
491 if(at(col).first)
492 return false;
493 v << at(col).second;
494 return true;
495 }
fetch(int col,std::tm & v)496 virtual bool fetch(int col,std::tm &v)
497 {
498 if(at(col).first)
499 return false;
500 v = parse_time(at(col).second);
501 return true;
502 }
is_null(int col)503 virtual bool is_null(int col)
504 {
505 return at(col).first;
506 }
cols()507 virtual int cols()
508 {
509 return cols_;
510 }
name_to_column(std::string const & cn)511 virtual int name_to_column(std::string const &cn)
512 {
513 for(unsigned i=0;i<names_.size();i++)
514 if(names_[i]==cn)
515 return i;
516 return -1;
517 }
column_to_name(int c)518 virtual std::string column_to_name(int c)
519 {
520 if(c < 0 || c >= int(names_.size()))
521 throw invalid_column();
522 return names_[c];
523 }
524
result(rows_type & rows,std::vector<std::string> & names,int cols)525 result(rows_type &rows,std::vector<std::string> &names,int cols) : cols_(cols)
526 {
527 names_.swap(names);
528 rows_.swap(rows);
529 started_ = false;
530 current_ = rows_.end();
531 ss_.imbue(std::locale::classic());
532 }
at(int col)533 cell_type &at(int col)
534 {
535 if(current_!=rows_.end() && col >= 0 && col <int(current_->size()))
536 return current_->at(col);
537 throw invalid_column();
538 }
539 private:
540 int cols_;
541 bool started_;
542 std::vector<std::string> names_;
543 rows_type::iterator current_;
544 rows_type rows_;
545 std::istringstream ss_;
546 };
547
548 class statements_cache;
549
550 class connection;
551
552 class statement : public backend::statement {
553 struct parameter {
parametercppdb::odbc_backend::statement::parameter554 parameter() :
555 null(true),
556 ctype(SQL_C_CHAR),
557 sqltype(SQL_C_NUMERIC)
558 {
559 }
set_binarycppdb::odbc_backend::statement::parameter560 void set_binary(char const *b,char const *e)
561 {
562 value.assign(b,e-b);
563 null=false;
564 ctype=SQL_C_BINARY;
565 sqltype = SQL_LONGVARBINARY;
566 }
set_textcppdb::odbc_backend::statement::parameter567 void set_text(char const *b,char const *e,bool wide)
568 {
569 if(!wide) {
570 value.assign(b,e-b);
571 null=false;
572 ctype=SQL_C_CHAR;
573 sqltype = SQL_LONGVARCHAR;
574 }
575 else {
576 std::string tmp = widen(b,e);
577 value.swap(tmp);
578 null=false;
579 ctype=SQL_C_WCHAR;
580 sqltype = SQL_WLONGVARCHAR;
581 }
582 }
setcppdb::odbc_backend::statement::parameter583 void set(std::tm const &v)
584 {
585 value = cppdb::format_time(v);
586 null=false;
587 sqltype = SQL_C_TIMESTAMP;
588 ctype = SQL_C_CHAR;
589 }
590
591 template<typename T>
setcppdb::odbc_backend::statement::parameter592 void set(T v)
593 {
594 std::ostringstream ss;
595 ss.imbue(std::locale::classic());
596 if(!std::numeric_limits<T>::is_integer)
597 ss << std::setprecision(std::numeric_limits<T>::digits10+1);
598 ss << v;
599
600 value=ss.str();
601 null=false;
602 ctype = SQL_C_CHAR;
603 if(std::numeric_limits<T>::is_integer)
604 sqltype = SQL_INTEGER;
605 else
606 sqltype = SQL_DOUBLE;
607
608 }
bindcppdb::odbc_backend::statement::parameter609 void bind(int col,SQLHSTMT stmt,bool wide)
610 {
611 int r;
612 if(null) {
613 lenval = SQL_NULL_DATA;
614 r = SQLBindParameter( stmt,
615 col,
616 SQL_PARAM_INPUT,
617 SQL_C_CHAR,
618 SQL_NUMERIC, // for null
619 10, // COLUMNSIZE
620 0, // Presision
621 0, // string
622 0, // size
623 &lenval);
624 }
625 else {
626 lenval=value.size();
627 size_t column_size = value.size();
628 if(ctype == SQL_C_WCHAR)
629 column_size/=2;
630 if(value.empty())
631 column_size=1;
632 r = SQLBindParameter( stmt,
633 col,
634 SQL_PARAM_INPUT,
635 ctype,
636 sqltype,
637 column_size, // COLUMNSIZE
638 0, // Presision
639 (void*)value.c_str(), // string
640 value.size(),
641 &lenval);
642 }
643 check_odbc_error(r,stmt,SQL_HANDLE_STMT,wide);
644 }
645
646 std::string value;
647 bool null;
648 SQLSMALLINT ctype;
649 SQLSMALLINT sqltype;
650 SQLLEN lenval;
651 };
652 public:
653 // Begin of API
reset()654 virtual void reset()
655 {
656 SQLFreeStmt(stmt_,SQL_UNBIND);
657 SQLCloseCursor(stmt_);
658 params_.resize(0);
659 if(params_no_ > 0)
660 params_.resize(params_no_);
661
662 }
param_at(int col)663 parameter ¶m_at(int col)
664 {
665 col --;
666 if(col < 0)
667 throw invalid_placeholder();
668 if(params_no_ < 0) {
669 if(params_.size() < size_t(col+1))
670 params_.resize(col+1);
671 }
672 else if(col >= params_no_) {
673 throw invalid_placeholder();
674 }
675 return params_[col];
676 }
sql_query()677 virtual std::string const &sql_query()
678 {
679 return query_;
680 }
bind(int col,std::string const & s)681 virtual void bind(int col,std::string const &s)
682 {
683 bind(col,s.c_str(),s.c_str()+s.size());
684 }
bind(int col,char const * s)685 virtual void bind(int col,char const *s)
686 {
687 bind(col,s,s+strlen(s));
688 }
bind(int col,char const * b,char const * e)689 virtual void bind(int col,char const *b,char const *e)
690 {
691 param_at(col).set_text(b,e,wide_);
692 }
bind(int col,std::tm const & s)693 virtual void bind(int col,std::tm const &s)
694 {
695 param_at(col).set(s);
696 }
bind(int col,std::istream & in)697 virtual void bind(int col,std::istream &in)
698 {
699 std::ostringstream ss;
700 ss << in.rdbuf();
701 std::string s = ss.str();
702 param_at(col).set_binary(s.c_str(),s.c_str()+s.size());
703 }
704 template<typename T>
do_bind_num(int col,T v)705 void do_bind_num(int col,T v)
706 {
707 param_at(col).set(v);
708 }
bind(int col,int v)709 virtual void bind(int col,int v)
710 {
711 do_bind_num(col,v);
712 }
bind(int col,unsigned v)713 virtual void bind(int col,unsigned v)
714 {
715 do_bind_num(col,v);
716 }
bind(int col,long v)717 virtual void bind(int col,long v)
718 {
719 do_bind_num(col,v);
720 }
bind(int col,unsigned long v)721 virtual void bind(int col,unsigned long v)
722 {
723 do_bind_num(col,v);
724 }
bind(int col,long long v)725 virtual void bind(int col,long long v)
726 {
727 do_bind_num(col,v);
728 }
bind(int col,unsigned long long v)729 virtual void bind(int col,unsigned long long v)
730 {
731 do_bind_num(col,v);
732 }
bind(int col,double v)733 virtual void bind(int col,double v)
734 {
735 do_bind_num(col,v);
736 }
bind(int col,long double v)737 virtual void bind(int col,long double v)
738 {
739 do_bind_num(col,v);
740 }
bind_null(int col)741 virtual void bind_null(int col)
742 {
743 param_at(col) = parameter();
744 }
bind_all()745 void bind_all()
746 {
747 for(unsigned i=0;i<params_.size();i++) {
748 params_[i].bind(i+1,stmt_,wide_);
749 }
750
751 }
sequence_last(std::string const & sequence)752 virtual long long sequence_last(std::string const &sequence)
753 {
754 ref_ptr<statement> st;
755 if(!sequence_last_.empty()) {
756 st = new statement(sequence_last_,dbc_,wide_,false);
757 st->bind(1,sequence);
758 }
759 else if(!last_insert_id_.empty()) {
760 st = new statement(last_insert_id_,dbc_,wide_,false);
761 }
762 else {
763 throw not_supported_by_backend(
764 "cppdb::odbc::sequence_last is not supported by odbc backend "
765 "unless properties @squence_last, @last_insert_id are specified "
766 "or @engine is one of mysql, sqlite3, postgresql, mssql");
767 }
768 ref_ptr<result> res = st->query();
769 long long last_id;
770 if(!res->next() || res->cols()!=1 || !res->fetch(0,last_id)) {
771 throw cppdb_error("cppdb::odbc::sequence_last failed to fetch last value");
772 }
773 res.reset();
774 st.reset();
775 return last_id;
776 }
affected()777 virtual unsigned long long affected()
778 {
779 SQLLEN rows = 0;
780 int r = SQLRowCount(stmt_,&rows);
781 check_error(r);
782 return rows;
783 }
query()784 virtual result *query()
785 {
786 bind_all();
787 int r = real_exec();
788 check_error(r);
789 result::rows_type rows;
790 result::row_type row;
791
792 std::string value;
793 bool is_null = false;
794 SQLSMALLINT ocols;
795 r = SQLNumResultCols(stmt_,&ocols);
796 check_error(r);
797 int cols = ocols;
798
799 std::vector<std::string> names(cols);
800 std::vector<int> types(cols,SQL_C_CHAR);
801
802 for(int col=0;col < cols;col++) {
803 SQLSMALLINT name_length=0,data_type=0,digits=0,nullable=0;
804 SQLULEN collen = 0;
805
806 if(wide_) {
807 SQLWCHAR name[257] = {0};
808 r=SQLDescribeColW(stmt_,col+1,name,256,&name_length,&data_type,&collen,&digits,&nullable);
809 check_error(r);
810 names[col]=narrower(name);
811 }
812 else {
813 SQLCHAR name[257] = {0};
814 r=SQLDescribeColA(stmt_,col+1,name,256,&name_length,&data_type,&collen,&digits,&nullable);
815 check_error(r);
816 names[col]=(char*)name;
817 }
818 switch(data_type) {
819 case SQL_CHAR:
820 case SQL_VARCHAR:
821 case SQL_LONGVARCHAR:
822 types[col]=SQL_C_CHAR;
823 break;
824 case SQL_WCHAR:
825 case SQL_WVARCHAR:
826 case SQL_WLONGVARCHAR:
827 types[col]=SQL_C_WCHAR ;
828 break;
829 case SQL_BINARY:
830 case SQL_VARBINARY:
831 case SQL_LONGVARBINARY:
832 types[col]=SQL_C_BINARY ;
833 break;
834 default:
835 types[col]=SQL_C_DEFAULT;
836 // Just a hack, actually I'm going to use C
837 ;
838 }
839 }
840
841 while((r=SQLFetch(stmt_))==SQL_SUCCESS || r==SQL_SUCCESS_WITH_INFO) {
842 row.resize(cols);
843 for(int col=0;col < cols;col++) {
844 SQLLEN len = 0;
845 is_null=false;
846 int type = types[col];
847 if(type==SQL_C_DEFAULT) {
848 char buf[64];
849 int r = SQLGetData(stmt_,col+1,SQL_C_CHAR,buf,sizeof(buf),&len);
850 check_error(r);
851 if(len == SQL_NULL_DATA) {
852 is_null = true;
853 }
854 else if(len <= 64) {
855 value.assign(buf,len);
856 }
857 else {
858 throw cppdb_error("cppdb::odbc::query - data too long");
859 }
860 }
861 else {
862 char buf[1024];
863 size_t real_len;
864 if(type == SQL_C_CHAR) {
865 real_len = sizeof(buf)-1;
866 }
867 else if(type == SQL_C_BINARY) {
868 real_len = sizeof(buf);
869 }
870 else { // SQL_C_WCHAR
871 real_len = sizeof(buf) - sizeof(SQLWCHAR);
872 }
873
874 r = SQLGetData(stmt_,col+1,type,buf,sizeof(buf),&len);
875 check_error(r);
876 if(len == SQL_NULL_DATA) {
877 is_null = true;
878 }
879 else if(len == SQL_NO_TOTAL) {
880 while(len==SQL_NO_TOTAL) {
881 value.append(buf,real_len);
882 r = SQLGetData(stmt_,col+1,type,buf,sizeof(buf),&len);
883 check_error(r);
884 }
885 value.append(buf,len);
886 }
887 else if(0<= len && size_t(len) <= real_len) {
888 value.assign(buf,len);
889 }
890 else if(len>=0) {
891 value.assign(buf,real_len);
892 size_t rem_len = len - real_len;
893 std::vector<char> tmp(rem_len+2,0);
894 r = SQLGetData(stmt_,col+1,type,&tmp[0],tmp.size(),&len);
895 check_error(r);
896 value.append(&tmp[0],rem_len);
897 }
898 else {
899 throw cppdb_error("cppdb::odbc::query invalid result length");
900 }
901 if(!is_null && type == SQL_C_WCHAR) {
902 std::string tmp=narrower(value);
903 value.swap(tmp);
904 }
905 }
906
907 row[col].first = is_null;
908 row[col].second.swap(value);
909 }
910 rows.push_back(result::row_type());
911 rows.back().swap(row);
912 }
913 if(r!=SQL_NO_DATA) {
914 check_error(r);
915 }
916 return new result(rows,names,cols);
917 }
918
real_exec()919 int real_exec()
920 {
921 int r = 0;
922 if(prepared_) {
923 r=SQLExecute(stmt_);
924 }
925 else {
926 if(wide_)
927 r=SQLExecDirectW(stmt_,(SQLWCHAR*)tosqlwide(query_).c_str(),SQL_NTS);
928 else
929 r=SQLExecDirectA(stmt_,(SQLCHAR*)query_.c_str(),SQL_NTS);
930 }
931 return r;
932 }
exec()933 virtual void exec()
934 {
935 bind_all();
936 int r=real_exec();
937 if(r!=SQL_NO_DATA)
938 check_error(r);
939 }
940 // End of API
941
statement(std::string const & q,SQLHDBC dbc,bool wide,bool prepared)942 statement(std::string const &q,SQLHDBC dbc,bool wide,bool prepared) :
943 dbc_(dbc),
944 wide_(wide),
945 query_(q),
946 params_no_(-1),
947 prepared_(prepared)
948 {
949 SQLRETURN r = SQLAllocHandle(SQL_HANDLE_STMT,dbc,&stmt_);
950 check_odbc_error(r,dbc,SQL_HANDLE_DBC,wide_);
951 if(prepared_) {
952 try {
953 if(wide_) {
954 r = SQLPrepareW(
955 stmt_,
956 (SQLWCHAR*)tosqlwide(query_).c_str(),
957 SQL_NTS);
958 }
959 else {
960 r = SQLPrepareA(
961 stmt_,
962 (SQLCHAR*)query_.c_str(),
963 SQL_NTS);
964 }
965 check_error(r);
966 }
967 catch(...) {
968 SQLFreeHandle(SQL_HANDLE_STMT,stmt_);
969 throw;
970 }
971 SQLSMALLINT params_no;
972 r = SQLNumParams(stmt_,¶ms_no);
973 check_error(r);
974 params_no_ = params_no;
975 params_.resize(params_no_);
976 }
977 else {
978 params_.reserve(50);
979 }
980 }
~statement()981 ~statement()
982 {
983 SQLFreeHandle(SQL_HANDLE_STMT,stmt_);
984 }
985 private:
check_error(int code)986 void check_error(int code)
987 {
988 check_odbc_error(code,stmt_,SQL_HANDLE_STMT,wide_);
989 }
990
991
992 SQLHDBC dbc_;
993 SQLHSTMT stmt_;
994 bool wide_;
995 std::string query_;
996 std::vector<parameter> params_;
997 int params_no_;
998
999 friend class connection;
1000 std::string sequence_last_;
1001 std::string last_insert_id_;
1002 bool prepared_;
1003
1004 };
1005
1006 class connection : public backend::connection {
1007 public:
1008
connection(connection_info const & ci)1009 connection(connection_info const &ci) :
1010 backend::connection(ci),
1011 ci_(ci)
1012 {
1013 std::string utf_mode = ci.get("@utf","narrow");
1014
1015 if(utf_mode == "narrow")
1016 wide_ = false;
1017 else if(utf_mode == "wide")
1018 wide_ = true;
1019 else
1020 throw cppdb_error("cppdb::odbc:: @utf property can be either 'narrow' or 'wide'");
1021
1022 bool env_created = false;
1023 bool dbc_created = false;
1024 bool dbc_connected = false;
1025
1026 try {
1027 SQLRETURN r = SQLAllocHandle(SQL_HANDLE_ENV,SQL_NULL_HANDLE,&env_);
1028 if(!SQL_SUCCEEDED(r)) {
1029 throw cppdb_error("cppdb::odbc::Failed to allocate environment handle");
1030 }
1031 env_created = true;
1032 r = SQLSetEnvAttr(env_,SQL_ATTR_ODBC_VERSION,(SQLPOINTER)SQL_OV_ODBC3, 0);
1033 check_odbc_error(r,env_,SQL_HANDLE_ENV,wide_);
1034 r = SQLAllocHandle(SQL_HANDLE_DBC,env_,&dbc_);
1035 check_odbc_error(r,env_,SQL_HANDLE_ENV,wide_);
1036 dbc_created = true;
1037 if(wide_) {
1038 r = SQLDriverConnectW(dbc_,0,
1039 (SQLWCHAR*)tosqlwide(conn_str(ci)).c_str(),
1040 SQL_NTS,0,0,0,SQL_DRIVER_COMPLETE);
1041 }
1042 else {
1043 r = SQLDriverConnectA(dbc_,0,
1044 (SQLCHAR*)conn_str(ci).c_str(),
1045 SQL_NTS,0,0,0,SQL_DRIVER_COMPLETE);
1046 }
1047 check_odbc_error(r,dbc_,SQL_HANDLE_DBC,wide_);
1048 }
1049 catch(...) {
1050 if(dbc_connected)
1051 SQLDisconnect(dbc_);
1052 if(dbc_created)
1053 SQLFreeHandle(SQL_HANDLE_DBC,dbc_);
1054 if(env_created)
1055 SQLFreeHandle(SQL_HANDLE_ENV,env_);
1056 throw;
1057 }
1058 }
1059
conn_str(connection_info const & ci)1060 std::string conn_str(connection_info const &ci)
1061 {
1062 std::map<std::string,std::string>::const_iterator p;
1063 std::string str;
1064 for(p=ci.properties.begin();p!=ci.properties.end();p++) {
1065 if(p->first.empty() || p->first[0]=='@')
1066 continue;
1067 str+=p->first;
1068 str+="=";
1069 str+=p->second;
1070 str+=";";
1071 }
1072 return str;
1073 }
1074
~connection()1075 ~connection()
1076 {
1077 SQLDisconnect(dbc_);
1078 SQLFreeHandle(SQL_HANDLE_DBC,dbc_);
1079 SQLFreeHandle(SQL_HANDLE_ENV,env_);
1080 }
1081
1082 /// API
begin()1083 virtual void begin()
1084 {
1085 set_autocommit(false);
1086 }
commit()1087 virtual void commit()
1088 {
1089 SQLRETURN r = SQLEndTran(SQL_HANDLE_DBC,dbc_,SQL_COMMIT);
1090 check_odbc_error(r,dbc_,SQL_HANDLE_DBC,wide_);
1091 set_autocommit(true);
1092 }
1093
rollback()1094 virtual void rollback()
1095 {
1096 try {
1097 SQLRETURN r = SQLEndTran(SQL_HANDLE_DBC,dbc_,SQL_ROLLBACK);
1098 check_odbc_error(r,dbc_,SQL_HANDLE_DBC,wide_);
1099 }catch(...) {}
1100 try {
1101 set_autocommit(true);
1102 }catch(...){}
1103 }
real_prepare(std::string const & q,bool prepared)1104 statement *real_prepare(std::string const &q,bool prepared)
1105 {
1106 std::auto_ptr<statement> st(new statement(q,dbc_,wide_,prepared));
1107 std::string seq = ci_.get("@sequence_last","");
1108 if(seq.empty()) {
1109 std::string eng=engine();
1110 if(eng == "sqlite3")
1111 st->last_insert_id_ = "select last_insert_rowid()";
1112 else if(eng == "mysql")
1113 st->last_insert_id_ = "select last_insert_id()";
1114 else if(eng == "postgresql")
1115 st->sequence_last_ = "select currval(?)";
1116 else if(eng == "mssql")
1117 st->last_insert_id_ = "select @@identity";
1118 }
1119 else {
1120 if(seq.find('?')==std::string::npos)
1121 st->last_insert_id_ = seq;
1122 else
1123 st->sequence_last_ = seq;
1124 }
1125
1126 return st.release();
1127 }
1128
prepare_statement(std::string const & q)1129 virtual statement *prepare_statement(std::string const &q)
1130 {
1131 return real_prepare(q,true);
1132 }
1133
create_statement(std::string const & q)1134 virtual statement *create_statement(std::string const &q)
1135 {
1136 return real_prepare(q,false);
1137 }
1138
1139
escape(std::string const & s)1140 virtual std::string escape(std::string const &s)
1141 {
1142 return escape(s.c_str(),s.c_str()+s.size());
1143 }
escape(char const * s)1144 virtual std::string escape(char const *s)
1145 {
1146 return escape(s,s+strlen(s));
1147 }
escape(char const *,char const *)1148 virtual std::string escape(char const * /*b*/,char const * /*e*/)
1149 {
1150 throw not_supported_by_backend("cppcms::odbc:: string escaping is not supported");
1151 }
driver()1152 virtual std::string driver()
1153 {
1154 return "odbc";
1155 }
engine()1156 virtual std::string engine()
1157 {
1158 return ci_.get("@engine","unknown");
1159 }
1160
set_autocommit(bool on)1161 void set_autocommit(bool on)
1162 {
1163 SQLPOINTER mode = (SQLPOINTER)(on ? SQL_AUTOCOMMIT_ON : SQL_AUTOCOMMIT_OFF);
1164 SQLRETURN r = SQLSetConnectAttr(
1165 dbc_, // handler
1166 SQL_ATTR_AUTOCOMMIT, // option
1167 mode, //value
1168 0);
1169 check_odbc_error(r,dbc_,SQL_HANDLE_DBC,wide_);
1170 }
1171
1172 private:
1173 SQLHENV env_;
1174 SQLHDBC dbc_;
1175 bool wide_;
1176 connection_info ci_;
1177 };
1178
1179
1180 } // odbc_backend
1181 } // cppdb
1182
1183 extern "C" {
cppdb_odbc_get_connection(cppdb::connection_info const & cs)1184 CPPDB_DRIVER_API cppdb::backend::connection *cppdb_odbc_get_connection(cppdb::connection_info const &cs)
1185 {
1186 return new cppdb::odbc_backend::connection(cs);
1187 }
1188 }
1189