1 // Copyright (c) 1999-2018 David Muse
2 // See the file COPYING for more information
3 
4 #include <config.h>
5 #include <sqlrelay/sqlrclient.h>
6 #include <rudiments/inetsocketclient.h>
7 #include <rudiments/unixsocketclient.h>
8 #include <rudiments/file.h>
9 #include <rudiments/environment.h>
10 #include <rudiments/charstring.h>
11 #include <rudiments/stdio.h>
12 #include <rudiments/error.h>
13 #include <rudiments/permissions.h>
14 #include <rudiments/gss.h>
15 #include <rudiments/tls.h>
16 #include <rudiments/sys.h>
17 #include <defines.h>
18 #include <defaults.h>
19 
20 #ifndef MAXPATHLEN
21         #define MAXPATHLEN 256
22 #endif
23 
24 class sqlrconnectionprivate {
25 	friend class sqlrconnection;
26 	private:
27 
28 		// clients
29 		inetsocketclient	_ics;
30 		unixsocketclient	_ucs;
31 		socketclient		*_cs;
32 
33 		// session state
34 		bool	_endsessionsent;
35 		bool	_suspendsessionsent;
36 		bool	_connected;
37 
38 		// connection
39 		char		*_server;
40 		uint16_t	_listenerinetport;
41 		uint16_t	_connectioninetport;
42 		char		*_listenerunixport;
43 		const char	*_connectionunixport;
44 		char		_connectionunixportbuffer[MAXPATHLEN+1];
45 		int32_t		_connecttimeoutsec;
46 		int32_t		_connecttimeoutusec;
47 		int32_t		_authtimeoutsec;
48 		int32_t		_authtimeoutusec;
49 		int32_t		_responsetimeoutsec;
50 		int32_t		_responsetimeoutusec;
51 		int32_t		_retrytime;
52 		int32_t		_tries;
53 
54 		// auth
55 		char		*_user;
56 		uint32_t	_userlen;
57 		char		*_password;
58 		uint32_t	_passwordlen;
59 
60 		// gss
61 		bool		_usekrb;
62 		char		*_krbservice;
63 		char		*_krbmech;
64 		char		*_krbflags;
65 		gsscredentials	_gcred;
66 		gssmechanism	_gmech;
67 		gsscontext	_gctx;
68 
69 		// tls
70 		bool		_usetls;
71 		char		*_tlsversion;
72 		char		*_tlscert;
73 		char		*_tlspassword;
74 		char		*_tlsciphers;
75 		char		*_tlsvalidate;
76 		char		*_tlsca;
77 		uint16_t	_tlsdepth;
78 		tlscontext	_tctx;
79 
80 		securitycontext	*_ctx;
81 
82 		// error
83 		int64_t		_errorno;
84 		char		*_error;
85 
86 		// identify
87 		char		*_id;
88 
89 		// db version
90 		char		*_dbversion;
91 
92 		// db host name
93 		char		*_dbhostname;
94 
95 		// db ip address
96 		char		*_dbipaddress;
97 
98 		// server version
99 		char		*_serverversion;
100 
101 		// current database name
102 		char		*_currentdbname;
103 
104 		// current schema name
105 		char		*_currentschemaname;
106 
107 		// bind format
108 		char		*_bindformat;
109 
110 		// bind delimiters
111 		bool		_questionmarksupported;
112 		bool		_colonsupported;
113 		bool		_atsignsupported;
114 		bool		_dollarsignsupported;
115 
116 		// client info
117 		char		*_clientinfo;
118 		uint64_t	_clientinfolen;
119 
120 		// debug
121 		bool		_debug;
122 		int32_t		_webdebug;
123 		int		(*_printfunction)(const char *,...);
124 		file		_debugfile;
125 
126 		// copy references flag
127 		bool		_copyrefs;
128 
129 		// cursor list
130 		sqlrcursor	*_firstcursor;
131 		sqlrcursor	*_lastcursor;
132 };
133 
sqlrconnection(const char * server,uint16_t port,const char * socket,const char * user,const char * password,int32_t retrytime,int32_t tries,bool copyreferences)134 sqlrconnection::sqlrconnection(const char *server, uint16_t port,
135 					const char *socket,
136 					const char *user, const char *password,
137 					int32_t retrytime, int32_t tries,
138 					bool copyreferences) {
139 	init(server,port,socket,user,password,retrytime,tries,copyreferences);
140 }
141 
sqlrconnection(const char * server,uint16_t port,const char * socket,const char * user,const char * password,int32_t retrytime,int32_t tries)142 sqlrconnection::sqlrconnection(const char *server, uint16_t port,
143 					const char *socket,
144 					const char *user, const char *password,
145 					int32_t retrytime, int32_t tries) {
146 	init(server,port,socket,user,password,retrytime,tries,false);
147 }
148 
init(const char * server,uint16_t port,const char * socket,const char * user,const char * password,int32_t retrytime,int32_t tries,bool copyreferences)149 void sqlrconnection::init(const char *server, uint16_t port,
150 					const char *socket,
151 					const char *user, const char *password,
152 					int32_t retrytime, int32_t tries,
153 					bool copyreferences) {
154 
155 	pvt=new sqlrconnectionprivate;
156 
157 	pvt->_copyrefs=copyreferences;
158 
159 	// retry reads if they get interrupted by signals
160 	pvt->_ucs.translateByteOrder();
161 	pvt->_ucs.retryInterruptedReads();
162 	pvt->_ics.retryInterruptedReads();
163 	pvt->_cs=&pvt->_ucs;
164 
165 	// connection
166 	pvt->_server=(pvt->_copyrefs)?
167 			charstring::duplicate(server):
168 			(char *)server;
169 	pvt->_listenerinetport=port;
170 	pvt->_listenerunixport=(pvt->_copyrefs)?
171 				charstring::duplicate(socket):
172 				(char *)socket;
173 	pvt->_retrytime=retrytime;
174 	pvt->_tries=tries;
175 
176 	// initialize timeouts
177 	setTimeoutFromEnv("SQLR_CLIENT_CONNECT_TIMEOUT",
178 			&pvt->_connecttimeoutsec,&pvt->_connecttimeoutusec);
179 	setTimeoutFromEnv("SQLR_CLIENT_AUTHENTICATION_TIMEOUT",
180 			&pvt->_authtimeoutsec,&pvt->_authtimeoutusec);
181 	setTimeoutFromEnv("SQLR_CLIENT_RESPONSE_TIMEOUT",
182 			&pvt->_responsetimeoutsec,&pvt->_responsetimeoutusec);
183 
184 	// authentication
185 	pvt->_user=(pvt->_copyrefs)?
186 			charstring::duplicate(user):
187 			(char *)user;
188 	pvt->_password=(pvt->_copyrefs)?
189 			charstring::duplicate(password):
190 			(char *)password;
191 	pvt->_userlen=charstring::length(user);
192 	pvt->_passwordlen=charstring::length(password);
193 	pvt->_usekrb=false;
194 	pvt->_krbservice=NULL;
195 	pvt->_krbmech=NULL;
196 	pvt->_krbflags=NULL;
197 
198 	pvt->_usetls=false;
199 	pvt->_tlsversion=NULL;
200 	pvt->_tlscert=NULL;
201 	pvt->_tlspassword=NULL;
202 	pvt->_tlsciphers=NULL;
203 	pvt->_tlsvalidate=(pvt->_copyrefs)?
204 				charstring::duplicate("no"):
205 				(char *)"no";
206 	pvt->_tlsca=NULL;
207 	pvt->_tlsdepth=0;
208 
209 	pvt->_ctx=NULL;
210 
211 	// database id
212 	pvt->_id=NULL;
213 
214 	// db version
215 	pvt->_dbversion=NULL;
216 
217 	// db host name
218 	pvt->_dbhostname=NULL;
219 
220 	// db ip address
221 	pvt->_dbipaddress=NULL;
222 
223 	// server version
224 	pvt->_serverversion=NULL;
225 
226 	// current database name
227 	pvt->_currentdbname=NULL;
228 
229 	// current schema name
230 	pvt->_currentschemaname=NULL;
231 
232 	// bind format
233 	pvt->_bindformat=NULL;
234 
235 	// bind delimiters
236 	pvt->_questionmarksupported=true;
237 	pvt->_colonsupported=true;
238 	pvt->_atsignsupported=true;
239 	pvt->_dollarsignsupported=true;
240 
241 	// client info
242 	pvt->_clientinfo=NULL;
243 	pvt->_clientinfolen=0;
244 
245 	// session state
246 	pvt->_connected=false;
247 	clearSessionFlags();
248 
249 	// debug print function
250 	pvt->_printfunction=NULL;
251 
252 	// enable/disable debug
253 	const char	*sqlrdebug=environment::getValue("SQLRDEBUG");
254 	if (!sqlrdebug || !*sqlrdebug) {
255 		sqlrdebug=environment::getValue("SQLR_CLIENT_DEBUG");
256 	}
257 	pvt->_debug=(sqlrdebug && *sqlrdebug && !charstring::isNo(sqlrdebug));
258 	if (pvt->_debug && !charstring::isYes(sqlrdebug) &&
259 				!charstring::isNo(sqlrdebug)) {
260 		setDebugFile(sqlrdebug);
261 	}
262 	pvt->_webdebug=-1;
263 
264 	// error
265 	pvt->_errorno=0;
266 	pvt->_error=NULL;
267 
268 	// cursor list
269 	pvt->_firstcursor=NULL;
270 	pvt->_lastcursor=NULL;
271 }
272 
clearSessionFlags()273 void sqlrconnection::clearSessionFlags() {
274 
275 	// indicate that the session hasn't been suspended or ended
276 	pvt->_endsessionsent=false;
277 	pvt->_suspendsessionsent=false;
278 }
279 
~sqlrconnection()280 sqlrconnection::~sqlrconnection() {
281 
282 	// unless it was already ended or suspended, end the session
283 	if (!pvt->_endsessionsent && !pvt->_suspendsessionsent) {
284 		endSession();
285 	}
286 
287 	// deallocate error
288 	delete[] pvt->_error;
289 
290 	// deallocate id
291 	delete[] pvt->_id;
292 
293 	// deallocate dbversion
294 	delete[] pvt->_dbversion;
295 
296 	// deallocate db host name
297 	delete[] pvt->_dbhostname;
298 
299 	// deallocate db ip address
300 	delete[] pvt->_dbipaddress;
301 
302 	// deallocate server version
303 	delete[] pvt->_serverversion;
304 
305 	// deallocate current database name
306 	delete[] pvt->_currentdbname;
307 
308 	// deallocate current schema name
309 	delete[] pvt->_currentschemaname;
310 
311 	// deallocate bindformat
312 	delete[] pvt->_bindformat;
313 
314 	// deallocate client info
315 	delete[] pvt->_clientinfo;
316 
317 	// deallocate copied references
318 	if (pvt->_copyrefs) {
319 		delete[] pvt->_server;
320 		delete[] pvt->_listenerunixport;
321 		delete[] pvt->_user;
322 		delete[] pvt->_password;
323 		delete[] pvt->_krbservice;
324 		delete[] pvt->_krbmech;
325 		delete[] pvt->_krbflags;
326 		delete[] pvt->_tlsversion;
327 		delete[] pvt->_tlscert;
328 		delete[] pvt->_tlspassword;
329 		delete[] pvt->_tlsvalidate;
330 		delete[] pvt->_tlsca;
331 	}
332 
333 	// detach all cursors attached to this client
334 	sqlrcursor	*currentcursor=pvt->_firstcursor;
335 	while (currentcursor) {
336 		pvt->_firstcursor=currentcursor;
337 		currentcursor=currentcursor->next();
338 		pvt->_firstcursor->sqlrc(NULL);
339 	}
340 
341 	if (pvt->_debug) {
342 		debugPreStart();
343 		debugPrint("Deallocated connection\n");
344 		debugPreEnd();
345 	}
346 
347 	pvt->_debugfile.close();
348 
349 	delete pvt;
350 }
351 
enableKerberos(const char * service,const char * mech,const char * flags)352 void sqlrconnection::enableKerberos(const char *service,
353 					const char *mech,
354 					const char *flags) {
355 
356 	// clear any existing configuration
357 	if (pvt->_usekrb || pvt->_usetls) {
358 		disableEncryption();
359 	}
360 
361 	if (!gss::supported()) {
362 		return;
363 	}
364 
365 	pvt->_usekrb=true;
366 
367 	// "Negotiate" is Windows' default mech.  Force Kerberos instead.
368 	char	*os=sys::getOperatingSystemName();
369 	if (!charstring::compare(os,"Windows",7) &&
370 			charstring::isNullOrEmpty(mech)) {
371 		mech="Kerberos";
372 	}
373 	delete[] os;
374 
375 	if (pvt->_copyrefs) {
376 		delete[] pvt->_krbservice;
377 		pvt->_krbservice=charstring::duplicate(
378 			(!charstring::isNullOrEmpty(service))?
379 					service:DEFAULT_KRBSERVICE);
380 		delete[] pvt->_krbmech;
381 		pvt->_krbmech=charstring::duplicate(mech);
382 		delete[] pvt->_krbflags;
383 		pvt->_krbflags=charstring::duplicate(flags);
384 	} else {
385 		pvt->_krbservice=(char *)
386 			(!charstring::isNullOrEmpty(service)?
387 					service:DEFAULT_KRBSERVICE);
388 		pvt->_krbmech=(char *)mech;
389 		pvt->_krbflags=(char *)flags;
390 	}
391 }
392 
enableTls(const char * version,const char * cert,const char * password,const char * ciphers,const char * validate,const char * ca,uint16_t depth)393 void sqlrconnection::enableTls(const char *version,
394 					const char *cert,
395 					const char *password,
396 					const char *ciphers,
397 					const char *validate,
398 					const char *ca,
399 					uint16_t depth) {
400 
401 	// clear any existing configuration
402 	if (pvt->_usekrb || pvt->_usetls) {
403 		disableEncryption();
404 	}
405 
406 	if (!tls::supported()) {
407 		return;
408 	}
409 
410 	pvt->_usetls=true;
411 
412 	if (pvt->_copyrefs) {
413 		delete[] pvt->_tlsversion;
414 		pvt->_tlsversion=charstring::duplicate(version);
415 		delete[] pvt->_tlscert;
416 		pvt->_tlscert=charstring::duplicate(cert);
417 		delete[] pvt->_tlspassword;
418 		pvt->_tlspassword=charstring::duplicate(password);
419 		delete[] pvt->_tlsciphers;
420 		pvt->_tlsciphers=charstring::duplicate(ciphers);
421 		delete[] pvt->_tlsvalidate;
422 		pvt->_tlsvalidate=charstring::duplicate(validate);
423 		delete[] pvt->_tlsca;
424 		pvt->_tlsca=charstring::duplicate(ca);
425 	} else {
426 		pvt->_tlsversion=(char *)version;
427 		pvt->_tlscert=(char *)cert;
428 		pvt->_tlspassword=(char *)password;
429 		pvt->_tlsciphers=(char *)ciphers;
430 		pvt->_tlsvalidate=(char *)validate;
431 		pvt->_tlsca=(char *)ca;
432 	}
433 	pvt->_tlsdepth=depth;
434 }
435 
disableEncryption()436 void sqlrconnection::disableEncryption() {
437 
438 	if (pvt->_copyrefs) {
439 		delete[] pvt->_krbservice;
440 		delete[] pvt->_krbmech;
441 		delete[] pvt->_krbflags;
442 
443 		delete[] pvt->_tlsversion;
444 		delete[] pvt->_tlscert;
445 		delete[] pvt->_tlspassword;
446 		delete[] pvt->_tlsciphers;
447 		delete[] pvt->_tlsvalidate;
448 		delete[] pvt->_tlsca;
449 	}
450 	pvt->_krbservice=NULL;
451 	pvt->_krbmech=NULL;
452 	pvt->_krbflags=NULL;
453 	pvt->_usekrb=false;
454 
455 	pvt->_tlsversion=NULL;
456 	pvt->_tlscert=NULL;
457 	pvt->_tlspassword=NULL;
458 	pvt->_tlsciphers=NULL;
459 	pvt->_tlsvalidate=NULL;
460 	pvt->_tlsca=NULL;
461 	pvt->_tlsdepth=0;
462 	pvt->_usetls=false;
463 }
464 
setConnectTimeout(int32_t timeoutsec,int32_t timeoutusec)465 void sqlrconnection::setConnectTimeout(int32_t timeoutsec,
466 					int32_t timeoutusec) {
467 	pvt->_connecttimeoutsec=timeoutsec;
468 	pvt->_connecttimeoutusec=timeoutusec;
469 }
470 
setAuthenticationTimeout(int32_t timeoutsec,int32_t timeoutusec)471 void sqlrconnection::setAuthenticationTimeout(int32_t timeoutsec,
472 						int32_t timeoutusec) {
473 	pvt->_authtimeoutsec=timeoutsec;
474 	pvt->_authtimeoutusec=timeoutusec;
475 }
476 
setResponseTimeout(int32_t timeoutsec,int32_t timeoutusec)477 void sqlrconnection::setResponseTimeout(int32_t timeoutsec,
478 						int32_t timeoutusec) {
479 	pvt->_responsetimeoutsec=timeoutsec;
480 	pvt->_responsetimeoutusec=timeoutusec;
481 }
482 
setTimeoutFromEnv(const char * var,int32_t * timeoutsec,int32_t * timeoutusec)483 void sqlrconnection::setTimeoutFromEnv(const char *var,
484 					int32_t *timeoutsec,
485 					int32_t *timeoutusec) {
486 	const char	*timeout=environment::getValue(var);
487 	if (charstring::isNumber(timeout)) {
488 		*timeoutsec=charstring::toInteger(timeout);
489 		long double	dbl=charstring::toFloatC(timeout);
490 		dbl=dbl-(long double)(*timeoutsec);
491 		*timeoutusec=(int32_t)(dbl*1000000.0);
492 	} else {
493 		*timeoutsec=-1;
494 		*timeoutusec=-1;
495 	}
496 }
497 
getConnectTimeout(int32_t * timeoutsec,int32_t * timeoutusec)498 void sqlrconnection::getConnectTimeout(int32_t *timeoutsec,
499 					int32_t *timeoutusec) {
500 	*timeoutsec=pvt->_connecttimeoutsec;
501 	*timeoutusec=pvt->_connecttimeoutusec;
502 }
503 
getAuthenticationTimeout(int32_t * timeoutsec,int32_t * timeoutusec)504 void sqlrconnection::getAuthenticationTimeout(int32_t *timeoutsec,
505 						int32_t *timeoutusec) {
506 	*timeoutsec=pvt->_authtimeoutsec;
507 	*timeoutusec=pvt->_authtimeoutusec;
508 }
509 
getResponseTimeout(int32_t * timeoutsec,int32_t * timeoutusec)510 void sqlrconnection::getResponseTimeout(int32_t *timeoutsec,
511 						int32_t *timeoutusec) {
512 	*timeoutsec=pvt->_responsetimeoutsec;
513 	*timeoutusec=pvt->_responsetimeoutusec;
514 }
515 
endSession()516 void sqlrconnection::endSession() {
517 
518 	if (pvt->_debug) {
519 		debugPreStart();
520 		debugPrint("Ending Session\n");
521 		debugPreEnd();
522 	}
523 
524 	// abort each cursor's result set
525 	sqlrcursor	*currentcursor=pvt->_firstcursor;
526 	while (currentcursor) {
527 		// FIXME: do we need to clearResultSet() here too?
528 		if (!currentcursor->endofresultset()) {
529 			currentcursor->closeResultSet(false);
530 		}
531 		currentcursor->havecursorid(false);
532 		currentcursor=currentcursor->next();
533 	}
534 
535 	// write an END_SESSION to the connection
536 	if (pvt->_connected) {
537 		pvt->_cs->write((uint16_t)END_SESSION);
538 		flushWriteBuffer();
539 		pvt->_endsessionsent=true;
540 		closeConnection();
541 	}
542 }
543 
flushWriteBuffer()544 void sqlrconnection::flushWriteBuffer() {
545 	pvt->_cs->flushWriteBuffer(-1,-1);
546 }
547 
closeConnection()548 void sqlrconnection::closeConnection() {
549 	pvt->_cs->close();
550 	pvt->_connected=false;
551 }
552 
suspendSession()553 bool sqlrconnection::suspendSession() {
554 
555 	if (!openSession()) {
556 		return 0;
557 	}
558 
559 	clearError();
560 
561 	if (pvt->_debug) {
562 		debugPreStart();
563 		debugPrint("Suspending Session\n");
564 		debugPreEnd();
565 	}
566 
567 	// suspend the session
568 	pvt->_cs->write((uint16_t)SUSPEND_SESSION);
569 	flushWriteBuffer();
570 	pvt->_suspendsessionsent=true;
571 
572 	// check for error
573 	if (gotError()) {
574 		return false;
575 	}
576 
577 	// get port to resume on
578 	bool	retval=getNewPort();
579 
580 	closeConnection();
581 
582 	return retval;
583 }
584 
openSession()585 bool sqlrconnection::openSession() {
586 
587 	if (pvt->_connected) {
588 		return true;
589 	}
590 
591 	if (!reConfigureSockets()) {
592 		pvt->_connected=false;
593 		return false;
594 	}
595 
596 	if (pvt->_debug) {
597 		debugPreStart();
598 		debugPrint("Connecting to listener...");
599 		debugPrint("\n");
600 		debugPreEnd();
601 	}
602 
603 	// open a connection to the listener
604 	int	openresult=RESULT_ERROR;
605 
606 	// first, try for a unix connection
607 	if (!charstring::isNullOrEmpty(pvt->_listenerunixport)) {
608 
609 		if (pvt->_debug) {
610 			debugPreStart();
611 			debugPrint("Unix socket: ");
612 			debugPrint(pvt->_listenerunixport);
613 			debugPrint("\n");
614 			debugPreEnd();
615 		}
616 
617 		openresult=pvt->_ucs.connect(pvt->_listenerunixport,
618 						pvt->_connecttimeoutsec,
619 						pvt->_connecttimeoutusec,
620 						pvt->_retrytime,pvt->_tries);
621 		if (openresult==RESULT_SUCCESS) {
622 
623 			pvt->_ucs.setSocketReadBufferSize(65536);
624 			pvt->_ucs.setSocketWriteBufferSize(65536);
625 
626 			pvt->_cs=&pvt->_ucs;
627 		}
628 	}
629 
630 	// then try for an inet connection
631 	if (openresult!=RESULT_SUCCESS && pvt->_listenerinetport) {
632 
633 		if (pvt->_debug) {
634 			debugPreStart();
635 			debugPrint("Inet socket: ");
636 			debugPrint(pvt->_server);
637 			debugPrint(":");
638 			debugPrint((int64_t)pvt->_listenerinetport);
639 			debugPrint("\n");
640 			debugPreEnd();
641 		}
642 
643 		openresult=pvt->_ics.connect(pvt->_server,
644 						pvt->_listenerinetport,
645 						pvt->_connecttimeoutsec,
646 						pvt->_connecttimeoutusec,
647 						pvt->_retrytime,pvt->_tries);
648 		if (openresult==RESULT_SUCCESS) {
649 
650 			pvt->_ics.setSocketReadBufferSize(65536);
651 			pvt->_ics.setSocketWriteBufferSize(65536);
652 
653 			pvt->_ics.dontUseNaglesAlgorithm();
654 
655 			pvt->_cs=&pvt->_ics;
656 		}
657 	}
658 
659 	// handle failures
660 	if (openresult!=RESULT_SUCCESS) {
661 		setConnectFailedError();
662 		return false;
663 	}
664 
665 	// if tls is enabled and we're using an inet socket,
666 	// then we may need to validate the host
667 	if (pvt->_usetls && pvt->_cs==&pvt->_ics) {
668 
669 		if (!validateCertificate()) {
670 			setError("TLS error: common name mismatch");
671 			pvt->_cs->close();
672 			return false;
673 		}
674 	}
675 
676 	// if we made it here then everything went
677 	// well and we are successfully connected
678 	pvt->_connected=true;
679 
680 	// send protocol info
681 	protocol();
682 
683 	// auth
684 	auth();
685 
686 	return true;
687 }
688 
validateCertificate()689 bool sqlrconnection::validateCertificate() {
690 
691 	// If we're not doing any validation then just return true. If we're
692 	// just doing ca validation then the connect would have failed if the
693 	// certificate was invalid, so we can just return true for that too.
694 	if (charstring::isNo(pvt->_tlsvalidate) ||
695 		!charstring::compareIgnoringCase(pvt->_tlsvalidate,"ca")) {
696 		return true;
697 	}
698 
699 	// get the cert from the server
700 	tlscertificate		*cert=((tlscontext *)pvt->_ctx)->
701 						getPeerCertificate();
702 	if (!cert) {
703 		// this should never happen, the connect()
704 		// should have failed if no cert was supplied
705 		return false;
706 	}
707 
708 	// get the subject alternate names and common name from the cert
709 	linkedlist< char * >	*sans=cert->getSubjectAlternateNames();
710 	const char		*commonname=cert->getCommonName();
711 
712 	// should we validate the host name or domain?
713 	bool	host=!charstring::compareIgnoringCase(
714 					pvt->_tlsvalidate,"ca+host");
715 
716 	// get the server name to validate against
717 	const char	*server=pvt->_server;
718 	if (!host) {
719 		const char	*dot=charstring::findFirst(server,'.');
720 		if (dot) {
721 			server=dot+1;
722 		}
723 	}
724 
725 	// if there are any subject alternate
726 	// names then validate against those
727 	if (sans && sans->getLength()) {
728 
729 		for (linkedlistnode< char * > *node=sans->getFirst();
730 					node; node=node->getNext()) {
731 
732 			const char	*san=node->getValue();
733 			if (!host) {
734 				const char	*dot=
735 					charstring::findFirst(san,'.');
736 				if (dot) {
737 					san=dot+1;
738 				}
739 			}
740 
741 			if (pvt->_debug) {
742 				debugPreStart();
743 				debugPrint(pvt->_tlsvalidate);
744 				debugPrint(": ");
745 				debugPrint(server);
746 				debugPrint("=");
747 				debugPrint(san);
748 				debugPrint("\n");
749 				debugPreEnd();
750 			}
751 
752 			if (!charstring::compareIgnoringCase(server,san)) {
753 				return true;
754 			}
755 		}
756 
757 		return false;
758 	}
759 
760 
761 	// if there were no subject alternate names
762 	// then just validate against the common name
763 	if (!host) {
764 		const char	*dot=charstring::findFirst(commonname,'.');
765 		if (dot) {
766 			commonname=dot+1;
767 		}
768 	}
769 
770 	if (pvt->_debug) {
771 		debugPreStart();
772 		debugPrint(pvt->_tlsvalidate);
773 		debugPrint(": ");
774 		debugPrint(server);
775 		debugPrint("=");
776 		debugPrint(commonname);
777 		debugPrint("\n");
778 		debugPreEnd();
779 	}
780 
781 	return !charstring::compareIgnoringCase(server,commonname);
782 }
783 
784 
reConfigureSockets()785 bool sqlrconnection::reConfigureSockets() {
786 
787 	pvt->_ucs.setReadBufferSize(65536);
788 	pvt->_ucs.setWriteBufferSize(65536);
789 	//pvt->_ucs.useAsyncWrite();
790 
791 	pvt->_ics.setReadBufferSize(65536);
792 	pvt->_ics.setWriteBufferSize(65536);
793 	//pvt->_ics.useAsyncWrite();
794 
795 
796 	if (pvt->_usekrb) {
797 
798 		if (pvt->_debug) {
799 			debugPreStart();
800 			debugPrint("kerberos encryption/"
801 					"authentication enabled\n");
802 			debugPrint("  service: ");
803 			if (pvt->_krbservice) {
804 				debugPrint(pvt->_krbservice);
805 			}
806 			debugPrint("\n");
807 			debugPrint("  mech: ");
808 			if (pvt->_krbmech) {
809 				debugPrint(pvt->_krbmech);
810 			}
811 			debugPrint("\n");
812 			debugPrint("  flags: ");
813 			if (pvt->_krbflags) {
814 				debugPrint(pvt->_krbflags);
815 			}
816 			debugPrint("\n");
817 			debugPreEnd();
818 		}
819 
820 		pvt->_gmech.clear();
821 		pvt->_gmech.initialize(pvt->_krbmech);
822 
823 		pvt->_gcred.clearDesiredMechanisms();
824 		pvt->_gcred.addDesiredMechanism(&pvt->_gmech);
825 
826 		if (!pvt->_gcred.acquired() &&
827 				!charstring::isNullOrEmpty(pvt->_user)) {
828 
829 			if (pvt->_gcred.acquireForUser(pvt->_user)) {
830 				if (pvt->_debug) {
831 					debugPreStart();
832 					debugPrint("acquired kerberos "
833 							"credentials for: ");
834 					debugPrint(pvt->_user);
835 					debugPrint("\n");
836 					debugPreEnd();
837 				}
838 			} else {
839 				if (pvt->_debug) {
840 					debugPreStart();
841 					if (pvt->_gcred.getMajorStatus()) {
842 						debugPrint(pvt->_gcred.
843 						getMechanismMinorStatus());
844 					}
845 					debugPreEnd();
846 				}
847 				setError("Failed to acquire "
848 						"kerberos credentials.");
849 				return false;
850 			}
851 		}
852 
853 		pvt->_gctx.close();
854 		pvt->_gctx.setDesiredMechanism(&pvt->_gmech);
855 		pvt->_gctx.setDesiredFlags(pvt->_krbflags);
856 		pvt->_gctx.setService(pvt->_krbservice);
857 		pvt->_gctx.setCredentials(&pvt->_gcred);
858 
859 		pvt->_ctx=&pvt->_gctx;
860 
861 	} else if (pvt->_usetls) {
862 
863 		if (pvt->_debug) {
864 			debugPreStart();
865 			debugPrint("TLS encryption/authentication enabled\n");
866 			debugPrint("  version: ");
867 			if (pvt->_tlsversion) {
868 				debugPrint(pvt->_tlsversion);
869 			}
870 			debugPrint("\n");
871 			debugPrint("  cert: ");
872 			if (pvt->_tlscert) {
873 				debugPrint(pvt->_tlscert);
874 			}
875 			debugPrint("\n");
876 			debugPrint("  private key password: ");
877 			if (pvt->_tlspassword) {
878 				debugPrint(pvt->_tlspassword);
879 			}
880 			debugPrint("\n");
881 			debugPrint("  ciphers: ");
882 			if (pvt->_tlsciphers) {
883 				debugPrint(pvt->_tlsciphers);
884 			}
885 			debugPrint("\n");
886 			debugPrint("  validate: ");
887 			if (pvt->_tlsvalidate) {
888 				debugPrint(pvt->_tlsvalidate);
889 			}
890 			debugPrint("\n");
891 			debugPrint("  ca: ");
892 			if (pvt->_tlsca) {
893 				debugPrint(pvt->_tlsca);
894 			}
895 			debugPrint("\n");
896 			debugPrint("  depth: ");
897 			debugPrint((int64_t)pvt->_tlsdepth);
898 			debugPrint("\n");
899 			debugPreEnd();
900 		}
901 
902 		pvt->_tctx.close();
903 		pvt->_tctx.setProtocolVersion(pvt->_tlsversion);
904 		pvt->_tctx.setCertificateChainFile(pvt->_tlscert);
905 		pvt->_tctx.setPrivateKeyPassword(pvt->_tlspassword);
906 		pvt->_tctx.setCiphers(pvt->_tlsciphers);
907 		pvt->_tctx.setValidatePeer(!charstring::compareIgnoringCase(
908 						pvt->_tlsvalidate,"ca",2));
909 		pvt->_tctx.setCertificateAuthority(pvt->_tlsca);
910 		pvt->_tctx.setValidationDepth(pvt->_tlsdepth);
911 
912 		pvt->_ctx=&pvt->_tctx;
913 
914 	} else {
915 
916 		if (pvt->_debug) {
917 			debugPreStart();
918 			debugPrint("encryption disabled\n");
919 			debugPreEnd();
920 		}
921 
922 		pvt->_ctx=NULL;
923 	}
924 
925 	pvt->_ucs.setSecurityContext(pvt->_ctx);
926 	pvt->_ics.setSecurityContext(pvt->_ctx);
927 
928 	return true;
929 }
930 
setConnectFailedError()931 void sqlrconnection::setConnectFailedError() {
932 	if (pvt->_usekrb && pvt->_gctx.getMajorStatus()) {
933 		setError(pvt->_gctx.getMechanismMinorStatus());
934 	} else if (pvt->_usetls && pvt->_tctx.getError()) {
935 		stringbuffer	err;
936 		err.append("TLS error: ");
937 		err.append(pvt->_tctx.getErrorString());
938 		setError(err.getString());
939 	} else {
940 		setError("Couldn't connect to the listener.");
941 	}
942 }
943 
protocol()944 void sqlrconnection::protocol() {
945 
946 	if (pvt->_debug) {
947 		debugPreStart();
948 		debugPrint("Protocol : sqlrclient version 2\n");
949 		debugPreEnd();
950 	}
951 
952 	pvt->_cs->write((uint16_t)PROTOCOLVERSION);
953 	pvt->_cs->write((uint16_t)2);
954 }
955 
auth()956 void sqlrconnection::auth() {
957 
958 	if (pvt->_debug) {
959 		debugPreStart();
960 		debugPrint("Auth : ");
961 		debugPrint(pvt->_user);
962 		debugPrint(":");
963 		// hide the password
964 		for (uint8_t i=0; i<pvt->_passwordlen; i++) {
965 			debugPrint("*");
966 		}
967 		debugPrint("\n");
968 		debugPreEnd();
969 	}
970 
971 	pvt->_cs->write((uint16_t)AUTH);
972 
973 	pvt->_cs->write(pvt->_userlen);
974 	pvt->_cs->write(pvt->_user,pvt->_userlen);
975 
976 	pvt->_cs->write(pvt->_passwordlen);
977 	pvt->_cs->write(pvt->_password,pvt->_passwordlen);
978 
979 	// I don't think this needs to be here...
980 	// Just commenting it out for now though.
981 	//flushWriteBuffer();
982 }
983 
getNewPort()984 bool sqlrconnection::getNewPort() {
985 
986 	// get the size of the unix port string
987 	uint16_t	size;
988 	if (pvt->_cs->read(&size)!=sizeof(uint16_t)) {
989 		setError("Failed to get the size of "
990 				"the unix connection port.\n"
991 				"A network error may have occurred.");
992 		return false;
993 	}
994 
995 	if (size>MAXPATHLEN) {
996 
997 		// if size is too big, return an error
998 		stringbuffer	errstr;
999 		errstr.append("The pathname of the unix port was too long: ");
1000 		errstr.append(size);
1001 		errstr.append(" bytes.  The maximum size is ");
1002 		errstr.append((uint16_t)MAXPATHLEN);
1003 		errstr.append(" bytes.");
1004 		setError(errstr.getString());
1005 		return false;
1006 	}
1007 
1008 	// get the unix port string
1009 	if (size && pvt->_cs->read(pvt->_connectionunixportbuffer,size)!=size) {
1010 		setError("Failed to get the unix connection port.\n"
1011 				"A network error may have occurred.");
1012 		return false;
1013 	}
1014 	pvt->_connectionunixportbuffer[size]='\0';
1015 	pvt->_connectionunixport=pvt->_connectionunixportbuffer;
1016 
1017 	// get the inet port
1018 	if (pvt->_cs->read(&pvt->_connectioninetport)!=sizeof(uint16_t)) {
1019 		setError("Failed to get the inet connection port.\n"
1020 				"A network error may have occurred.");
1021 		return false;
1022 	}
1023 
1024 	// the server will send 0 for both the size of the unixport and
1025 	// the inet port if a server error occurred
1026 	if (!size && !pvt->_connectioninetport) {
1027 		setError("An error occurred on the server.");
1028 		return false;
1029 	}
1030 	return true;
1031 }
1032 
getConnectionPort()1033 uint16_t sqlrconnection::getConnectionPort() {
1034 
1035 	if (!pvt->_suspendsessionsent && !openSession()) {
1036 		return 0;
1037 	}
1038 
1039 	if (pvt->_debug) {
1040 		debugPreStart();
1041 		debugPrint("Getting connection port: ");
1042 		debugPrint((int64_t)pvt->_connectioninetport);
1043 		debugPrint("\n");
1044 		debugPreEnd();
1045 	}
1046 
1047 	return pvt->_connectioninetport;
1048 }
1049 
getConnectionSocket()1050 const char *sqlrconnection::getConnectionSocket() {
1051 
1052 	if (!pvt->_suspendsessionsent && !openSession()) {
1053 		return NULL;
1054 	}
1055 
1056 	if (pvt->_debug) {
1057 		debugPreStart();
1058 		debugPrint("Getting connection socket: ");
1059 		if (pvt->_connectionunixport) {
1060 			debugPrint(pvt->_connectionunixport);
1061 		}
1062 		debugPrint("\n");
1063 		debugPreEnd();
1064 	}
1065 
1066 	if (pvt->_connectionunixport) {
1067 		return pvt->_connectionunixport;
1068 	}
1069 	return NULL;
1070 }
1071 
resumeSession(uint16_t port,const char * socket)1072 bool sqlrconnection::resumeSession(uint16_t port, const char *socket) {
1073 
1074 	// if already pvt->_connected, end the session
1075 	if (pvt->_connected) {
1076 		endSession();
1077 	}
1078 
1079 	if (pvt->_debug) {
1080 		debugPreStart();
1081 		debugPrint("Resuming Session: \n");
1082 		debugPrint("port: ");
1083 		debugPrint((int64_t)port);
1084 		debugPrint("\n");
1085 		debugPrint("socket: ");
1086 		debugPrint(socket);
1087 		debugPrint("\n");
1088 		debugPreEnd();
1089 	}
1090 
1091 	// set the connectionunixport and connectioninetport values
1092 	if (pvt->_copyrefs) {
1093 		if (charstring::length(socket)<=MAXPATHLEN) {
1094 			charstring::copy(pvt->_connectionunixportbuffer,socket);
1095 			pvt->_connectionunixport=pvt->_connectionunixportbuffer;
1096 		} else {
1097 			pvt->_connectionunixport="";
1098 		}
1099 	} else {
1100 		pvt->_connectionunixport=(char *)socket;
1101 	}
1102 	pvt->_connectioninetport=port;
1103 
1104 	if (!reConfigureSockets()) {
1105 		pvt->_connected=false;
1106 		return false;
1107 	}
1108 
1109 	// first, try for the unix port
1110 	if (!charstring::isNullOrEmpty(socket)) {
1111 		pvt->_connected=(pvt->_ucs.connect(
1112 					socket,-1,-1,
1113 					pvt->_retrytime,
1114 					pvt->_tries)==RESULT_SUCCESS);
1115 		if (pvt->_connected) {
1116 			pvt->_cs=&pvt->_ucs;
1117 		}
1118 	}
1119 
1120 	// then try for the inet port
1121 	if (!pvt->_connected) {
1122 		pvt->_connected=(pvt->_ics.connect(
1123 					pvt->_server,port,-1,-1,
1124 					pvt->_retrytime,
1125 					pvt->_tries)==RESULT_SUCCESS);
1126 		if (pvt->_connected) {
1127 			pvt->_cs=&pvt->_ics;
1128 		}
1129 	}
1130 
1131 	if (pvt->_connected) {
1132 
1133 		// send protocol info
1134 		protocol();
1135 
1136 		if (pvt->_debug) {
1137 			debugPreStart();
1138 			debugPrint("success");
1139 			debugPrint("\n");
1140 			debugPreEnd();
1141 		}
1142 		clearSessionFlags();
1143 	} else {
1144 		setConnectFailedError();
1145 		if (pvt->_debug) {
1146 			debugPreStart();
1147 			debugPrint("failure");
1148 			debugPrint("\n");
1149 			debugPreEnd();
1150 		}
1151 	}
1152 
1153 	return pvt->_connected;
1154 }
1155 
ping()1156 bool sqlrconnection::ping() {
1157 
1158 	if (!openSession()) {
1159 		return 0;
1160 	}
1161 
1162 	clearError();
1163 
1164 	if (pvt->_debug) {
1165 		debugPreStart();
1166 		debugPrint("Pinging...\n");
1167 		debugPreEnd();
1168 	}
1169 
1170 	pvt->_cs->write((uint16_t)PING);
1171 	flushWriteBuffer();
1172 
1173 	return !gotError();
1174 }
1175 
identify()1176 const char *sqlrconnection::identify() {
1177 
1178 	if (!openSession()) {
1179 		return NULL;
1180 	}
1181 
1182 	clearError();
1183 
1184 	if (pvt->_debug) {
1185 		debugPreStart();
1186 		debugPrint("Identifying...\n");
1187 		debugPreEnd();
1188 	}
1189 
1190 	// tell the server we want the identity of the db
1191 	pvt->_cs->write((uint16_t)IDENTIFY);
1192 	flushWriteBuffer();
1193 
1194 	if (gotError()) {
1195 		return NULL;
1196 	}
1197 
1198 	// get the identity size
1199 	uint16_t	size;
1200 	if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1201 				pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1202 		setError("Failed to identify.\n"
1203 			"A network error may have occurred.");
1204 		return NULL;
1205 	}
1206 
1207 	// get the identity
1208 	delete[] pvt->_id;
1209 	pvt->_id=new char[size+1];
1210 	if (pvt->_cs->read(pvt->_id,size)!=size) {
1211 		setError("Failed to identify.\n"
1212 			"A network error may have occurred.");
1213 		delete[] pvt->_id;
1214 		pvt->_id=NULL;
1215 		return NULL;
1216 	}
1217 	pvt->_id[size]='\0';
1218 
1219 	if (pvt->_debug) {
1220 		debugPreStart();
1221 		debugPrint(pvt->_id);
1222 		debugPrint("\n");
1223 		debugPreEnd();
1224 	}
1225 	return pvt->_id;
1226 }
1227 
dbVersion()1228 const char *sqlrconnection::dbVersion() {
1229 
1230 	if (!openSession()) {
1231 		return NULL;
1232 	}
1233 
1234 	clearError();
1235 
1236 	if (pvt->_debug) {
1237 		debugPreStart();
1238 		debugPrint("DB Version...");
1239 		debugPrint("\n");
1240 		debugPreEnd();
1241 	}
1242 
1243 	// tell the server we want the db version
1244 	pvt->_cs->write((uint16_t)DBVERSION);
1245 	flushWriteBuffer();
1246 
1247 	if (gotError()) {
1248 		return NULL;
1249 	}
1250 
1251 	// get the db version size
1252 	uint16_t	size;
1253 	if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1254 				pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1255 		setError("Failed to get DB version.\n"
1256 			"A network error may have occurred.");
1257 		return NULL;
1258 	}
1259 
1260 	// get the db version
1261 	delete[] pvt->_dbversion;
1262 	pvt->_dbversion=new char[size+1];
1263 	if (pvt->_cs->read(pvt->_dbversion,size)!=size) {
1264 		setError("Failed to get DB version.\n"
1265 			"A network error may have occurred.");
1266 		delete[] pvt->_dbversion;
1267 		pvt->_dbversion=NULL;
1268 		return NULL;
1269 	}
1270 	pvt->_dbversion[size]='\0';
1271 
1272 	if (pvt->_debug) {
1273 		debugPreStart();
1274 		debugPrint(pvt->_dbversion);
1275 		debugPrint("\n");
1276 		debugPreEnd();
1277 	}
1278 	return pvt->_dbversion;
1279 }
1280 
dbHostName()1281 const char *sqlrconnection::dbHostName() {
1282 
1283 	if (!openSession()) {
1284 		return NULL;
1285 	}
1286 
1287 	clearError();
1288 
1289 	if (pvt->_debug) {
1290 		debugPreStart();
1291 		debugPrint("DB Host Name...");
1292 		debugPrint("\n");
1293 		debugPreEnd();
1294 	}
1295 
1296 	// tell the server we want the db host name
1297 	pvt->_cs->write((uint16_t)DBHOSTNAME);
1298 	flushWriteBuffer();
1299 
1300 	if (gotError()) {
1301 		return NULL;
1302 	}
1303 
1304 	// get the db host name size
1305 	uint16_t	size;
1306 	if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1307 				pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1308 		setError("Failed to get DB host name.\n"
1309 				"A network error may have occurred.");
1310 		return NULL;
1311 	}
1312 
1313 	// get the db host name
1314 	delete[] pvt->_dbhostname;
1315 	pvt->_dbhostname=new char[size+1];
1316 	if (pvt->_cs->read(pvt->_dbhostname,size)!=size) {
1317 		setError("Failed to get DB host name.\n"
1318 				"A network error may have occurred.");
1319 		delete[] pvt->_dbhostname;
1320 		pvt->_dbhostname=NULL;
1321 		return NULL;
1322 	}
1323 	pvt->_dbhostname[size]='\0';
1324 
1325 	if (pvt->_debug) {
1326 		debugPreStart();
1327 		debugPrint(pvt->_dbhostname);
1328 		debugPrint("\n");
1329 		debugPreEnd();
1330 	}
1331 	return pvt->_dbhostname;
1332 }
1333 
dbIpAddress()1334 const char *sqlrconnection::dbIpAddress() {
1335 
1336 	if (!openSession()) {
1337 		return NULL;
1338 	}
1339 
1340 	clearError();
1341 
1342 	if (pvt->_debug) {
1343 		debugPreStart();
1344 		debugPrint("DB Ip Address...");
1345 		debugPrint("\n");
1346 		debugPreEnd();
1347 	}
1348 
1349 	// tell the server we want the db ip address
1350 	pvt->_cs->write((uint16_t)DBIPADDRESS);
1351 	flushWriteBuffer();
1352 
1353 	if (gotError()) {
1354 		return NULL;
1355 	}
1356 
1357 	// get the db ip address size
1358 	uint16_t	size;
1359 	if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1360 				pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1361 		setError("Failed to get DB ip address.\n"
1362 				"A network error may have occurred.");
1363 		return NULL;
1364 	}
1365 
1366 	// get the db ip address
1367 	delete[] pvt->_dbipaddress;
1368 	pvt->_dbipaddress=new char[size+1];
1369 	if (pvt->_cs->read(pvt->_dbipaddress,size)!=size) {
1370 		setError("Failed to get DB ip address.\n"
1371 				"A network error may have occurred.");
1372 		delete[] pvt->_dbipaddress;
1373 		pvt->_dbipaddress=NULL;
1374 		return NULL;
1375 	}
1376 	pvt->_dbipaddress[size]='\0';
1377 
1378 	if (pvt->_debug) {
1379 		debugPreStart();
1380 		debugPrint(pvt->_dbipaddress);
1381 		debugPrint("\n");
1382 		debugPreEnd();
1383 	}
1384 	return pvt->_dbipaddress;
1385 }
1386 
serverVersion()1387 const char *sqlrconnection::serverVersion() {
1388 
1389 	if (!openSession()) {
1390 		return NULL;
1391 	}
1392 
1393 	clearError();
1394 
1395 	if (pvt->_debug) {
1396 		debugPreStart();
1397 		debugPrint("Server Version...");
1398 		debugPrint("\n");
1399 		debugPreEnd();
1400 	}
1401 
1402 	// tell the server we want the server version
1403 	pvt->_cs->write((uint16_t)SERVERVERSION);
1404 	flushWriteBuffer();
1405 
1406 	if (gotError()) {
1407 		return NULL;
1408 	}
1409 
1410 	// get the server version size
1411 	uint16_t	size;
1412 	if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1413 				pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1414 		setError("Failed to get Server version.\n"
1415 				"A network error may have occurred.");
1416 		return NULL;
1417 	}
1418 
1419 	// get the server version
1420 	delete[] pvt->_serverversion;
1421 	pvt->_serverversion=new char[size+1];
1422 	if (pvt->_cs->read(pvt->_serverversion,size)!=size) {
1423 		setError("Failed to get Server version.\n"
1424 				"A network error may have occurred.");
1425 		delete[] pvt->_serverversion;
1426 		pvt->_serverversion=NULL;
1427 		return NULL;
1428 	}
1429 	pvt->_serverversion[size]='\0';
1430 
1431 	if (pvt->_debug) {
1432 		debugPreStart();
1433 		debugPrint(pvt->_serverversion);
1434 		debugPrint("\n");
1435 		debugPreEnd();
1436 	}
1437 	return pvt->_serverversion;
1438 }
1439 
clientVersion()1440 const char *sqlrconnection::clientVersion() {
1441 	return SQLR_VERSION;
1442 }
1443 
bindFormat()1444 const char *sqlrconnection::bindFormat() {
1445 
1446 	if (!openSession()) {
1447 		return NULL;
1448 	}
1449 
1450 	clearError();
1451 
1452 	if (pvt->_debug) {
1453 		debugPreStart();
1454 		debugPrint("bind format...");
1455 		debugPrint("\n");
1456 		debugPreEnd();
1457 	}
1458 
1459 	// tell the server we want the bind format
1460 	pvt->_cs->write((uint16_t)BINDFORMAT);
1461 	flushWriteBuffer();
1462 
1463 	if (gotError()) {
1464 		return NULL;
1465 	}
1466 
1467 
1468 	// get the bindformat size
1469 	uint16_t	size;
1470 	if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1471 				pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1472 		setError("Failed to get bind format.\n"
1473 				"A network error may have occurred.");
1474 		return NULL;
1475 	}
1476 
1477 	// get the bindformat
1478 	delete[] pvt->_bindformat;
1479 	pvt->_bindformat=new char[size+1];
1480 	if (pvt->_cs->read(pvt->_bindformat,size)!=size) {
1481 		setError("Failed to get bind format.\n"
1482 				"A network error may have occurred.");
1483 		delete[] pvt->_bindformat;
1484 		pvt->_bindformat=NULL;
1485 		return NULL;
1486 	}
1487 	pvt->_bindformat[size]='\0';
1488 
1489 	if (pvt->_debug) {
1490 		debugPreStart();
1491 		debugPrint(pvt->_bindformat);
1492 		debugPrint("\n");
1493 		debugPreEnd();
1494 	}
1495 	return pvt->_bindformat;
1496 }
1497 
selectDatabase(const char * database)1498 bool sqlrconnection::selectDatabase(const char *database) {
1499 
1500 	if (!charstring::length(database)) {
1501 		return true;
1502 	}
1503 
1504 	clearError();
1505 
1506 	if (!openSession()) {
1507 		return 0;
1508 	}
1509 
1510 	if (pvt->_debug) {
1511 		debugPreStart();
1512 		debugPrint("Selecting database ");
1513 		debugPrint(database);
1514 		debugPrint("...\n");
1515 		debugPreEnd();
1516 	}
1517 
1518 	// tell the server we want to select a db
1519 	pvt->_cs->write((uint16_t)SELECT_DATABASE);
1520 
1521 	// send the database name
1522 	uint32_t	len=charstring::length(database);
1523 	pvt->_cs->write(len);
1524 	if (len) {
1525 		pvt->_cs->write(database,len);
1526 	}
1527 	flushWriteBuffer();
1528 
1529 	return !gotError();
1530 }
1531 
getCurrentDatabase()1532 const char *sqlrconnection::getCurrentDatabase() {
1533 
1534 	if (!openSession()) {
1535 		return NULL;
1536 	}
1537 
1538 	clearError();
1539 
1540 	if (pvt->_debug) {
1541 		debugPreStart();
1542 		debugPrint("Getting the current database...\n");
1543 		debugPreEnd();
1544 	}
1545 
1546 	clearError();
1547 
1548 	// tell the server we want to get the current db
1549 	pvt->_cs->write((uint16_t)GET_CURRENT_DATABASE);
1550 	flushWriteBuffer();
1551 
1552 	if (gotError()) {
1553 		return NULL;
1554 	}
1555 
1556 	// get the current db name size
1557 	uint16_t	size;
1558 	if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1559 				pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1560 		setError("Failed to get the current database.\n"
1561 				"A network error may have occurred.");
1562 		return NULL;
1563 	}
1564 
1565 	// get the current db name
1566 	delete[] pvt->_currentdbname;
1567 	pvt->_currentdbname=new char[size+1];
1568 	if (pvt->_cs->read(pvt->_currentdbname,size)!=size) {
1569 		setError("Failed to get the current database.\n"
1570 				"A network error may have occurred.");
1571 		delete[] pvt->_currentdbname;
1572 		pvt->_currentdbname=NULL;
1573 		return NULL;
1574 	}
1575 	pvt->_currentdbname[size]='\0';
1576 
1577 	if (pvt->_debug) {
1578 		debugPreStart();
1579 		debugPrint(pvt->_currentdbname);
1580 		debugPrint("\n");
1581 		debugPreEnd();
1582 	}
1583 	return pvt->_currentdbname;
1584 }
1585 
getCurrentSchema()1586 const char *sqlrconnection::getCurrentSchema() {
1587 
1588 	if (!openSession()) {
1589 		return NULL;
1590 	}
1591 
1592 	clearError();
1593 
1594 	if (pvt->_debug) {
1595 		debugPreStart();
1596 		debugPrint("Getting the current schema...\n");
1597 		debugPreEnd();
1598 	}
1599 
1600 	clearError();
1601 
1602 	// tell the server we want to get the current schema
1603 	pvt->_cs->write((uint16_t)GET_CURRENT_SCHEMA);
1604 	flushWriteBuffer();
1605 
1606 	if (gotError()) {
1607 		return NULL;
1608 	}
1609 
1610 	// get the current schema name size
1611 	uint16_t	size;
1612 	if (pvt->_cs->read(&size,pvt->_responsetimeoutsec,
1613 				pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1614 		setError("Failed to get the current database.\n"
1615 				"A network error may have occurred.");
1616 		return NULL;
1617 	}
1618 
1619 	// get the current schema name
1620 	delete[] pvt->_currentschemaname;
1621 	pvt->_currentschemaname=new char[size+1];
1622 	if (pvt->_cs->read(pvt->_currentschemaname,size)!=size) {
1623 		setError("Failed to get the current database.\n"
1624 				"A network error may have occurred.");
1625 		delete[] pvt->_currentschemaname;
1626 		pvt->_currentschemaname=NULL;
1627 		return NULL;
1628 	}
1629 	pvt->_currentschemaname[size]='\0';
1630 
1631 	if (pvt->_debug) {
1632 		debugPreStart();
1633 		debugPrint(pvt->_currentschemaname);
1634 		debugPrint("\n");
1635 		debugPreEnd();
1636 	}
1637 	return pvt->_currentschemaname;
1638 }
1639 
getLastInsertId()1640 uint64_t sqlrconnection::getLastInsertId() {
1641 
1642 	if (!openSession()) {
1643 		return 0;
1644 	}
1645 
1646 	clearError();
1647 
1648 	if (pvt->_debug) {
1649 		debugPreStart();
1650 		debugPrint("Getting the last insert id...\n");
1651 		debugPreEnd();
1652 	}
1653 
1654 	// tell the server we want the last insert id
1655 	pvt->_cs->write((uint16_t)GET_LAST_INSERT_ID);
1656 	flushWriteBuffer();
1657 
1658 	if (gotError()) {
1659 		return 0;
1660 	}
1661 
1662 	// get the last insert id
1663 	uint64_t	id=0;
1664 	if (pvt->_cs->read(&id)!=sizeof(uint64_t)) {
1665 		setError("Failed to get the last insert id.\n"
1666 				"A network error may have occurred.");
1667 		return 0;
1668 	}
1669 
1670 	if (pvt->_debug) {
1671 		debugPreStart();
1672 		debugPrint("Got the last insert id: ");
1673 		debugPrint((int64_t)id);
1674 		debugPrint("\n");
1675 		debugPreEnd();
1676 	}
1677 	return id;
1678 }
1679 
autoCommitOn()1680 bool sqlrconnection::autoCommitOn() {
1681 	return autoCommit(true);
1682 }
1683 
autoCommitOff()1684 bool sqlrconnection::autoCommitOff() {
1685 	return autoCommit(false);
1686 }
1687 
autoCommit(bool on)1688 bool sqlrconnection::autoCommit(bool on) {
1689 
1690 	if (!openSession()) {
1691 		return false;
1692 	}
1693 
1694 	clearError();
1695 
1696 	if (pvt->_debug) {
1697 		debugPreStart();
1698 		debugPrint((on)?"Setting autocommit on\n":
1699 				"Setting autocommit off\n");
1700 		debugPreEnd();
1701 	}
1702 
1703 	pvt->_cs->write((uint16_t)AUTOCOMMIT);
1704 	pvt->_cs->write(on);
1705 	flushWriteBuffer();
1706 
1707 	return !gotError();
1708 }
1709 
begin()1710 bool sqlrconnection::begin() {
1711 
1712 	if (!openSession()) {
1713 		return false;
1714 	}
1715 
1716 	clearError();
1717 
1718 	if (pvt->_debug) {
1719 		debugPreStart();
1720 		debugPrint("Beginning...\n");
1721 		debugPreEnd();
1722 	}
1723 
1724 	pvt->_cs->write((uint16_t)BEGIN);
1725 	flushWriteBuffer();
1726 
1727 	return !gotError();
1728 }
1729 
commit()1730 bool sqlrconnection::commit() {
1731 
1732 	if (!openSession()) {
1733 		return false;
1734 	}
1735 
1736 	clearError();
1737 
1738 	if (pvt->_debug) {
1739 		debugPreStart();
1740 		debugPrint("Committing...\n");
1741 		debugPreEnd();
1742 	}
1743 
1744 	pvt->_cs->write((uint16_t)COMMIT);
1745 	flushWriteBuffer();
1746 
1747 	return !gotError();
1748 }
1749 
rollback()1750 bool sqlrconnection::rollback() {
1751 
1752 	if (!openSession()) {
1753 		return false;
1754 	}
1755 
1756 	clearError();
1757 
1758 	if (pvt->_debug) {
1759 		debugPreStart();
1760 		debugPrint("Rolling Back...\n");
1761 		debugPreEnd();
1762 	}
1763 
1764 	pvt->_cs->write((uint16_t)ROLLBACK);
1765 	flushWriteBuffer();
1766 
1767 	return !gotError();
1768 }
1769 
errorMessage()1770 const char *sqlrconnection::errorMessage() {
1771 	return pvt->_error;
1772 }
1773 
errorNumber()1774 int64_t sqlrconnection::errorNumber() {
1775 	return pvt->_errorno;
1776 }
1777 
clearError()1778 void sqlrconnection::clearError() {
1779 	delete[] pvt->_error;
1780 	pvt->_error=NULL;
1781 	pvt->_errorno=0;
1782 }
1783 
setError(const char * err)1784 void sqlrconnection::setError(const char *err) {
1785 
1786 	if (pvt->_debug) {
1787 		debugPreStart();
1788 		debugPrint("Setting Error\n");
1789 		debugPreEnd();
1790 	}
1791 
1792 	delete[] pvt->_error;
1793 	pvt->_error=charstring::duplicate(err);
1794 
1795 	if (pvt->_debug) {
1796 		debugPreStart();
1797 		debugPrint(pvt->_error);
1798 		debugPrint("\n");
1799 		debugPreEnd();
1800 	}
1801 }
1802 
getError()1803 uint16_t sqlrconnection::getError() {
1804 
1805 	clearError();
1806 
1807 	if (pvt->_debug) {
1808 		debugPreStart();
1809 		debugPrint("Checking for error\n");
1810 		debugPreEnd();
1811 	}
1812 
1813 	// get whether an error occurred or not
1814 	uint16_t	status;
1815 	if (pvt->_cs->read(&status,pvt->_responsetimeoutsec,
1816 				pvt->_responsetimeoutusec)!=sizeof(uint16_t)) {
1817 		setError("Failed to get the error status.\n"
1818 				"A network error may have occurred.");
1819 		return ERROR_OCCURRED;
1820 	}
1821 
1822 	// if no error occurred, return that
1823 	if (status==NO_ERROR_OCCURRED) {
1824 		if (pvt->_debug) {
1825 			debugPreStart();
1826 			debugPrint("No error occurred\n");
1827 			debugPreEnd();
1828 		}
1829 		return status;
1830 	}
1831 
1832 	if (pvt->_debug) {
1833 		debugPreStart();
1834 		debugPrint("An error occurred\n");
1835 		debugPreEnd();
1836 	}
1837 
1838 	// get the error code
1839 	if (pvt->_cs->read((uint64_t *)&pvt->_errorno)!=sizeof(uint64_t)) {
1840 		setError("Failed to get the error code.\n"
1841 				"A network error may have occurred.");
1842 		return status;
1843 	}
1844 
1845 	// get the error size
1846 	uint16_t	size;
1847 	if (pvt->_cs->read(&size)!=sizeof(uint16_t)) {
1848 		setError("Failed to get the error size.\n"
1849 				"A network error may have occurred.");
1850 		return status;
1851 	}
1852 
1853 	// get the error string
1854 	pvt->_error=new char[size+1];
1855 	if (pvt->_cs->read(pvt->_error,size)!=size) {
1856 		setError("Failed to get the error string.\n"
1857 				"A network error may have occurred.");
1858 		return status;
1859 	}
1860 	pvt->_error[size]='\0';
1861 
1862 	if (pvt->_debug) {
1863 		debugPreStart();
1864 		debugPrint("Got error:\n");
1865 		debugPrint(pvt->_errorno);
1866 		debugPrint(": ");
1867 		debugPrint(pvt->_error);
1868 		debugPrint("\n");
1869 		debugPreEnd();
1870 	}
1871 	return status;
1872 }
1873 
gotError()1874 bool sqlrconnection::gotError() {
1875 	uint16_t	status=getError();
1876 	if (status==NO_ERROR_OCCURRED) {
1877 		return false;
1878 	}
1879 	if (status==ERROR_OCCURRED_DISCONNECT) {
1880 		endSession();
1881 	}
1882 	return true;
1883 }
1884 
debugOn()1885 void sqlrconnection::debugOn() {
1886 	pvt->_debug=true;
1887 }
1888 
debugOff()1889 void sqlrconnection::debugOff() {
1890 	pvt->_debug=false;
1891 }
1892 
setDebugFile(const char * filename)1893 void sqlrconnection::setDebugFile(const char *filename) {
1894 	pvt->_debugfile.close();
1895 	error::clearError();
1896 	if (filename && *filename &&
1897 		!pvt->_debugfile.open(filename,O_WRONLY|O_APPEND) &&
1898 				error::getErrorNumber()==ENOENT) {
1899 		pvt->_debugfile.create(filename,
1900 				permissions::evalPermString("rw-r--r--"));
1901 	}
1902 }
1903 
getDebug()1904 bool sqlrconnection::getDebug() {
1905 	return pvt->_debug;
1906 }
1907 
debugPreStart()1908 void sqlrconnection::debugPreStart() {
1909 	if (pvt->_webdebug==-1) {
1910 		const char	*docroot=environment::getValue("DOCUMENT_ROOT");
1911 		if (!charstring::isNullOrEmpty(docroot)) {
1912 			pvt->_webdebug=1;
1913 		} else {
1914 			pvt->_webdebug=0;
1915 		}
1916 	}
1917 	if (pvt->_webdebug==1) {
1918 		debugPrint("<pre>\n");
1919 	}
1920 }
1921 
debugPreEnd()1922 void sqlrconnection::debugPreEnd() {
1923 	if (pvt->_webdebug==1) {
1924 		debugPrint("</pre>\n");
1925 	}
1926 }
1927 
debugPrintFunction(int (* printfunction)(const char *,...))1928 void sqlrconnection::debugPrintFunction(
1929 				int (*printfunction)(const char *,...)) {
1930 	pvt->_printfunction=printfunction;
1931 }
1932 
debugPrint(const char * string)1933 void sqlrconnection::debugPrint(const char *string) {
1934 	if (pvt->_printfunction) {
1935 		(*pvt->_printfunction)("%s",string);
1936 	} else if (pvt->_debugfile.getFileDescriptor()!=-1) {
1937 		pvt->_debugfile.printf("%s",string);
1938 	} else {
1939 		stdoutput.printf("%s",string);
1940 	}
1941 }
1942 
debugPrint(int64_t number)1943 void sqlrconnection::debugPrint(int64_t number) {
1944 	if (pvt->_printfunction) {
1945 		(*pvt->_printfunction)("%lld",(long long)number);
1946 	} else if (pvt->_debugfile.getFileDescriptor()!=-1) {
1947 		pvt->_debugfile.printf("%lld",(long long)number);
1948 	} else {
1949 		stdoutput.printf("%lld",(long long)number);
1950 	}
1951 }
1952 
debugPrint(double number)1953 void sqlrconnection::debugPrint(double number) {
1954 	if (pvt->_printfunction) {
1955 		(*pvt->_printfunction)("%f",number);
1956 	} else if (pvt->_debugfile.getFileDescriptor()!=-1) {
1957 		pvt->_debugfile.printf("%f",number);
1958 	} else {
1959 		stdoutput.printf("%f",number);
1960 	}
1961 }
1962 
debugPrint(char character)1963 void sqlrconnection::debugPrint(char character) {
1964 	if (pvt->_printfunction) {
1965 		(*pvt->_printfunction)("%c",character);
1966 	} else if (pvt->_debugfile.getFileDescriptor()!=-1) {
1967 		pvt->_debugfile.printf("%c",character);
1968 	} else {
1969 		stdoutput.printf("%c",character);
1970 	}
1971 }
1972 
debugPrintBlob(const char * blob,uint32_t length)1973 void sqlrconnection::debugPrintBlob(const char *blob, uint32_t length) {
1974 	debugPrint('\n');
1975 	int	column=0;
1976 	for (uint32_t i=0; i<length; i++) {
1977 		if (blob[i]>=' ' && blob[i]<='~') {
1978 			debugPrint(blob[i]);
1979 		} else {
1980 			debugPrint('.');
1981 		}
1982 		column++;
1983 		if (column==80) {
1984 			debugPrint('\n');
1985 			column=0;
1986 		}
1987 	}
1988 	debugPrint('\n');
1989 }
1990 
debugPrintClob(const char * clob,uint32_t length)1991 void sqlrconnection::debugPrintClob(const char *clob, uint32_t length) {
1992 	debugPrint('\n');
1993 	for (uint32_t i=0; i<length; i++) {
1994 		if (clob[i]=='\0') {
1995 			debugPrint("\\0");
1996 		} else {
1997 			debugPrint(clob[i]);
1998 		}
1999 	}
2000 	debugPrint('\n');
2001 }
2002 
setClientInfo(const char * clientinfo)2003 void sqlrconnection::setClientInfo(const char *clientinfo) {
2004 	delete[] pvt->_clientinfo;
2005 	pvt->_clientinfo=charstring::duplicate(clientinfo);
2006 	pvt->_clientinfolen=charstring::length(clientinfo);
2007 }
2008 
getClientInfo() const2009 const char *sqlrconnection::getClientInfo() const {
2010 	return pvt->_clientinfo;
2011 }
2012 
cs()2013 socketclient *sqlrconnection::cs() {
2014 	return pvt->_cs;
2015 }
2016 
endsessionsent()2017 bool sqlrconnection::endsessionsent() {
2018 	return pvt->_endsessionsent;
2019 }
2020 
suspendsessionsent()2021 bool sqlrconnection::suspendsessionsent() {
2022 	return pvt->_suspendsessionsent;
2023 }
2024 
connected()2025 bool sqlrconnection::connected() {
2026 	return pvt->_connected;
2027 }
2028 
responsetimeoutsec()2029 int32_t sqlrconnection::responsetimeoutsec() {
2030 	return pvt->_responsetimeoutsec;
2031 }
2032 
responsetimeoutusec()2033 int32_t sqlrconnection::responsetimeoutusec() {
2034 	return pvt->_responsetimeoutusec;
2035 }
2036 
errorno()2037 int64_t sqlrconnection::errorno() {
2038 	return pvt->_errorno;
2039 }
2040 
error()2041 char *sqlrconnection::error() {
2042 	return pvt->_error;
2043 }
2044 
clientinfo()2045 char *sqlrconnection::clientinfo() {
2046 	return pvt->_clientinfo;
2047 }
2048 
clientinfolen()2049 uint64_t sqlrconnection::clientinfolen() {
2050 	return pvt->_clientinfolen;
2051 }
2052 
debug()2053 bool sqlrconnection::debug() {
2054 	return pvt->_debug;
2055 }
2056 
firstcursor(sqlrcursor * cur)2057 void sqlrconnection::firstcursor(sqlrcursor *cur) {
2058 	pvt->_firstcursor=cur;
2059 }
2060 
lastcursor(sqlrcursor * cur)2061 void sqlrconnection::lastcursor(sqlrcursor *cur) {
2062 	pvt->_lastcursor=cur;
2063 }
2064 
lastcursor()2065 sqlrcursor *sqlrconnection::lastcursor() {
2066 	return pvt->_lastcursor;
2067 }
2068 
isYes(const char * str)2069 bool sqlrconnection::isYes(const char *str) {
2070 	return charstring::isYes(str);
2071 }
2072 
isNo(const char * str)2073 bool sqlrconnection::isNo(const char *str) {
2074 	return charstring::isNo(str);
2075 }
2076 
setBindVariableDelimiters(const char * delimiters)2077 void sqlrconnection::setBindVariableDelimiters(const char *delimiters) {
2078 	pvt->_questionmarksupported=charstring::contains(delimiters,'?');
2079 	pvt->_colonsupported=charstring::contains(delimiters,':');
2080 	pvt->_atsignsupported=charstring::contains(delimiters,'@');
2081 	pvt->_dollarsignsupported=charstring::contains(delimiters,'$');
2082 }
2083 
getBindVariableDelimiterQuestionMarkSupported()2084 bool sqlrconnection::getBindVariableDelimiterQuestionMarkSupported() {
2085 	return pvt->_questionmarksupported;
2086 }
2087 
getBindVariableDelimiterColonSupported()2088 bool sqlrconnection::getBindVariableDelimiterColonSupported() {
2089 	return pvt->_colonsupported;
2090 }
2091 
getBindVariableDelimiterAtSignSupported()2092 bool sqlrconnection::getBindVariableDelimiterAtSignSupported() {
2093 	return pvt->_atsignsupported;
2094 }
2095 
getBindVariableDelimiterDollarSignSupported()2096 bool sqlrconnection::getBindVariableDelimiterDollarSignSupported() {
2097 	return pvt->_dollarsignsupported;
2098 }
2099