1 #include "PostgreSQL.h"
2
3 #ifdef PLATFORM_WIN32
4 #include "Winsock2.h"
5 #endif
6
7 #ifdef PLATFORM_POSIX
8 #include <sys/socket.h>
9 #endif
10
11 #define LLOG(x) // DLOG(x)
12
13 #ifndef flagNOPOSTGRESQL
14
15 namespace Upp {
16
17 enum PGSQL_StandardOid {
18 PGSQL_BOOLOID = 16,
19 PGSQL_BYTEAOID = 17,
20 PGSQL_CHAROID = 18,
21 PGSQL_NAMEOID = 19,
22 PGSQL_INT8OID = 20,
23 PGSQL_INT2OID = 21,
24 PGSQL_INT2VECTOROID = 22,
25 PGSQL_INT4OID = 23,
26 PGSQL_REGPROCOID = 24,
27 PGSQL_TEXTOID = 25,
28 PGSQL_OIDOID = 26,
29 PGSQL_TIDOID = 27,
30 PGSQL_XIDOID = 28,
31 PGSQL_CIDOID = 29,
32 PGSQL_OIDVECTOROID = 30,
33 PGSQL_FLOAT4OID = 700,
34 PGSQL_FLOAT8OID = 701,
35 PGSQL_DATEOID = 1082,
36 PGSQL_TIMEOID = 1083,
37 PGSQL_TIMESTAMPOID = 1114,
38 PGSQL_TIMESTAMPZOID = 1184,
39 PGSQL_NUMERICOID = 1700
40 };
41
OidToType(Oid oid)42 int OidToType(Oid oid)
43 {
44 switch(oid) {
45 case PGSQL_BOOLOID:
46 return BOOL_V;
47 case PGSQL_INT8OID:
48 return INT64_V;
49 case PGSQL_INT2OID:
50 case PGSQL_INT2VECTOROID:
51 case PGSQL_INT4OID:
52 return INT_V;
53 case PGSQL_FLOAT4OID:
54 case PGSQL_FLOAT8OID:
55 case PGSQL_NUMERICOID:
56 return DOUBLE_V;
57 case PGSQL_DATEOID:
58 return DATE_V;
59 case PGSQL_TIMEOID:
60 case PGSQL_TIMESTAMPOID:
61 case PGSQL_TIMESTAMPZOID:
62 return TIME_V;
63 case PGSQL_BYTEAOID:
64 return SQLRAW_V;
65 }
66 return STRING_V;
67 }
68
69 class PostgreSQLConnection : public SqlConnection {
70 protected:
71 virtual void SetParam(int i, const Value& r);
72 virtual bool Execute();
73 virtual int GetRowsProcessed() const;
74 virtual Value GetInsertedId() const;
75 virtual bool Fetch();
76 virtual void GetColumn(int i, Ref f) const;
77 virtual void Cancel();
78 virtual SqlSession& GetSession() const;
79 virtual String GetUser() const;
80 virtual String ToString() const;
81
82 private:
83 PostgreSQLSession& session;
84
85 PGconn *conn;
86 Vector<String> param;
87 PGresult *result;
88 Vector<Oid> oid;
89 int rows;
90 int fetched_row; //-1, if not fetched yet
91 String last_insert_table;
92
93 void FreeResult();
94 String ErrorMessage();
95 String ErrorCode();
96
FromCharset(const String & s) const97 String FromCharset(const String& s) const { return session.FromCharset(s); }
ToCharset(const String & s) const98 String ToCharset(const String& s) const { return session.ToCharset(s); }
99
100 public:
101 PostgreSQLConnection(PostgreSQLSession& a_session, PGconn *a_conn);
~PostgreSQLConnection()102 virtual ~PostgreSQLConnection() { Cancel(); }
103 };
104
PostgreSQLReadString(const char * s,String & stmt)105 const char *PostgreSQLReadString(const char *s, String& stmt)
106 {
107 //TODO: to clear this, currently this is based on sqlite
108 stmt.Cat(*s);
109 int c = *s++;
110 for(;;) {
111 if(*s == '\0') break;
112 else
113 if(*s == '\'' && s[1] == '\'') {
114 stmt.Cat("\'\'");
115 s += 2;
116 }
117 else
118 if(*s == c) {
119 stmt.Cat(c);
120 s++;
121 break;
122 }
123 else
124 if(*s == '\\') {
125 stmt.Cat('\\');
126 if(*++s)
127 stmt.Cat(*s++);
128 }
129 else
130 stmt.Cat(*s++);
131 }
132 return s;
133 }
134
PostgreSQLPerformScript(const String & txt,StatementExecutor & se,Gate<int,int> progress_canceled)135 bool PostgreSQLPerformScript(const String& txt, StatementExecutor& se, Gate<int, int> progress_canceled)
136 {
137 const char *text = txt;
138 for(;;) {
139 String stmt;
140 while(*text <= 32 && *text > 0) text++;
141 if(*text == '\0') break;
142 for(;;) {
143 if(*text == '\0')
144 break;
145 if(*text == ';')
146 break;
147 else
148 if(*text == '\'')
149 text = PostgreSQLReadString(text, stmt);
150 else
151 if(*text == '\"')
152 text = PostgreSQLReadString(text, stmt);
153 else
154 stmt.Cat(*text++);
155 }
156 if(progress_canceled(int(text - txt.Begin()), txt.GetLength()))
157 return false;
158 if(!se.Execute(stmt))
159 return false;
160 if(*text) text++;
161 }
162 return true;
163 }
164
ErrorMessage()165 String PostgreSQLConnection::ErrorMessage()
166 {
167 return FromCharset(PQerrorMessage(conn));
168 }
169
ErrorCode()170 String PostgreSQLConnection::ErrorCode()
171 {
172 return PQresultErrorField(result, PG_DIAG_SQLSTATE);
173 }
174
ErrorMessage()175 String PostgreSQLSession::ErrorMessage()
176 {
177 return FromCharset(PQerrorMessage(conn));
178 }
179
ErrorCode()180 String PostgreSQLSession::ErrorCode()
181 {
182 return PQresultErrorField(result, PG_DIAG_SQLSTATE);
183 }
184
EnumUsers()185 Vector<String> PostgreSQLSession::EnumUsers()
186 {
187 Vector<String> vec;
188 Sql sql(*this);
189 sql.Execute("select rolname from pg_authid where rolcanlogin");
190 while(sql.Fetch())
191 vec.Add(sql[0]);
192 return vec;
193 }
194
EnumDatabases()195 Vector<String> PostgreSQLSession::EnumDatabases()
196 {// For now, we really enumerate namespaces rather than databases here
197 Vector<String> vec;
198 Sql sql(*this);
199 sql.Execute("select nspname from pg_namespace where nspacl is not null");
200 while(sql.Fetch())
201 vec.Add(sql[0]);
202 return vec;
203 }
204
EnumData(char type,const char * schema)205 Vector<String> PostgreSQLSession::EnumData(char type, const char *schema)
206 {
207 Vector<String> vec;
208 Sql sql(Format("select n.nspname || '.' || c.relname from pg_catalog.pg_class c "
209 "left join pg_catalog.pg_namespace n "
210 "on n.oid = c.relnamespace "
211 "where c.relkind = '%c' "
212 "and n.nspname like '%s' "
213 "and pg_catalog.pg_table_is_visible(c.oid)",
214 type, schema ? schema : "%"), *this);
215 sql.Execute();
216 while(sql.Fetch())
217 vec.Add(sql[0]);
218 return vec;
219 }
220
EnumTables(String database)221 Vector<String> PostgreSQLSession::EnumTables(String database)
222 {
223 return EnumData('r', database);
224 }
225
EnumViews(String database)226 Vector<String> PostgreSQLSession::EnumViews(String database)
227 {
228 return EnumData('v', database);
229 }
230
EnumSequences(String database)231 Vector<String> PostgreSQLSession::EnumSequences(String database)
232 {
233 return EnumData('S', database);
234 }
235
EnumColumns(String database,String table)236 Vector<SqlColumnInfo> PostgreSQLSession::EnumColumns(String database, String table)
237 {
238 /* database means schema here - support for schemas is a something to fix in sql interface */
239
240 int q = table.Find('.');
241 if(q) table = table.Mid(q + 1);
242 Vector<SqlColumnInfo> vec;
243 Sql sql(Format("select a.attname, a.atttypid, a.attlen, a.atttypmod, a.attnotnull "
244 "from pg_catalog.pg_attribute a "
245 "inner join pg_catalog.pg_class c "
246 "on a.attrelid = c.oid "
247 "inner join pg_catalog.pg_namespace n "
248 "on c.relnamespace = n.oid "
249 "where c.relname = '%s' "
250 "and n.nspname = '%s' "
251 "and a.attnum > 0 "
252 "and a.attisdropped = '0' "
253 "order by a.attnum", table, database), *this);
254 sql.Execute();
255 while(sql.Fetch())
256 {
257 SqlColumnInfo &ci = vec.Add();
258 int type_mod = int(sql[3]) - sizeof(int32);
259 ci.name = sql[0];
260 ci.type = OidToType(IsString(sql[1]) ? atoi(String(sql[1])) : (int)sql[1]);
261 ci.width = sql[2];
262 if(ci.width < 0)
263 ci.width = type_mod;
264 ci.precision = (type_mod >> 16) & 0xffff;
265 ci.scale = type_mod & 0xffff;
266 ci.nullable = AsString(sql[4]) == "0";
267 ci.binary = false;
268 }
269 return vec;
270 }
271
EnumPrimaryKey(String database,String table)272 Vector<String> PostgreSQLSession::EnumPrimaryKey(String database, String table)
273 {
274 // SELECT cc.conname, a.attname
275 // FROM pg_constraint cc
276 // INNER JOIN pg_class c
277 // ON c.oid=conrelid
278 // INNER JOIN pg_attribute a
279 // ON a.attnum = ANY(conkey)
280 // AND a.attrelid = c.oid
281 // WHERE contype='p'
282 // AND relname = '?'
283 return Vector<String>(); //TODO
284 }
285
EnumRowID(String database,String table)286 String PostgreSQLSession::EnumRowID(String database, String table)
287 {
288 return ""; //TODO
289 }
290
EnumReservedWords()291 Vector<String> PostgreSQLSession::EnumReservedWords()
292 {
293 return Vector<String>(); //TODO
294 }
295
CreateConnection()296 SqlConnection * PostgreSQLSession::CreateConnection()
297 {
298 return new PostgreSQLConnection(*this, conn);
299 }
300
ExecTrans(const char * statement)301 void PostgreSQLSession::ExecTrans(const char * statement)
302 {
303 if(trace)
304 *trace << statement << UPP::EOL;
305
306 int itry = 0;
307
308 do {
309 result = PQexec(conn, statement);
310 if(PQresultStatus(result) == PGRES_COMMAND_OK) {
311 PQclear(result);
312 return;
313 }
314 }
315 while(level == 0 && (!ConnectionOK() || ErrorMessage().Find("connection") >= 0 && itry == 0)
316 && WhenReconnect(itry++));
317
318 if(trace)
319 *trace << statement << " failed: " << ErrorMessage() << " (level " << level << ")\n";
320 SetError(ErrorMessage(), statement, 0, ErrorCode());
321 PQclear(result);
322 }
323
FromCharset(const String & s) const324 String PostgreSQLSession::FromCharset(const String& s) const
325 {
326 if(!charset)
327 return s;
328 String r = UPP::ToCharset(GetDefaultCharset(), s, charset);
329 return r;
330 }
331
ToCharset(const String & s) const332 String PostgreSQLSession::ToCharset(const String& s) const
333 {
334 if(!charset)
335 return s;
336 String r = UPP::ToCharset(charset, s);
337 return r;
338 }
339
DoKeepAlive()340 void PostgreSQLSession::DoKeepAlive()
341 {
342 if(keepalive && conn) {
343 int optval = 1;
344 setsockopt(PQsocket(conn), SOL_SOCKET, SO_KEEPALIVE, (char *) &optval, sizeof(optval));
345 }
346 }
347
Open(const char * connect)348 bool PostgreSQLSession::Open(const char *connect)
349 {
350 Close();
351 conns = connect;
352
353 {
354 MemoryIgnoreLeaksBlock __;
355 // PGSQL, when sharing .dll SSL, does not free SSL data
356 conn = PQconnectdb(connect);
357 }
358
359 if(PQstatus(conn) != CONNECTION_OK)
360 {
361 SetError(FromSystemCharset(PQerrorMessage(conn)), "Opening database");
362 Close();
363 return false;
364 }
365 level = 0;
366
367 if(PQclientEncoding(conn)) {
368 if(PQsetClientEncoding(conn, "UTF8")) {
369 SetError("Cannot set UTF8 charset", "Opening database");
370 return false;
371 }
372 charset = CHARSET_UTF8;
373 }
374 else
375 charset = CHARSET_DEFAULT;
376
377 DoKeepAlive();
378
379 LLOG( String("Postgresql client encoding: ") + pg_encoding_to_char( PQclientEncoding(conn) ) );
380
381 Sql sql(*this);
382 if(sql.Execute("select setting from pg_settings where name = 'bytea_output'") && sql.Fetch() && sql[0] == "hex")
383 hex_blobs = true;
384
385 return true;
386 }
387
ConnectionOK()388 bool PostgreSQLSession::ConnectionOK()
389 {
390 return conn && PQstatus(conn) == CONNECTION_OK;
391 }
392
ReOpen()393 bool PostgreSQLSession::ReOpen()
394 {
395 PQreset(conn);
396 if(PQstatus(conn) != CONNECTION_OK)
397 {
398 SetError(ErrorMessage(), "Opening database");
399 return false;
400 }
401 DoKeepAlive();
402 level = 0;
403 return true;
404 }
405
Close()406 void PostgreSQLSession::Close()
407 {
408 if(!conn)
409 return;
410 SessionClose();
411 PQfinish(conn);
412 conn = NULL;
413 level = 0;
414 }
415
Begin()416 void PostgreSQLSession::Begin()
417 {
418 ExecTrans("begin");
419 level++;
420 }
421
Commit()422 void PostgreSQLSession::Commit()
423 {
424 ExecTrans("commit");
425 level--;
426 }
427
Rollback()428 void PostgreSQLSession::Rollback()
429 {
430 ExecTrans("rollback");
431 if(level > 0) level--;
432 }
433
GetTransactionLevel() const434 int PostgreSQLSession::GetTransactionLevel() const
435 {
436 return level;
437 }
438
SetParam(int i,const Value & r)439 void PostgreSQLConnection::SetParam(int i, const Value& r)
440 {
441 String p;
442 if(IsNull(r))
443 p = "NULL";
444 else
445 switch(r.GetType()) {
446 case SQLRAW_V: {
447 String raw = SqlRaw(r);
448 size_t rl;
449 unsigned char *s = PQescapeByteaConn(conn, (const byte *)~raw, raw.GetLength(), &rl);
450 p.Reserve(int(rl + 16));
451 p = "\'" + String(s, int(rl - 1)) + "\'::bytea";
452 PQfreemem(s);
453 break;
454 }
455 case WSTRING_V:
456 case STRING_V: {
457 String v = r;
458 v = ToCharset(v);
459 StringBuffer b(v.GetLength() * 2 + 3);
460 char *q = b;
461 *q = '\'';
462 int *err = NULL;
463 int n = (int)PQescapeStringConn(conn, q + 1, v, v.GetLength(), err);
464 q[1 + n] = '\'';
465 b.SetCount(2 + n);
466 p = b;
467 }
468 break;
469 case BOOL_V:
470 case INT_V:
471 p << int(r);
472 break;
473 case INT64_V:
474 p << int64(r);
475 break;
476 case DOUBLE_V:
477 p = FormatDouble(double(r), 20);
478 break;
479 case DATE_V: {
480 Date d = r;
481 p = Format("\'%04d-%02d-%02d\'", d.year, d.month, d.day);
482 }
483 break;
484 case TIME_V: {
485 Time t = r;
486 p = Format("\'%04d-%02d-%02d %02d:%02d:%02d\'",
487 t.year, t.month, t.day, t.hour, t.minute, t.second);
488 }
489 break;
490 default:
491 NEVER();
492 }
493 param.At(i, p);
494 }
495
Execute()496 bool PostgreSQLConnection::Execute()
497 {
498 Cancel();
499 if(statement.GetLength() == 0) {
500 session.SetError("Empty statement", statement);
501 return false;
502 }
503
504 CParser p(statement);
505 if((p.Id("insert") || p.Id("INSERT")) && (p.Id("into") || p.Id("INTO")) && p.IsId())
506 last_insert_table = p.ReadId();
507
508 String query;
509 int pi = 0;
510 const char *s = statement;
511 while(s < statement.End())
512 if(*s == '\'' || *s == '\"')
513 s = PostgreSQLReadString(s, query);
514 else {
515 if(*s == '?' && !session.noquestionparams) {
516 if(s[1] == '?') {
517 query.Cat('?');
518 s++;
519 }
520 else {
521 if(pi >= param.GetCount()) {
522 session.SetError("Invalid number of parameters", statement);
523 return false;
524 }
525 query.Cat(param[pi++]);
526 }
527 }
528 else
529 query.Cat(*s);
530 s++;
531 }
532 param.Clear();
533
534 Stream *trace = session.GetTrace();
535 dword time;
536 if(session.IsTraceTime())
537 time = msecs();
538
539 int itry = 0;
540 int stat;
541 do {
542 result = PQexecParams(conn, query, 0, NULL, NULL, NULL, NULL, 0);
543 stat = PQresultStatus(result);
544 }
545 while(stat != PGRES_TUPLES_OK && stat != PGRES_COMMAND_OK && session.level == 0 &&
546 (!session.ConnectionOK() || ErrorMessage().Find("connection") >= 0 && itry == 0) && session.WhenReconnect(itry++));
547
548 if(trace) {
549 if(session.IsTraceTime())
550 *trace << Format("--------------\nexec %d ms:\n", msecs(time));
551 }
552 if(stat == PGRES_TUPLES_OK) //result set
553 {
554 rows = PQntuples(result);
555 int fields = PQnfields(result);
556 info.SetCount(fields);
557 oid.SetCount(fields);
558 for(int i = 0; i < fields; i++)
559 {
560 SqlColumnInfo& f = info[i];
561 f.name = ToUpper(PQfname(result, i));
562 f.width = PQfsize(result, i);
563 int type_mod = PQfmod(result, i) - sizeof(int32);
564 if(f.width < 0)
565 f.width = type_mod;
566 f.precision = (type_mod >> 16) & 0xffff;
567 f.scale = type_mod & 0xffff;
568 f.nullable = true;
569 Oid type_oid = PQftype(result, i);
570 f.type = OidToType(type_oid);
571 oid[i] = type_oid;
572 }
573 return true;
574 }
575 if(stat == PGRES_COMMAND_OK) //command executed OK
576 {
577 rows = atoi(PQcmdTuples(result));
578 return true;
579 }
580
581 session.SetError(ErrorMessage(), query, 0, ErrorCode());
582 FreeResult();
583 return false;
584 }
585
GetRowsProcessed() const586 int PostgreSQLConnection::GetRowsProcessed() const
587 {
588 return rows;
589 }
590
GetInsertedId() const591 Value PostgreSQLConnection::GetInsertedId() const
592 {
593 String pk = session.pkache.Get(last_insert_table, Null);
594 if(IsNull(pk)) {
595 String sqlc_expr;
596 sqlc_expr <<
597 "SELECT " <<
598 "pg_attribute.attname " <<
599 "FROM pg_index, pg_class, pg_attribute " <<
600 "WHERE " <<
601 "pg_class.oid = '" << last_insert_table << "'::regclass AND "
602 "indrelid = pg_class.oid AND "
603 "pg_attribute.attrelid = pg_class.oid AND "
604 "pg_attribute.attnum = any(pg_index.indkey) "
605 "AND indisprimary";
606 Sql sqlc(sqlc_expr, session);
607 pk = sqlc.Execute() && sqlc.Fetch() ? sqlc[0] : "ID";
608 session.pkache.Add(last_insert_table, pk);
609 }
610 Sql sql("select currval('" + last_insert_table + "_" + pk +"_seq')", session);
611 if(sql.Execute() && sql.Fetch())
612 return sql[0];
613 else
614 return Null;
615 }
616
Fetch()617 bool PostgreSQLConnection::Fetch()
618 {
619 fetched_row++;
620 if(result && rows > 0 && fetched_row < rows)
621 return true;
622 Cancel();
623 return false;
624 }
625
sDate(const char * s)626 static Date sDate(const char *s)
627 {
628 // 0123456789012345678
629 // YYYY-MM-DD HH-MM-SS
630 return Date(atoi(s), atoi(s + 5), atoi(s + 8));
631 }
632
GetColumn(int i,Ref f) const633 void PostgreSQLConnection::GetColumn(int i, Ref f) const
634 {
635 if(PQgetisnull(result, fetched_row, i))
636 {
637 f = Null;
638 return;
639 }
640 char *s = PQgetvalue(result, fetched_row, i);
641 switch(info[i].type)
642 {
643 case INT64_V:
644 f.SetValue(ScanInt64(s));
645 break;
646 case INT_V:
647 f.SetValue(ScanInt(s));
648 break;
649 case DOUBLE_V: {
650 double d = ScanDouble(s);
651 f.SetValue(IsNull(d) ? NAN : d);
652 }
653 break;
654 case BOOL_V:
655 f.SetValue(*s == 't' ? "1" : "0");
656 break;
657 case DATE_V:
658 f.SetValue(sDate(s));
659 break;
660 case TIME_V: {
661 Time t = ToTime(sDate(s));
662 t.hour = atoi(s + 11);
663 t.minute = atoi(s + 14);
664 t.second = atoi(s + 17);
665 f.SetValue(t);
666 }
667 break;
668 default: {
669 if(oid[i] == PGSQL_BYTEAOID) {
670 if(session.hex_blobs)
671 f.SetValue(ScanHexString(s, (int)strlen(s)));
672 else {
673 size_t len;
674 unsigned char *q = PQunescapeBytea((const unsigned char *)s, &len);
675 f.SetValue(String(q, (int)len));
676 PQfreemem(q);
677 }
678 }
679 else
680 f.SetValue(FromCharset(String(s)));
681 }
682 }
683 }
684
Cancel()685 void PostgreSQLConnection::Cancel()
686 {
687 info.Clear();
688 rows = 0;
689 fetched_row = -1;
690 FreeResult();
691 }
692
GetSession() const693 SqlSession& PostgreSQLConnection::GetSession() const
694 {
695 return session;
696 }
697
GetUser() const698 String PostgreSQLConnection::GetUser() const
699 {
700 return PQuser(conn);
701 }
702
ToString() const703 String PostgreSQLConnection::ToString() const
704 {
705 return statement;
706 }
707
FreeResult()708 void PostgreSQLConnection::FreeResult()
709 {
710 if(result)
711 {
712 PQclear(result);
713 result = NULL;
714 }
715 }
716
PostgreSQLConnection(PostgreSQLSession & a_session,PGconn * a_conn)717 PostgreSQLConnection::PostgreSQLConnection(PostgreSQLSession& a_session, PGconn *a_conn)
718 : session(a_session), conn(a_conn)
719 {
720 result = NULL;
721 }
722
Get()723 Value PgSequence::Get()
724 {
725 #ifndef NOAPPSQL
726 Sql sql(session ? *session : SQL.GetSession());
727 #else
728 ASSERT(session);
729 Sql sql(*session);
730 #endif
731 if(!sql.Execute(Select(NextVal(seq)).Get()) || !sql.Fetch())
732 return ErrorValue();
733 return sql[0];
734 }
735
736 }
737
738 #endif
739