1 // Copyright (c) 1999-2018 David Muse
2 // See the file COPYING for more information
3
4 #include <config.h>
5 #include <sqlrelay/sqlrclient.h>
6 #include <rudiments/inetsocketclient.h>
7 #include <rudiments/unixsocketclient.h>
8 #include <rudiments/file.h>
9 #include <rudiments/environment.h>
10 #include <rudiments/charstring.h>
11 #include <rudiments/stdio.h>
12 #include <rudiments/error.h>
13 #include <rudiments/permissions.h>
14 #include <rudiments/gss.h>
15 #include <rudiments/tls.h>
16 #include <rudiments/sys.h>
17 #include <defines.h>
18 #include <defaults.h>
19
20 #ifndef MAXPATHLEN
21 #define MAXPATHLEN 256
22 #endif
23
24 class sqlrconnectionprivate {
25 friend class sqlrconnection;
26 private:
27
28 // clients
29 inetsocketclient _ics;
30 unixsocketclient _ucs;
31 socketclient *_cs;
32
33 // session state
34 bool _endsessionsent;
35 bool _suspendsessionsent;
36 bool _connected;
37
38 // connection
39 char *_server;
40 uint16_t _listenerinetport;
41 uint16_t _connectioninetport;
42 char *_listenerunixport;
43 const char *_connectionunixport;
44 char _connectionunixportbuffer[MAXPATHLEN+1];
45 int32_t _connecttimeoutsec;
46 int32_t _connecttimeoutusec;
47 int32_t _authtimeoutsec;
48 int32_t _authtimeoutusec;
49 int32_t _responsetimeoutsec;
50 int32_t _responsetimeoutusec;
51 int32_t _retrytime;
52 int32_t _tries;
53
54 // auth
55 char *_user;
56 uint32_t _userlen;
57 char *_password;
58 uint32_t _passwordlen;
59
60 // gss
61 bool _usekrb;
62 char *_krbservice;
63 char *_krbmech;
64 char *_krbflags;
65 gsscredentials _gcred;
66 gssmechanism _gmech;
67 gsscontext _gctx;
68
69 // tls
70 bool _usetls;
71 char *_tlsversion;
72 char *_tlscert;
73 char *_tlspassword;
74 char *_tlsciphers;
75 char *_tlsvalidate;
76 char *_tlsca;
77 uint16_t _tlsdepth;
78 tlscontext _tctx;
79
80 securitycontext *_ctx;
81
82 // error
83 int64_t _errorno;
84 char *_error;
85
86 // identify
87 char *_id;
88
89 // db version
90 char *_dbversion;
91
92 // db host name
93 char *_dbhostname;
94
95 // db ip address
96 char *_dbipaddress;
97
98 // server version
99 char *_serverversion;
100
101 // current database name
102 char *_currentdbname;
103
104 // current schema name
105 char *_currentschemaname;
106
107 // bind format
108 char *_bindformat;
109
110 // bind delimiters
111 bool _questionmarksupported;
112 bool _colonsupported;
113 bool _atsignsupported;
114 bool _dollarsignsupported;
115
116 // client info
117 char *_clientinfo;
118 uint64_t _clientinfolen;
119
120 // debug
121 bool _debug;
122 int32_t _webdebug;
123 int (*_printfunction)(const char *,...);
124 file _debugfile;
125
126 // copy references flag
127 bool _copyrefs;
128
129 // cursor list
130 sqlrcursor *_firstcursor;
131 sqlrcursor *_lastcursor;
132 };
133
sqlrconnection(const char * server,uint16_t port,const char * socket,const char * user,const char * password,int32_t retrytime,int32_t tries,bool copyreferences)134 sqlrconnection::sqlrconnection(const char *server, uint16_t port,
135 const char *socket,
136 const char *user, const char *password,
137 int32_t retrytime, int32_t tries,
138 bool copyreferences) {
139 init(server,port,socket,user,password,retrytime,tries,copyreferences);
140 }
141
sqlrconnection(const char * server,uint16_t port,const char * socket,const char * user,const char * password,int32_t retrytime,int32_t tries)142 sqlrconnection::sqlrconnection(const char *server, uint16_t port,
143 const char *socket,
144 const char *user, const char *password,
145 int32_t retrytime, int32_t tries) {
146 init(server,port,socket,user,password,retrytime,tries,false);
147 }
148
init(const char * server,uint16_t port,const char * socket,const char * user,const char * password,int32_t retrytime,int32_t tries,bool copyreferences)149 void sqlrconnection::init(const char *server, uint16_t port,
150 const char *socket,
151 const char *user, const char *password,
152 int32_t retrytime, int32_t tries,
153 bool copyreferences) {
154
155 pvt=new sqlrconnectionprivate;
156
157 pvt->_copyrefs=copyreferences;
158
159 // retry reads if they get interrupted by signals
160 pvt->_ucs.translateByteOrder();
161 pvt->_ucs.retryInterruptedReads();
162 pvt->_ics.retryInterruptedReads();
163 pvt->_cs=&pvt->_ucs;
164
165 // connection
166 pvt->_server=(pvt->_copyrefs)?
167 charstring::duplicate(server):
168 (char *)server;
169 pvt->_listenerinetport=port;
170 pvt->_listenerunixport=(pvt->_copyrefs)?
171 charstring::duplicate(socket):
172 (char *)socket;
173 pvt->_retrytime=retrytime;
174 pvt->_tries=tries;
175
176 // initialize timeouts
177 setTimeoutFromEnv("SQLR_CLIENT_CONNECT_TIMEOUT",
178 &pvt->_connecttimeoutsec,&pvt->_connecttimeoutusec);
179 setTimeoutFromEnv("SQLR_CLIENT_AUTHENTICATION_TIMEOUT",
180 &pvt->_authtimeoutsec,&pvt->_authtimeoutusec);
181 setTimeoutFromEnv("SQLR_CLIENT_RESPONSE_TIMEOUT",
182 &pvt->_responsetimeoutsec,&pvt->_responsetimeoutusec);
183
184 // authentication
185 pvt->_user=(pvt->_copyrefs)?
186 charstring::duplicate(user):
187 (char *)user;
188 pvt->_password=(pvt->_copyrefs)?
189 charstring::duplicate(password):
190 (char *)password;
191 pvt->_userlen=charstring::length(user);
192 pvt->_passwordlen=charstring::length(password);
193 pvt->_usekrb=false;
194 pvt->_krbservice=NULL;
195 pvt->_krbmech=NULL;
196 pvt->_krbflags=NULL;
197
198 pvt->_usetls=false;
199 pvt->_tlsversion=NULL;
200 pvt->_tlscert=NULL;
201 pvt->_tlspassword=NULL;
202 pvt->_tlsciphers=NULL;
203 pvt->_tlsvalidate=(pvt->_copyrefs)?
204 charstring::duplicate("no"):
205 (char *)"no";
206 pvt->_tlsca=NULL;
207 pvt->_tlsdepth=0;
208
209 pvt->_ctx=NULL;
210
211 // database id
212 pvt->_id=NULL;
213
214 // db version
215 pvt->_dbversion=NULL;
216
217 // db host name
218 pvt->_dbhostname=NULL;
219
220 // db ip address
221 pvt->_dbipaddress=NULL;
222
223 // server version
224 pvt->_serverversion=NULL;
225
226 // current database name
227 pvt->_currentdbname=NULL;
228
229 // current schema name
230 pvt->_currentschemaname=NULL;
231
232 // bind format
233 pvt->_bindformat=NULL;
234
235 // bind delimiters
236 pvt->_questionmarksupported=true;
237 pvt->_colonsupported=true;
238 pvt->_atsignsupported=true;
239 pvt->_dollarsignsupported=true;
240
241 // client info
242 pvt->_clientinfo=NULL;
243 pvt->_clientinfolen=0;
244
245 // session state
246 pvt->_connected=false;
247 clearSessionFlags();
248
249 // debug print function
250 pvt->_printfunction=NULL;
251
252 // enable/disable debug
253 const char *sqlrdebug=environment::getValue("SQLRDEBUG");
254 if (!sqlrdebug || !*sqlrdebug) {
255 sqlrdebug=environment::getValue("SQLR_CLIENT_DEBUG");
256 }
257 pvt->_debug=(sqlrdebug && *sqlrdebug && !charstring::isNo(sqlrdebug));
258 if (pvt->_debug && !charstring::isYes(sqlrdebug) &&
259 !charstring::isNo(sqlrdebug)) {
260 setDebugFile(sqlrdebug);
261 }
262 pvt->_webdebug=-1;
263
264 // error
265 pvt->_errorno=0;
266 pvt->_error=NULL;
267
268 // cursor list
269 pvt->_firstcursor=NULL;
270 pvt->_lastcursor=NULL;
271 }
272
clearSessionFlags()273 void sqlrconnection::clearSessionFlags() {
274
275 // indicate that the session hasn't been suspended or ended
276 pvt->_endsessionsent=false;
277 pvt->_suspendsessionsent=false;
278 }
279
~sqlrconnection()280 sqlrconnection::~sqlrconnection() {
281
282 // unless it was already ended or suspended, end the session
283 if (!pvt->_endsessionsent && !pvt->_suspendsessionsent) {
284 endSession();
285 }
286
287 // deallocate error
288 delete[] pvt->_error;
289
290 // deallocate id
291 delete[] pvt->_id;
292
293 // deallocate dbversion
294 delete[] pvt->_dbversion;
295
296 // deallocate db host name
297 delete[] pvt->_dbhostname;
298
299 // deallocate db ip address
300 delete[] pvt->_dbipaddress;
301
302 // deallocate server version
303 delete[] pvt->_serverversion;
304
305 // deallocate current database name
306 delete[] pvt->_currentdbname;
307
308 // deallocate current schema name
309 delete[] pvt->_currentschemaname;
310
311 // deallocate bindformat
312 delete[] pvt->_bindformat;
313
314 // deallocate client info
315 delete[] pvt->_clientinfo;
316
317 // deallocate copied references
318 if (pvt->_copyrefs) {
319 delete[] pvt->_server;
320 delete[] pvt->_listenerunixport;
321 delete[] pvt->_user;
322 delete[] pvt->_password;
323 delete[] pvt->_krbservice;
324 delete[] pvt->_krbmech;
325 delete[] pvt->_krbflags;
326 delete[] pvt->_tlsversion;
327 delete[] pvt->_tlscert;
328 delete[] pvt->_tlspassword;
329 delete[] pvt->_tlsvalidate;
330 delete[] pvt->_tlsca;
331 }
332
333 // detach all cursors attached to this client
334 sqlrcursor *currentcursor=pvt->_firstcursor;
335 while (currentcursor) {
336 pvt->_firstcursor=currentcursor;
337 currentcursor=currentcursor->next();
338 pvt->_firstcursor->sqlrc(NULL);
339 }
340
341 if (pvt->_debug) {
342 debugPreStart();
343 debugPrint("Deallocated connection\n");
344 debugPreEnd();
345 }
346
347 pvt->_debugfile.close();
348
349 delete pvt;
350 }
351
enableKerberos(const char * service,const char * mech,const char * flags)352 void sqlrconnection::enableKerberos(const char *service,
353 const char *mech,
354 const char *flags) {
355
356 // clear any existing configuration
357 if (pvt->_usekrb || pvt->_usetls) {
358 disableEncryption();
359 }
360
361 if (!gss::supported()) {
362 return;
363 }
364
365 pvt->_usekrb=true;
366
367 // "Negotiate" is Windows' default mech. Force Kerberos instead.
368 char *os=sys::getOperatingSystemName();
369 if (!charstring::compare(os,"Windows",7) &&
370 charstring::isNullOrEmpty(mech)) {
371 mech="Kerberos";
372 }
373 delete[] os;
374
375 if (pvt->_copyrefs) {
376 delete[] pvt->_krbservice;
377 pvt->_krbservice=charstring::duplicate(
378 (!charstring::isNullOrEmpty(service))?
379 service:DEFAULT_KRBSERVICE);
380 delete[] pvt->_krbmech;
381 pvt->_krbmech=charstring::duplicate(mech);
382 delete[] pvt->_krbflags;
383 pvt->_krbflags=charstring::duplicate(flags);
384 } else {
385 pvt->_krbservice=(char *)
386 (!charstring::isNullOrEmpty(service)?
387 service:DEFAULT_KRBSERVICE);
388 pvt->_krbmech=(char *)mech;
389 pvt->_krbflags=(char *)flags;
390 }
391 }
392
enableTls(const char * version,const char * cert,const char * password,const char * ciphers,const char * validate,const char * ca,uint16_t depth)393 void sqlrconnection::enableTls(const char *version,
394 const char *cert,
395 const char *password,
396 const char *ciphers,
397 const char *validate,
398 const char *ca,
399 uint16_t depth) {
400
401 // clear any existing configuration
402 if (pvt->_usekrb || pvt->_usetls) {
403 disableEncryption();
404 }
405
406 if (!tls::supported()) {
407 return;
408 }
409
410 pvt->_usetls=true;
411
412 if (pvt->_copyrefs) {
413 delete[] pvt->_tlsversion;
414 pvt->_tlsversion=charstring::duplicate(version);
415 delete[] pvt->_tlscert;
416 pvt->_tlscert=charstring::duplicate(cert);
417 delete[] pvt->_tlspassword;
418 pvt->_tlspassword=charstring::duplicate(password);
419 delete[] pvt->_tlsciphers;
420 pvt->_tlsciphers=charstring::duplicate(ciphers);
421 delete[] pvt->_tlsvalidate;
422 pvt->_tlsvalidate=charstring::duplicate(validate);
423 delete[] pvt->_tlsca;
424 pvt->_tlsca=charstring::duplicate(ca);
425 } else {
426 pvt->_tlsversion=(char *)version;
427 pvt->_tlscert=(char *)cert;
428 pvt->_tlspassword=(char *)password;
429 pvt->_tlsciphers=(char *)ciphers;
430 pvt->_tlsvalidate=(char *)validate;
431 pvt->_tlsca=(char *)ca;
432 }
433 pvt->_tlsdepth=depth;
434 }
435
disableEncryption()436 void sqlrconnection::disableEncryption() {
437
438 if (pvt->_copyrefs) {
439 delete[] pvt->_krbservice;
440 delete[] pvt->_krbmech;
441 delete[] pvt->_krbflags;
442
443 delete[] pvt->_tlsversion;
444 delete[] pvt->_tlscert;
445 delete[] pvt->_tlspassword;
446 delete[] pvt->_tlsciphers;
447 delete[] pvt->_tlsvalidate;
448 delete[] pvt->_tlsca;
449 }
450 pvt->_krbservice=NULL;
451 pvt->_krbmech=NULL;
452 pvt->_krbflags=NULL;
453 pvt->_usekrb=false;
454
455 pvt->_tlsversion=NULL;
456 pvt->_tlscert=NULL;
457 pvt->_tlspassword=NULL;
458 pvt->_tlsciphers=NULL;
459 pvt->_tlsvalidate=NULL;
460 pvt->_tlsca=NULL;
461 pvt->_tlsdepth=0;
462 pvt->_usetls=false;
463 }
464
setConnectTimeout(int32_t timeoutsec,int32_t timeoutusec)465 void sqlrconnection::setConnectTimeout(int32_t timeoutsec,
466 int32_t timeoutusec) {
467 pvt->_connecttimeoutsec=timeoutsec;
468 pvt->_connecttimeoutusec=timeoutusec;
469 }
470
setAuthenticationTimeout(int32_t timeoutsec,int32_t timeoutusec)471 void sqlrconnection::setAuthenticationTimeout(int32_t timeoutsec,
472 int32_t timeoutusec) {
473 pvt->_authtimeoutsec=timeoutsec;
474 pvt->_authtimeoutusec=timeoutusec;
475 }
476
setResponseTimeout(int32_t timeoutsec,int32_t timeoutusec)477 void sqlrconnection::setResponseTimeout(int32_t timeoutsec,
478 int32_t timeoutusec) {
479 pvt->_responsetimeoutsec=timeoutsec;
480 pvt->_responsetimeoutusec=timeoutusec;
481 }
482
setTimeoutFromEnv(const char * var,int32_t * timeoutsec,int32_t * timeoutusec)483 void sqlrconnection::setTimeoutFromEnv(const char *var,
484 int32_t *timeoutsec,
485 int32_t *timeoutusec) {
486 const char *timeout=environment::getValue(var);
487 if (charstring::isNumber(timeout)) {
488 *timeoutsec=charstring::toInteger(timeout);
489 long double dbl=charstring::toFloatC(timeout);
490 dbl=dbl-(long double)(*timeoutsec);
491 *timeoutusec=(int32_t)(dbl*1000000.0);
492 } else {
493 *timeoutsec=-1;
494 *timeoutusec=-1;
495 }
496 }
497
getConnectTimeout(int32_t * timeoutsec,int32_t * timeoutusec)498 void sqlrconnection::getConnectTimeout(int32_t *timeoutsec,
499 int32_t *timeoutusec) {
500 *timeoutsec=pvt->_connecttimeoutsec;
501 *timeoutusec=pvt->_connecttimeoutusec;
502 }
503
getAuthenticationTimeout(int32_t * timeoutsec,int32_t * timeoutusec)504 void sqlrconnection::getAuthenticationTimeout(int32_t *timeoutsec,
505 int32_t *timeoutusec) {
506 *timeoutsec=pvt->_authtimeoutsec;
507 *timeoutusec=pvt->_authtimeoutusec;
508 }
509
getResponseTimeout(int32_t * timeoutsec,int32_t * timeoutusec)510 void sqlrconnection::getResponseTimeout(int32_t *timeoutsec,
511 int32_t *timeoutusec) {
512 *timeoutsec=pvt->_responsetimeoutsec;
513 *timeoutusec=pvt->_responsetimeoutusec;
514 }
515
endSession()516 void sqlrconnection::endSession() {
517
518 if (pvt->_debug) {
519 debugPreStart();
520 debugPrint("Ending Session\n");
521 debugPreEnd();
522 }
523
524 // abort each cursor's result set
525 sqlrcursor *currentcursor=pvt->_firstcursor;
526 while (currentcursor) {
527 // FIXME: do we need to clearResultSet() here too?
528 if (!currentcursor->endofresultset()) {
529 currentcursor->closeResultSet(false);
530 }
531 currentcursor->havecursorid(false);
532 currentcursor=currentcursor->next();
533 }
534
535 // write an END_SESSION to the connection
536 if (pvt->_connected) {
537 pvt->_cs->write((uint16_t)END_SESSION);
538 flushWriteBuffer();
539 pvt->_endsessionsent=true;
540 closeConnection();
541 }
542 }
543
flushWriteBuffer()544 void sqlrconnection::flushWriteBuffer() {
545 pvt->_cs->flushWriteBuffer(-1,-1);
546 }
547
closeConnection()548 void sqlrconnection::closeConnection() {
549 pvt->_cs->close();
550 pvt->_connected=false;
551 }
552
suspendSession()553 bool sqlrconnection::suspendSession() {
554
555 if (!openSession()) {
556 return 0;
557 }
558
559 clearError();
560
561 if (pvt->_debug) {
562 debugPreStart();
563 debugPrint("Suspending Session\n");
564 debugPreEnd();
565 }
566
567 // suspend the session
568 pvt->_cs->write((uint16_t)SUSPEND_SESSION);
569 flushWriteBuffer();
570 pvt->_suspendsessionsent=true;
571
572 // check for error
573 if (gotError()) {
574 return false;
575 }
576
577 // get port to resume on
578 bool retval=getNewPort();
579
580 closeConnection();
581
582 return retval;
583 }
584
openSession()585 bool sqlrconnection::openSession() {
586
587 if (pvt->_connected) {
588 return true;
589 }
590
591 if (!reConfigureSockets()) {
592 pvt->_connected=false;
593 return false;
594 }
595
596 if (pvt->_debug) {
597 debugPreStart();
598 debugPrint("Connecting to listener...");
599 debugPrint("\n");
600 debugPreEnd();
601 }
602
603 // open a connection to the listener
604 int openresult=RESULT_ERROR;
605
606 // first, try for a unix connection
607 if (!charstring::isNullOrEmpty(pvt->_listenerunixport)) {
608
609 if (pvt->_debug) {
610 debugPreStart();
611 debugPrint("Unix socket: ");
612 debugPrint(pvt->_listenerunixport);
613 debugPrint("\n");
614 debugPreEnd();
615 }
616
617 openresult=pvt->_ucs.connect(pvt->_listenerunixport,
618 pvt->_connecttimeoutsec,
619 pvt->_connecttimeoutusec,
620 pvt->_retrytime,pvt->_tries);
621 if (openresult==RESULT_SUCCESS) {
622
623 pvt->_ucs.setSocketReadBufferSize(65536);
624 pvt->_ucs.setSocketWriteBufferSize(65536);
625
626 pvt->_cs=&pvt->_ucs;
627 }
628 }
629
630 // then try for an inet connection
631 if (openresult!=RESULT_SUCCESS && pvt->_listenerinetport) {
632
633 if (pvt->_debug) {
634 debugPreStart();
635 debugPrint("Inet socket: ");
636 debugPrint(pvt->_server);
637 debugPrint(":");
638 debugPrint((int64_t)pvt->_listenerinetport);
639 debugPrint("\n");
640 debugPreEnd();
641 }
642
643 openresult=pvt->_ics.connect(pvt->_server,
644 pvt->_listenerinetport,
645 pvt->_connecttimeoutsec,
646 pvt->_connecttimeoutusec,
647 pvt->_retrytime,pvt->_tries);
648 if (openresult==RESULT_SUCCESS) {
649
650 pvt->_ics.setSocketReadBufferSize(65536);
651 pvt->_ics.setSocketWriteBufferSize(65536);
652
653 pvt->_ics.dontUseNaglesAlgorithm();
654
655 pvt->_cs=&pvt->_ics;
656 }
657 }
658
659 // handle failures
660 if (openresult!=RESULT_SUCCESS) {
661 setConnectFailedError();
662 return false;
663 }
664
665 // if tls is enabled and we're using an inet socket,
666 // then we may need to validate the host
667 if (pvt->_usetls && pvt->_cs==&pvt->_ics) {
668
669 if (!validateCertificate()) {
670 setError("TLS error: common name mismatch");
671 pvt->_cs->close();
672 return false;
673 }
674 }
675
676 // if we made it here then everything went
677 // well and we are successfully connected
678 pvt->_connected=true;
679
680 // send protocol info
681 protocol();
682
683 // auth
684 auth();
685
686 return true;
687 }
688
validateCertificate()689 bool sqlrconnection::validateCertificate() {
690
691 // If we're not doing any validation then just return true. If we're
692 // just doing ca validation then the connect would have failed if the
693 // certificate was invalid, so we can just return true for that too.
694 if (charstring::isNo(pvt->_tlsvalidate) ||
695 !charstring::compareIgnoringCase(pvt->_tlsvalidate,"ca")) {
696 return true;
697 }
698
699 // get the cert from the server
700 tlscertificate *cert=((tlscontext *)pvt->_ctx)->
701 getPeerCertificate();
702 if (!cert) {
703 // this should never happen, the connect()
704 // should have failed if no cert was supplied
705 return false;
706 }
707
708 // get the subject alternate names and common name from the cert
709 linkedlist< char * > *sans=cert->getSubjectAlternateNames();
710 const char *commonname=cert->getCommonName();
711
712 // should we validate the host name or domain?
713 bool host=!charstring::compareIgnoringCase(
714 pvt->_tlsvalidate,"ca+host");
715
716 // get the server name to validate against
717 const char *server=pvt->_server;
718 if (!host) {
719 const char *dot=charstring::findFirst(server,'.');
720 if (dot) {
721 server=dot+1;
722 }
723 }
724
725 // if there are any subject alternate
726 // names then validate against those
727 if (sans && sans->getLength()) {
728
729 for (linkedlistnode< char * > *node=sans->getFirst();
730 node; node=node->getNext()) {
731
732 const char *san=node->getValue();
733 if (!host) {
734 const char *dot=
735 charstring::findFirst(san,'.');
736 if (dot) {
737 san=dot+1;
738 }
739 }
740
741 if (pvt->_debug) {
742 debugPreStart();
743 debugPrint(pvt->_tlsvalidate);
744 debugPrint(": ");
745 debugPrint(server);
746 debugPrint("=");
747 debugPrint(san);
748 debugPrint("\n");
749 debugPreEnd();
750 }
751
752 if (!charstring::compareIgnoringCase(server,san)) {
753 return true;
754 }
755 }
756
757 return false;
758 }
759
760
761 // if there were no subject alternate names
762 // then just validate against the common name
763 if (!host) {
764 const char *dot=charstring::findFirst(commonname,'.');
765 if (dot) {
766 commonname=dot+1;
767 }
768 }
769
770 if (pvt->_debug) {
771 debugPreStart();
772 debugPrint(pvt->_tlsvalidate);
773 debugPrint(": ");
774 debugPrint(server);
775 debugPrint("=");
776 debugPrint(commonname);
777 debugPrint("\n");
778 debugPreEnd();
779 }
780
781 return !charstring::compareIgnoringCase(server,commonname);
782 }
783
784
reConfigureSockets()785 bool sqlrconnection::reConfigureSockets() {
786
787 pvt->_ucs.setReadBufferSize(65536);
788 pvt->_ucs.setWriteBufferSize(65536);
789 //pvt->_ucs.useAsyncWrite();
790
791 pvt->_ics.setReadBufferSize(65536);
792 pvt->_ics.setWriteBufferSize(65536);
793 //pvt->_ics.useAsyncWrite();
794
795
796 if (pvt->_usekrb) {
797
798 if (pvt->_debug) {
799 debugPreStart();
800 debugPrint("kerberos encryption/"
801 "authentication enabled\n");
802 debugPrint(" service: ");
803 if (pvt->_krbservice) {
804 debugPrint(pvt->_krbservice);
805 }
806 debugPrint("\n");
807 debugPrint(" mech: ");
808 if (pvt->_krbmech) {
809 debugPrint(pvt->_krbmech);
810 }
811 debugPrint("\n");
812 debugPrint(" flags: ");
813 if (pvt->_krbflags) {
814 debugPrint(pvt->_krbflags);
815 }
816 debugPrint("\n");
817 debugPreEnd();
818 }
819
820 pvt->_gmech.clear();
821 pvt->_gmech.initialize(pvt->_krbmech);
822
823 pvt->_gcred.clearDesiredMechanisms();
824 pvt->_gcred.addDesiredMechanism(&pvt->_gmech);
825
826 if (!pvt->_gcred.acquired() &&
827 !charstring::isNullOrEmpty(pvt->_user)) {
828
829 if (pvt->_gcred.acquireForUser(pvt->_user)) {
830 if (pvt->_debug) {
831 debugPreStart();
832 debugPrint("acquired kerberos "
833 "credentials for: ");
834 debugPrint(pvt->_user);
835 debugPrint("\n");
836 debugPreEnd();
837 }
838 } else {
839 if (pvt->_debug) {
840 debugPreStart();
841 if (pvt->_gcred.getMajorStatus()) {
842 debugPrint(pvt->_gcred.
843 getMechanismMinorStatus());
844 }
845 debugPreEnd();
846 }
847 setError("Failed to acquire "
848 "kerberos credentials.");
849 return false;
850 }
851 }
852
853 pvt->_gctx.close();
854 pvt->_gctx.setDesiredMechanism(&pvt->_gmech);
855 pvt->_gctx.setDesiredFlags(pvt->_krbflags);
856 pvt->_gctx.setService(pvt->_krbservice);
857 pvt->_gctx.setCredentials(&pvt->_gcred);
858
859 pvt->_ctx=&pvt->_gctx;
860
861 } else if (pvt->_usetls) {
862
863 if (pvt->_debug) {
864 debugPreStart();
865 debugPrint("TLS encryption/authentication enabled\n");
866 debugPrint(" version: ");
867 if (pvt->_tlsversion) {
868 debugPrint(pvt->_tlsversion);
869 }
870 debugPrint("\n");
871 debugPrint(" cert: ");
872 if (pvt->_tlscert) {
873 debugPrint(pvt->_tlscert);
874 }
875 debugPrint("\n");
876 debugPrint(" private key password: ");
877 if (pvt->_tlspassword) {
878 debugPrint(pvt->_tlspassword);
879 }
880 debugPrint("\n");
881 debugPrint(" ciphers: ");
882 if (pvt->_tlsciphers) {
883 debugPrint(pvt->_tlsciphers);
884 }
885 debugPrint("\n");
886 debugPrint(" validate: ");
887 if (pvt->_tlsvalidate) {
888 debugPrint(pvt->_tlsvalidate);
889 }
890 debugPrint("\n");
891 debugPrint(" ca: ");
892 if (pvt->_tlsca) {
893 debugPrint(pvt->_tlsca);
894 }
895 debugPrint("\n");
896 debugPrint(" depth: ");
897 debugPrint((int64_t)pvt->_tlsdepth);
898 debugPrint("\n");
899 debugPreEnd();
900 }
901
902 pvt->_tctx.close();
903 pvt->_tctx.setProtocolVersion(pvt->_tlsversion);
904 pvt->_tctx.setCertificateChainFile(pvt->_tlscert);
905 pvt->_tctx.setPrivateKeyPassword(pvt->_tlspassword);
906 pvt->_tctx.setCiphers(pvt->_tlsciphers);
907 pvt->_tctx.setValidatePeer(!charstring::compareIgnoringCase(
908 pvt->_tlsvalidate,"ca",2));
909 pvt->_tctx.setCertificateAuthority(pvt->_tlsca);
910 pvt->_tctx.setValidationDepth(pvt->_tlsdepth);
911
912 pvt->_ctx=&pvt->_tctx;
913
914 } else {
915
916 if (pvt->_debug) {
917 debugPreStart();
918 debugPrint("encryption disabled\n");
919 debugPreEnd();
920 }
921
922 pvt->_ctx=NULL;
923 }
924
925 pvt->_ucs.setSecurityContext(pvt->_ctx);
926 pvt->_ics.setSecurityContext(pvt->_ctx);
927
928 return true;
929 }
930
setConnectFailedError()931 void sqlrconnection::setConnectFailedError() {
932 if (pvt->_usekrb && pvt->_gctx.getMajorStatus()) {
933 setError(pvt->_gctx.getMechanismMinorStatus());
934 } else if (pvt->_usetls && pvt->_tctx.getError()) {
935 stringbuffer err;
936 err.append("TLS error: ");
937 err.append(pvt->_tctx.getErrorString());
938 setError(err.getString());
939 } else {
940 setError("Couldn't connect to the listener.");
941 }
942 }
943
protocol()944 void sqlrconnection::protocol() {
945
946 if (pvt->_debug) {
947 debugPreStart();
948 debugPrint("Protocol : sqlrclient version 2\n");
949 debugPreEnd();
950 }
951
952 pvt->_cs->write((uint16_t)PROTOCOLVERSION);
953 pvt->_cs->write((uint16_t)2);
954 }
955
auth()956 void sqlrconnection::auth() {
957
958 if (pvt->_debug) {
959 debugPreStart();
960 debugPrint("Auth : ");
961 debugPrint(pvt->_user);
962 debugPrint(":");
963 // hide the password
964 for (uint8_t i=0; i<pvt->_passwordlen; i++) {
965 debugPrint("*");
966 }
967 debugPrint("\n");
968 debugPreEnd();
969 }
970
971 pvt->_cs->write((uint16_t)AUTH);
972
973 pvt->_cs->write(pvt->_userlen);
974 pvt->_cs->write(pvt->_user,pvt->_userlen);
975
976 pvt->_cs->write(pvt->_passwordlen);
977 pvt->_cs->write(pvt->_password,pvt->_passwordlen);
978
979 // I don't think this needs to be here...
980 // Just commenting it out for now though.
981 //flushWriteBuffer();
982 }
983
getNewPort()984 bool sqlrconnection::getNewPort() {
985
986 // get the size of the unix port string
987 uint16_t size;
988 if (pvt->_cs->read(&size)!=sizeof(uint16_t)) {
989 setError("Failed to get the size of "
990 "the unix connection port.\n"
991 "A network error may have occurred.");
992 return false;
993 }
994
995 if (size>MAXPATHLEN) {
996
997 // if size is too big, return an error
998 stringbuffer errstr;
999 errstr.append("The pathname of the unix port was too long: ");
1000 errstr.append(size);
1001 errstr.append(" bytes. The maximum size is ");
1002 errstr.append((uint16_t)MAXPATHLEN);
1003 errstr.append(" bytes.");
1004 setError(errstr.getString());
1005 return false;
1006 }
1007
1008 // get the unix port string
1009 if (size && pvt->_cs->read(pvt->_connectionunixportbuffer,size)!=size) {
1010 setError("Failed to get the unix connection port.\n"
1011 "A network error may have occurred.");
1012 return false;
1013 }
1014 pvt->_connectionunixportbuffer[size]='\0';
1015 pvt->_connectionunixport=pvt->_connectionunixportbuffer;
1016
1017 // get the inet port
1018 if (pvt->_cs->read(&pvt->_connectioninetport)!=sizeof(uint16_t)) {
1019 setError("Failed to get the inet connection port.\n"
1020 "A network error may have occurred.");
1021 return false;
1022 }
1023
1024 // the server will send 0 for both the size of the unixport and
1025 // the inet port if a server error occurred
1026 if (!size && !pvt->_connectioninetport) {
1027 setError("An error occurred on the server.");
1028 return false;
1029 }
1030 return true;
1031 }
1032
getConnectionPort()1033 uint16_t sqlrconnection::getConnectionPort() {
1034
1035 if (!pvt->_suspendsessionsent && !openSession()) {
1036 return 0;
1037 }
1038
1039 if (pvt->_debug) {
1040 debugPreStart();
1041 debugPrint("Getting connection port: ");
1042 debugPrint((int64_t)pvt->_connectioninetport);
1043 debugPrint("\n");
1044 debugPreEnd();
1045 }
1046
1047 return pvt->_connectioninetport;
1048 }
1049
getConnectionSocket()1050 const char *sqlrconnection::getConnectionSocket() {
1051
1052 if (!pvt->_suspendsessionsent && !openSession()) {
1053 return NULL;
1054 }
1055
1056 if (pvt->_debug) {
1057 debugPreStart();
1058 debugPrint("Getting connection socket: ");
1059 if (pvt->_connectionunixport) {
1060 debugPrint(pvt->_connectionunixport);
1061 }
1062 debugPrint("\n");
1063 debugPreEnd();
1064 }
1065
1066 if (pvt->_connectionunixport) {
1067 return pvt->_connectionunixport;
1068 }
1069 return NULL;
1070 }
1071
resumeSession(uint16_t port,const char * socket)1072 bool sqlrconnection::resumeSession(uint16_t port, const char *socket) {
1073
1074 // if already pvt->_connected, end the session
1075 if (pvt->_connected) {
1076 endSession();
1077 }
1078
1079 if (pvt->_debug) {
1080 debugPreStart();
1081 debugPrint("Resuming Session: \n");
1082 debugPrint("port: ");
1083 debugPrint((int64_t)port);
1084 debugPrint("\n");
1085 debugPrint("socket: ");
1086 debugPrint(socket);
1087 debugPrint("\n");
1088 debugPreEnd();
1089 }
1090
1091 // set the connectionunixport and connectioninetport values
1092 if (pvt->_copyrefs) {
1093 if (charstring::length(socket)<=MAXPATHLEN) {
1094 charstring::copy(pvt->_connectionunixportbuffer,socket);
1095 pvt->_connectionunixport=pvt->_connectionunixportbuffer;
1096 } else {
1097 pvt->_connectionunixport="";
1098 }
1099 } else {
1100 pvt->_connectionunixport=(char *)socket;
1101 }
1102 pvt->_connectioninetport=port;
1103
1104 if (!reConfigureSockets()) {
1105 pvt->_connected=false;
1106 return false;
1107 }
1108
1109 // first, try for the unix port
1110 if (!charstring::isNullOrEmpty(socket)) {
1111 pvt->_connected=(pvt->_ucs.connect(
1112 socket,-1,-1,
1113 pvt->_retrytime,
1114 pvt->_tries)==RESULT_SUCCESS);
1115 if (pvt->_connected) {
1116 pvt->_cs=&pvt->_ucs;
1117 }
1118 }
1119
1120 // then try for the inet port
1121 if (!pvt->_connected) {
1122 pvt->_connected=(pvt->_ics.connect(
1123 pvt->_server,port,-1,-1,
1124 pvt->_retrytime,
1125 pvt->_tries)==RESULT_SUCCESS);
1126 if (pvt->_connected) {
1127 pvt->_cs=&pvt->_ics;
1128 }
1129 }
1130
1131 if (pvt->_connected) {
1132
1133 // send protocol info
1134 protocol();
1135
1136 if (pvt->_debug) {
1137 debugPreStart();
1138 debugPrint("success");
1139 debugPrint("\n");
1140 debugPreEnd();
1141 }
1142 clearSessionFlags();
1143 } else {
1144 setConnectFailedError();
1145 if (pvt->_debug) {
1146 debugPreStart();
1147 debugPrint("failure");
1148 debugPrint("\n");
1149 debugPreEnd();
1150 }
1151 }
1152
1153 return pvt->_connected;
1154 }
1155
ping()1156 bool sqlrconnection::ping() {
1157
1158 if (!openSession()) {
1159 return 0;
1160 }
1161
1162 clearError();
1163
1164 if (pvt->_debug) {
1165 debugPreStart();
1166 debugPrint("Pinging...\n");
1167 debugPreEnd();
1168 }
1169
1170 pvt->_cs->write((uint16_t)PING);
1171 flushWriteBuffer();
1172
1173 return !gotError();
1174 }
1175
identify()1176 const char *sqlrconnection::identify() {
1177
1178 if (!openSession()) {
1179 return NULL;
1180 }
1181
1182 clearError();
1183
1184 if (pvt->_debug) {
1185 debugPreStart();
1186 debugPrint("Identifying...\n");
1187 debugPreEnd();
1188 }
1189
1190 // tell the server we want the identity of the db
1191 pvt->_cs->write((uint16_t)IDENTIFY);
1192 flushWriteBuffer();
1193
1194 if (gotError()) {
1195 return NULL;
1196 }
1197
1198 // get the identity size
1199 uint16_t size;
1200 if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1201 pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1202 setError("Failed to identify.\n"
1203 "A network error may have occurred.");
1204 return NULL;
1205 }
1206
1207 // get the identity
1208 delete[] pvt->_id;
1209 pvt->_id=new char[size+1];
1210 if (pvt->_cs->read(pvt->_id,size)!=size) {
1211 setError("Failed to identify.\n"
1212 "A network error may have occurred.");
1213 delete[] pvt->_id;
1214 pvt->_id=NULL;
1215 return NULL;
1216 }
1217 pvt->_id[size]='\0';
1218
1219 if (pvt->_debug) {
1220 debugPreStart();
1221 debugPrint(pvt->_id);
1222 debugPrint("\n");
1223 debugPreEnd();
1224 }
1225 return pvt->_id;
1226 }
1227
dbVersion()1228 const char *sqlrconnection::dbVersion() {
1229
1230 if (!openSession()) {
1231 return NULL;
1232 }
1233
1234 clearError();
1235
1236 if (pvt->_debug) {
1237 debugPreStart();
1238 debugPrint("DB Version...");
1239 debugPrint("\n");
1240 debugPreEnd();
1241 }
1242
1243 // tell the server we want the db version
1244 pvt->_cs->write((uint16_t)DBVERSION);
1245 flushWriteBuffer();
1246
1247 if (gotError()) {
1248 return NULL;
1249 }
1250
1251 // get the db version size
1252 uint16_t size;
1253 if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1254 pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1255 setError("Failed to get DB version.\n"
1256 "A network error may have occurred.");
1257 return NULL;
1258 }
1259
1260 // get the db version
1261 delete[] pvt->_dbversion;
1262 pvt->_dbversion=new char[size+1];
1263 if (pvt->_cs->read(pvt->_dbversion,size)!=size) {
1264 setError("Failed to get DB version.\n"
1265 "A network error may have occurred.");
1266 delete[] pvt->_dbversion;
1267 pvt->_dbversion=NULL;
1268 return NULL;
1269 }
1270 pvt->_dbversion[size]='\0';
1271
1272 if (pvt->_debug) {
1273 debugPreStart();
1274 debugPrint(pvt->_dbversion);
1275 debugPrint("\n");
1276 debugPreEnd();
1277 }
1278 return pvt->_dbversion;
1279 }
1280
dbHostName()1281 const char *sqlrconnection::dbHostName() {
1282
1283 if (!openSession()) {
1284 return NULL;
1285 }
1286
1287 clearError();
1288
1289 if (pvt->_debug) {
1290 debugPreStart();
1291 debugPrint("DB Host Name...");
1292 debugPrint("\n");
1293 debugPreEnd();
1294 }
1295
1296 // tell the server we want the db host name
1297 pvt->_cs->write((uint16_t)DBHOSTNAME);
1298 flushWriteBuffer();
1299
1300 if (gotError()) {
1301 return NULL;
1302 }
1303
1304 // get the db host name size
1305 uint16_t size;
1306 if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1307 pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1308 setError("Failed to get DB host name.\n"
1309 "A network error may have occurred.");
1310 return NULL;
1311 }
1312
1313 // get the db host name
1314 delete[] pvt->_dbhostname;
1315 pvt->_dbhostname=new char[size+1];
1316 if (pvt->_cs->read(pvt->_dbhostname,size)!=size) {
1317 setError("Failed to get DB host name.\n"
1318 "A network error may have occurred.");
1319 delete[] pvt->_dbhostname;
1320 pvt->_dbhostname=NULL;
1321 return NULL;
1322 }
1323 pvt->_dbhostname[size]='\0';
1324
1325 if (pvt->_debug) {
1326 debugPreStart();
1327 debugPrint(pvt->_dbhostname);
1328 debugPrint("\n");
1329 debugPreEnd();
1330 }
1331 return pvt->_dbhostname;
1332 }
1333
dbIpAddress()1334 const char *sqlrconnection::dbIpAddress() {
1335
1336 if (!openSession()) {
1337 return NULL;
1338 }
1339
1340 clearError();
1341
1342 if (pvt->_debug) {
1343 debugPreStart();
1344 debugPrint("DB Ip Address...");
1345 debugPrint("\n");
1346 debugPreEnd();
1347 }
1348
1349 // tell the server we want the db ip address
1350 pvt->_cs->write((uint16_t)DBIPADDRESS);
1351 flushWriteBuffer();
1352
1353 if (gotError()) {
1354 return NULL;
1355 }
1356
1357 // get the db ip address size
1358 uint16_t size;
1359 if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1360 pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1361 setError("Failed to get DB ip address.\n"
1362 "A network error may have occurred.");
1363 return NULL;
1364 }
1365
1366 // get the db ip address
1367 delete[] pvt->_dbipaddress;
1368 pvt->_dbipaddress=new char[size+1];
1369 if (pvt->_cs->read(pvt->_dbipaddress,size)!=size) {
1370 setError("Failed to get DB ip address.\n"
1371 "A network error may have occurred.");
1372 delete[] pvt->_dbipaddress;
1373 pvt->_dbipaddress=NULL;
1374 return NULL;
1375 }
1376 pvt->_dbipaddress[size]='\0';
1377
1378 if (pvt->_debug) {
1379 debugPreStart();
1380 debugPrint(pvt->_dbipaddress);
1381 debugPrint("\n");
1382 debugPreEnd();
1383 }
1384 return pvt->_dbipaddress;
1385 }
1386
serverVersion()1387 const char *sqlrconnection::serverVersion() {
1388
1389 if (!openSession()) {
1390 return NULL;
1391 }
1392
1393 clearError();
1394
1395 if (pvt->_debug) {
1396 debugPreStart();
1397 debugPrint("Server Version...");
1398 debugPrint("\n");
1399 debugPreEnd();
1400 }
1401
1402 // tell the server we want the server version
1403 pvt->_cs->write((uint16_t)SERVERVERSION);
1404 flushWriteBuffer();
1405
1406 if (gotError()) {
1407 return NULL;
1408 }
1409
1410 // get the server version size
1411 uint16_t size;
1412 if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1413 pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1414 setError("Failed to get Server version.\n"
1415 "A network error may have occurred.");
1416 return NULL;
1417 }
1418
1419 // get the server version
1420 delete[] pvt->_serverversion;
1421 pvt->_serverversion=new char[size+1];
1422 if (pvt->_cs->read(pvt->_serverversion,size)!=size) {
1423 setError("Failed to get Server version.\n"
1424 "A network error may have occurred.");
1425 delete[] pvt->_serverversion;
1426 pvt->_serverversion=NULL;
1427 return NULL;
1428 }
1429 pvt->_serverversion[size]='\0';
1430
1431 if (pvt->_debug) {
1432 debugPreStart();
1433 debugPrint(pvt->_serverversion);
1434 debugPrint("\n");
1435 debugPreEnd();
1436 }
1437 return pvt->_serverversion;
1438 }
1439
clientVersion()1440 const char *sqlrconnection::clientVersion() {
1441 return SQLR_VERSION;
1442 }
1443
bindFormat()1444 const char *sqlrconnection::bindFormat() {
1445
1446 if (!openSession()) {
1447 return NULL;
1448 }
1449
1450 clearError();
1451
1452 if (pvt->_debug) {
1453 debugPreStart();
1454 debugPrint("bind format...");
1455 debugPrint("\n");
1456 debugPreEnd();
1457 }
1458
1459 // tell the server we want the bind format
1460 pvt->_cs->write((uint16_t)BINDFORMAT);
1461 flushWriteBuffer();
1462
1463 if (gotError()) {
1464 return NULL;
1465 }
1466
1467
1468 // get the bindformat size
1469 uint16_t size;
1470 if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1471 pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1472 setError("Failed to get bind format.\n"
1473 "A network error may have occurred.");
1474 return NULL;
1475 }
1476
1477 // get the bindformat
1478 delete[] pvt->_bindformat;
1479 pvt->_bindformat=new char[size+1];
1480 if (pvt->_cs->read(pvt->_bindformat,size)!=size) {
1481 setError("Failed to get bind format.\n"
1482 "A network error may have occurred.");
1483 delete[] pvt->_bindformat;
1484 pvt->_bindformat=NULL;
1485 return NULL;
1486 }
1487 pvt->_bindformat[size]='\0';
1488
1489 if (pvt->_debug) {
1490 debugPreStart();
1491 debugPrint(pvt->_bindformat);
1492 debugPrint("\n");
1493 debugPreEnd();
1494 }
1495 return pvt->_bindformat;
1496 }
1497
selectDatabase(const char * database)1498 bool sqlrconnection::selectDatabase(const char *database) {
1499
1500 if (!charstring::length(database)) {
1501 return true;
1502 }
1503
1504 clearError();
1505
1506 if (!openSession()) {
1507 return 0;
1508 }
1509
1510 if (pvt->_debug) {
1511 debugPreStart();
1512 debugPrint("Selecting database ");
1513 debugPrint(database);
1514 debugPrint("...\n");
1515 debugPreEnd();
1516 }
1517
1518 // tell the server we want to select a db
1519 pvt->_cs->write((uint16_t)SELECT_DATABASE);
1520
1521 // send the database name
1522 uint32_t len=charstring::length(database);
1523 pvt->_cs->write(len);
1524 if (len) {
1525 pvt->_cs->write(database,len);
1526 }
1527 flushWriteBuffer();
1528
1529 return !gotError();
1530 }
1531
getCurrentDatabase()1532 const char *sqlrconnection::getCurrentDatabase() {
1533
1534 if (!openSession()) {
1535 return NULL;
1536 }
1537
1538 clearError();
1539
1540 if (pvt->_debug) {
1541 debugPreStart();
1542 debugPrint("Getting the current database...\n");
1543 debugPreEnd();
1544 }
1545
1546 clearError();
1547
1548 // tell the server we want to get the current db
1549 pvt->_cs->write((uint16_t)GET_CURRENT_DATABASE);
1550 flushWriteBuffer();
1551
1552 if (gotError()) {
1553 return NULL;
1554 }
1555
1556 // get the current db name size
1557 uint16_t size;
1558 if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1559 pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1560 setError("Failed to get the current database.\n"
1561 "A network error may have occurred.");
1562 return NULL;
1563 }
1564
1565 // get the current db name
1566 delete[] pvt->_currentdbname;
1567 pvt->_currentdbname=new char[size+1];
1568 if (pvt->_cs->read(pvt->_currentdbname,size)!=size) {
1569 setError("Failed to get the current database.\n"
1570 "A network error may have occurred.");
1571 delete[] pvt->_currentdbname;
1572 pvt->_currentdbname=NULL;
1573 return NULL;
1574 }
1575 pvt->_currentdbname[size]='\0';
1576
1577 if (pvt->_debug) {
1578 debugPreStart();
1579 debugPrint(pvt->_currentdbname);
1580 debugPrint("\n");
1581 debugPreEnd();
1582 }
1583 return pvt->_currentdbname;
1584 }
1585
getCurrentSchema()1586 const char *sqlrconnection::getCurrentSchema() {
1587
1588 if (!openSession()) {
1589 return NULL;
1590 }
1591
1592 clearError();
1593
1594 if (pvt->_debug) {
1595 debugPreStart();
1596 debugPrint("Getting the current schema...\n");
1597 debugPreEnd();
1598 }
1599
1600 clearError();
1601
1602 // tell the server we want to get the current schema
1603 pvt->_cs->write((uint16_t)GET_CURRENT_SCHEMA);
1604 flushWriteBuffer();
1605
1606 if (gotError()) {
1607 return NULL;
1608 }
1609
1610 // get the current schema name size
1611 uint16_t size;
1612 if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1613 pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1614 setError("Failed to get the current database.\n"
1615 "A network error may have occurred.");
1616 return NULL;
1617 }
1618
1619 // get the current schema name
1620 delete[] pvt->_currentschemaname;
1621 pvt->_currentschemaname=new char[size+1];
1622 if (pvt->_cs->read(pvt->_currentschemaname,size)!=size) {
1623 setError("Failed to get the current database.\n"
1624 "A network error may have occurred.");
1625 delete[] pvt->_currentschemaname;
1626 pvt->_currentschemaname=NULL;
1627 return NULL;
1628 }
1629 pvt->_currentschemaname[size]='\0';
1630
1631 if (pvt->_debug) {
1632 debugPreStart();
1633 debugPrint(pvt->_currentschemaname);
1634 debugPrint("\n");
1635 debugPreEnd();
1636 }
1637 return pvt->_currentschemaname;
1638 }
1639
getLastInsertId()1640 uint64_t sqlrconnection::getLastInsertId() {
1641
1642 if (!openSession()) {
1643 return 0;
1644 }
1645
1646 clearError();
1647
1648 if (pvt->_debug) {
1649 debugPreStart();
1650 debugPrint("Getting the last insert id...\n");
1651 debugPreEnd();
1652 }
1653
1654 // tell the server we want the last insert id
1655 pvt->_cs->write((uint16_t)GET_LAST_INSERT_ID);
1656 flushWriteBuffer();
1657
1658 if (gotError()) {
1659 return 0;
1660 }
1661
1662 // get the last insert id
1663 uint64_t id=0;
1664 if (pvt->_cs->read(&id)!=sizeof(uint64_t)) {
1665 setError("Failed to get the last insert id.\n"
1666 "A network error may have occurred.");
1667 return 0;
1668 }
1669
1670 if (pvt->_debug) {
1671 debugPreStart();
1672 debugPrint("Got the last insert id: ");
1673 debugPrint((int64_t)id);
1674 debugPrint("\n");
1675 debugPreEnd();
1676 }
1677 return id;
1678 }
1679
autoCommitOn()1680 bool sqlrconnection::autoCommitOn() {
1681 return autoCommit(true);
1682 }
1683
autoCommitOff()1684 bool sqlrconnection::autoCommitOff() {
1685 return autoCommit(false);
1686 }
1687
autoCommit(bool on)1688 bool sqlrconnection::autoCommit(bool on) {
1689
1690 if (!openSession()) {
1691 return false;
1692 }
1693
1694 clearError();
1695
1696 if (pvt->_debug) {
1697 debugPreStart();
1698 debugPrint((on)?"Setting autocommit on\n":
1699 "Setting autocommit off\n");
1700 debugPreEnd();
1701 }
1702
1703 pvt->_cs->write((uint16_t)AUTOCOMMIT);
1704 pvt->_cs->write(on);
1705 flushWriteBuffer();
1706
1707 return !gotError();
1708 }
1709
begin()1710 bool sqlrconnection::begin() {
1711
1712 if (!openSession()) {
1713 return false;
1714 }
1715
1716 clearError();
1717
1718 if (pvt->_debug) {
1719 debugPreStart();
1720 debugPrint("Beginning...\n");
1721 debugPreEnd();
1722 }
1723
1724 pvt->_cs->write((uint16_t)BEGIN);
1725 flushWriteBuffer();
1726
1727 return !gotError();
1728 }
1729
commit()1730 bool sqlrconnection::commit() {
1731
1732 if (!openSession()) {
1733 return false;
1734 }
1735
1736 clearError();
1737
1738 if (pvt->_debug) {
1739 debugPreStart();
1740 debugPrint("Committing...\n");
1741 debugPreEnd();
1742 }
1743
1744 pvt->_cs->write((uint16_t)COMMIT);
1745 flushWriteBuffer();
1746
1747 return !gotError();
1748 }
1749
rollback()1750 bool sqlrconnection::rollback() {
1751
1752 if (!openSession()) {
1753 return false;
1754 }
1755
1756 clearError();
1757
1758 if (pvt->_debug) {
1759 debugPreStart();
1760 debugPrint("Rolling Back...\n");
1761 debugPreEnd();
1762 }
1763
1764 pvt->_cs->write((uint16_t)ROLLBACK);
1765 flushWriteBuffer();
1766
1767 return !gotError();
1768 }
1769
errorMessage()1770 const char *sqlrconnection::errorMessage() {
1771 return pvt->_error;
1772 }
1773
errorNumber()1774 int64_t sqlrconnection::errorNumber() {
1775 return pvt->_errorno;
1776 }
1777
clearError()1778 void sqlrconnection::clearError() {
1779 delete[] pvt->_error;
1780 pvt->_error=NULL;
1781 pvt->_errorno=0;
1782 }
1783
setError(const char * err)1784 void sqlrconnection::setError(const char *err) {
1785
1786 if (pvt->_debug) {
1787 debugPreStart();
1788 debugPrint("Setting Error\n");
1789 debugPreEnd();
1790 }
1791
1792 delete[] pvt->_error;
1793 pvt->_error=charstring::duplicate(err);
1794
1795 if (pvt->_debug) {
1796 debugPreStart();
1797 debugPrint(pvt->_error);
1798 debugPrint("\n");
1799 debugPreEnd();
1800 }
1801 }
1802
getError()1803 uint16_t sqlrconnection::getError() {
1804
1805 clearError();
1806
1807 if (pvt->_debug) {
1808 debugPreStart();
1809 debugPrint("Checking for error\n");
1810 debugPreEnd();
1811 }
1812
1813 // get whether an error occurred or not
1814 uint16_t status;
1815 if (pvt->_cs->read(&status,pvt->_responsetimeoutsec,
1816 pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1817 setError("Failed to get the error status.\n"
1818 "A network error may have occurred.");
1819 return ERROR_OCCURRED;
1820 }
1821
1822 // if no error occurred, return that
1823 if (status==NO_ERROR_OCCURRED) {
1824 if (pvt->_debug) {
1825 debugPreStart();
1826 debugPrint("No error occurred\n");
1827 debugPreEnd();
1828 }
1829 return status;
1830 }
1831
1832 if (pvt->_debug) {
1833 debugPreStart();
1834 debugPrint("An error occurred\n");
1835 debugPreEnd();
1836 }
1837
1838 // get the error code
1839 if (pvt->_cs->read((uint64_t *)&pvt->_errorno)!=sizeof(uint64_t)) {
1840 setError("Failed to get the error code.\n"
1841 "A network error may have occurred.");
1842 return status;
1843 }
1844
1845 // get the error size
1846 uint16_t size;
1847 if (pvt->_cs->read(&size)!=sizeof(uint16_t)) {
1848 setError("Failed to get the error size.\n"
1849 "A network error may have occurred.");
1850 return status;
1851 }
1852
1853 // get the error string
1854 pvt->_error=new char[size+1];
1855 if (pvt->_cs->read(pvt->_error,size)!=size) {
1856 setError("Failed to get the error string.\n"
1857 "A network error may have occurred.");
1858 return status;
1859 }
1860 pvt->_error[size]='\0';
1861
1862 if (pvt->_debug) {
1863 debugPreStart();
1864 debugPrint("Got error:\n");
1865 debugPrint(pvt->_errorno);
1866 debugPrint(": ");
1867 debugPrint(pvt->_error);
1868 debugPrint("\n");
1869 debugPreEnd();
1870 }
1871 return status;
1872 }
1873
gotError()1874 bool sqlrconnection::gotError() {
1875 uint16_t status=getError();
1876 if (status==NO_ERROR_OCCURRED) {
1877 return false;
1878 }
1879 if (status==ERROR_OCCURRED_DISCONNECT) {
1880 endSession();
1881 }
1882 return true;
1883 }
1884
debugOn()1885 void sqlrconnection::debugOn() {
1886 pvt->_debug=true;
1887 }
1888
debugOff()1889 void sqlrconnection::debugOff() {
1890 pvt->_debug=false;
1891 }
1892
setDebugFile(const char * filename)1893 void sqlrconnection::setDebugFile(const char *filename) {
1894 pvt->_debugfile.close();
1895 error::clearError();
1896 if (filename && *filename &&
1897 !pvt->_debugfile.open(filename,O_WRONLY|O_APPEND) &&
1898 error::getErrorNumber()==ENOENT) {
1899 pvt->_debugfile.create(filename,
1900 permissions::evalPermString("rw-r--r--"));
1901 }
1902 }
1903
getDebug()1904 bool sqlrconnection::getDebug() {
1905 return pvt->_debug;
1906 }
1907
debugPreStart()1908 void sqlrconnection::debugPreStart() {
1909 if (pvt->_webdebug==-1) {
1910 const char *docroot=environment::getValue("DOCUMENT_ROOT");
1911 if (!charstring::isNullOrEmpty(docroot)) {
1912 pvt->_webdebug=1;
1913 } else {
1914 pvt->_webdebug=0;
1915 }
1916 }
1917 if (pvt->_webdebug==1) {
1918 debugPrint("<pre>\n");
1919 }
1920 }
1921
debugPreEnd()1922 void sqlrconnection::debugPreEnd() {
1923 if (pvt->_webdebug==1) {
1924 debugPrint("</pre>\n");
1925 }
1926 }
1927
debugPrintFunction(int (* printfunction)(const char *,...))1928 void sqlrconnection::debugPrintFunction(
1929 int (*printfunction)(const char *,...)) {
1930 pvt->_printfunction=printfunction;
1931 }
1932
debugPrint(const char * string)1933 void sqlrconnection::debugPrint(const char *string) {
1934 if (pvt->_printfunction) {
1935 (*pvt->_printfunction)("%s",string);
1936 } else if (pvt->_debugfile.getFileDescriptor()!=-1) {
1937 pvt->_debugfile.printf("%s",string);
1938 } else {
1939 stdoutput.printf("%s",string);
1940 }
1941 }
1942
debugPrint(int64_t number)1943 void sqlrconnection::debugPrint(int64_t number) {
1944 if (pvt->_printfunction) {
1945 (*pvt->_printfunction)("%lld",(long long)number);
1946 } else if (pvt->_debugfile.getFileDescriptor()!=-1) {
1947 pvt->_debugfile.printf("%lld",(long long)number);
1948 } else {
1949 stdoutput.printf("%lld",(long long)number);
1950 }
1951 }
1952
debugPrint(double number)1953 void sqlrconnection::debugPrint(double number) {
1954 if (pvt->_printfunction) {
1955 (*pvt->_printfunction)("%f",number);
1956 } else if (pvt->_debugfile.getFileDescriptor()!=-1) {
1957 pvt->_debugfile.printf("%f",number);
1958 } else {
1959 stdoutput.printf("%f",number);
1960 }
1961 }
1962
debugPrint(char character)1963 void sqlrconnection::debugPrint(char character) {
1964 if (pvt->_printfunction) {
1965 (*pvt->_printfunction)("%c",character);
1966 } else if (pvt->_debugfile.getFileDescriptor()!=-1) {
1967 pvt->_debugfile.printf("%c",character);
1968 } else {
1969 stdoutput.printf("%c",character);
1970 }
1971 }
1972
debugPrintBlob(const char * blob,uint32_t length)1973 void sqlrconnection::debugPrintBlob(const char *blob, uint32_t length) {
1974 debugPrint('\n');
1975 int column=0;
1976 for (uint32_t i=0; i<length; i++) {
1977 if (blob[i]>=' ' && blob[i]<='~') {
1978 debugPrint(blob[i]);
1979 } else {
1980 debugPrint('.');
1981 }
1982 column++;
1983 if (column==80) {
1984 debugPrint('\n');
1985 column=0;
1986 }
1987 }
1988 debugPrint('\n');
1989 }
1990
debugPrintClob(const char * clob,uint32_t length)1991 void sqlrconnection::debugPrintClob(const char *clob, uint32_t length) {
1992 debugPrint('\n');
1993 for (uint32_t i=0; i<length; i++) {
1994 if (clob[i]=='\0') {
1995 debugPrint("\\0");
1996 } else {
1997 debugPrint(clob[i]);
1998 }
1999 }
2000 debugPrint('\n');
2001 }
2002
setClientInfo(const char * clientinfo)2003 void sqlrconnection::setClientInfo(const char *clientinfo) {
2004 delete[] pvt->_clientinfo;
2005 pvt->_clientinfo=charstring::duplicate(clientinfo);
2006 pvt->_clientinfolen=charstring::length(clientinfo);
2007 }
2008
getClientInfo() const2009 const char *sqlrconnection::getClientInfo() const {
2010 return pvt->_clientinfo;
2011 }
2012
cs()2013 socketclient *sqlrconnection::cs() {
2014 return pvt->_cs;
2015 }
2016
endsessionsent()2017 bool sqlrconnection::endsessionsent() {
2018 return pvt->_endsessionsent;
2019 }
2020
suspendsessionsent()2021 bool sqlrconnection::suspendsessionsent() {
2022 return pvt->_suspendsessionsent;
2023 }
2024
connected()2025 bool sqlrconnection::connected() {
2026 return pvt->_connected;
2027 }
2028
responsetimeoutsec()2029 int32_t sqlrconnection::responsetimeoutsec() {
2030 return pvt->_responsetimeoutsec;
2031 }
2032
responsetimeoutusec()2033 int32_t sqlrconnection::responsetimeoutusec() {
2034 return pvt->_responsetimeoutusec;
2035 }
2036
errorno()2037 int64_t sqlrconnection::errorno() {
2038 return pvt->_errorno;
2039 }
2040
error()2041 char *sqlrconnection::error() {
2042 return pvt->_error;
2043 }
2044
clientinfo()2045 char *sqlrconnection::clientinfo() {
2046 return pvt->_clientinfo;
2047 }
2048
clientinfolen()2049 uint64_t sqlrconnection::clientinfolen() {
2050 return pvt->_clientinfolen;
2051 }
2052
debug()2053 bool sqlrconnection::debug() {
2054 return pvt->_debug;
2055 }
2056
firstcursor(sqlrcursor * cur)2057 void sqlrconnection::firstcursor(sqlrcursor *cur) {
2058 pvt->_firstcursor=cur;
2059 }
2060
lastcursor(sqlrcursor * cur)2061 void sqlrconnection::lastcursor(sqlrcursor *cur) {
2062 pvt->_lastcursor=cur;
2063 }
2064
lastcursor()2065 sqlrcursor *sqlrconnection::lastcursor() {
2066 return pvt->_lastcursor;
2067 }
2068
isYes(const char * str)2069 bool sqlrconnection::isYes(const char *str) {
2070 return charstring::isYes(str);
2071 }
2072
isNo(const char * str)2073 bool sqlrconnection::isNo(const char *str) {
2074 return charstring::isNo(str);
2075 }
2076
setBindVariableDelimiters(const char * delimiters)2077 void sqlrconnection::setBindVariableDelimiters(const char *delimiters) {
2078 pvt->_questionmarksupported=charstring::contains(delimiters,'?');
2079 pvt->_colonsupported=charstring::contains(delimiters,':');
2080 pvt->_atsignsupported=charstring::contains(delimiters,'@');
2081 pvt->_dollarsignsupported=charstring::contains(delimiters,'$');
2082 }
2083
getBindVariableDelimiterQuestionMarkSupported()2084 bool sqlrconnection::getBindVariableDelimiterQuestionMarkSupported() {
2085 return pvt->_questionmarksupported;
2086 }
2087
getBindVariableDelimiterColonSupported()2088 bool sqlrconnection::getBindVariableDelimiterColonSupported() {
2089 return pvt->_colonsupported;
2090 }
2091
getBindVariableDelimiterAtSignSupported()2092 bool sqlrconnection::getBindVariableDelimiterAtSignSupported() {
2093 return pvt->_atsignsupported;
2094 }
2095
getBindVariableDelimiterDollarSignSupported()2096 bool sqlrconnection::getBindVariableDelimiterDollarSignSupported() {
2097 return pvt->_dollarsignsupported;
2098 }
2099