1 // src/Database.cc
2 // This file is part of libpbe; see http://decimail.org
3 // (C) 2004 - 2007 Philip Endecott
4 
5 // This program is free software; you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation; either version 2 of the License, or
8 // any later version.
9 //
10 // This program is distributed in the hope that it will be useful,
11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 // GNU General Public License for more details.
14 //
15 // You should have received a copy of the GNU General Public License
16 // along with this program; if not, write to the Free Software
17 // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
18 
19 #include "Database.hh"
20 
21 #include "Exception.hh"
22 #include "utils.hh"
23 #include "StringTransformer.hh"
24 
25 #include <map>
26 #include <string>
27 
28 #include <boost/lexical_cast.hpp>
29 #include <boost/scoped_ptr.hpp>
30 #include <boost/algorithm/string/predicate.hpp>
31 #include <boost/algorithm/string/classification.hpp>
32 #include <boost/algorithm/string/trim.hpp>
33 
34 #include <libpq-fe.h>
35 #include <postgres.h>
36 #include <catalog/pg_type.h>
37 
38 using namespace std;
39 
40 
41 namespace pbe {
42 
43 
oid_to_typecode(Oid oid)44 static typecode_t oid_to_typecode(Oid oid)
45 {
46   switch(oid) {
47     case TEXTOID:        return text_type;
48     case BYTEAOID:       return text_type;
49     case INT4OID:        return numeric_type;
50     case INT8OID:        return numeric_type;
51     case TIMESTAMPTZOID: return timestamptz_type;
52     case TIMESTAMPOID:   return timestamptz_type;  // Hmm, should distinguish it is local time
53     case FLOAT4OID:      return float_type;
54     case FLOAT8OID:      return double_type;
55     default:             throw StrException("type error, unrecognised oid "+boost::lexical_cast<string>(oid));
56   }
57 }
58 
59 
typecode_to_oid(typecode_t typecode)60 static Oid typecode_to_oid(typecode_t typecode)
61 {
62   switch(typecode) {
63     case null_type:        return 0;
64     case text_type:        return TEXTOID;
65     case numeric_type:     return INT4OID;
66     case timestamptz_type: return TIMESTAMPTZOID;
67     case bytea_type:       return BYTEAOID;
68     case float_type:       return FLOAT4OID;
69     case double_type:      return FLOAT8OID;
70     default:               throw "type error, unrecognised typecode";
71   }
72 }
73 
74 
75 class DatabaseConnectionFailed: public DatabaseException {
76 public:
DatabaseConnectionFailed(PGconn * pgconn)77   DatabaseConnectionFailed(PGconn* pgconn):
78     DatabaseException(pgconn,"Connecting to database") {}
79 };
80 
81 
Database(string conninfo)82 Database::Database(string conninfo):
83   pgconn(PQconnectdb(conninfo.c_str())),
84   conn_fd(PQsocket(pgconn),false),
85   transaction_in_progress(false)
86 {
87   if (PQstatus(pgconn)!=CONNECTION_OK) {
88     throw DatabaseConnectionFailed(pgconn);
89   }
90 }
91 
92 
~Database()93 Database::~Database()
94 {
95   PQfinish(pgconn);
96 }
97 
98 
get_fd(void) const99 const FileDescriptor& Database::get_fd(void) const
100 {
101   return conn_fd;
102 }
103 
104 
get_any_notification(void)105 string Database::get_any_notification(void)
106 {
107   int rc = PQconsumeInput(pgconn);
108   if (rc==0) {
109     throw QueryFailed(pgconn,"checking for notifications");
110   }
111   boost::shared_ptr<PGnotify> p(PQnotifies(pgconn), PQfreemem);
112   if (!p) {
113     return "";
114   }
115   return p.get()->relname;
116 }
117 
118 
exec_sql(string cmd)119 void Database::exec_sql(string cmd)
120 {
121   boost::shared_ptr<PGresult> res(PQexec(pgconn, cmd.c_str()), PQclear);
122   if (!res || PQresultStatus(res.get()) !=PGRES_COMMAND_OK) {
123     throw QueryFailed(pgconn, cmd.c_str());
124   }
125 }
126 
127 
report(ostream & s) const128 void DatabaseException::report(ostream& s) const
129 {
130   s << "Database exception: " << postgres_error;
131   if (doing_what!="") {
132     s << " while " << doing_what;
133   }
134   s << "\n";
135 }
136 
137 
Transaction(Database & database_)138 Transaction::Transaction(Database& database_):
139   database(database_),
140   committed(false)
141 {
142   if (database.transaction_in_progress) {
143     nested=true;
144   } else {
145     nested=false;
146     database.transaction_in_progress=true;
147     database.exec_sql("begin");
148   }
149 }
150 
151 
~Transaction()152 Transaction::~Transaction()
153 {
154   if (!nested) {
155     if (!committed) {
156       database.transaction_in_progress=false;
157       try {
158         database.exec_sql("rollback;");
159       }
160       catch(...) {
161         // Mustn't throw an exception from inside a destructor, in case it is being
162         // invoked during exception processing.
163         // (TODO is there a better fix for this?)
164       }
165     }
166   }
167 }
168 
169 
commit(void)170 void Transaction::commit(void)
171 {
172   if (!nested) {
173     database.transaction_in_progress=false;
174     database.exec_sql("commit");
175     committed=true;
176   }
177 }
178 
179 
Result(boost::shared_ptr<PGresult> res_)180 Result::Result(boost::shared_ptr<PGresult> res_):
181   rows(PQntuples(res_.get())),
182   cols(PQnfields(res_.get())),
183   res(res_)
184 {}
185 
186 
get_rows_changed(void) const187 int Result::get_rows_changed(void) const
188 {
189   return boost::lexical_cast<int>(PQcmdTuples(res.get()));
190 }
191 
192 
193 class ColumnNotFound: public StrException {
194 public:
ColumnNotFound(string colname)195   ColumnNotFound(string colname):
196     StrException("Table has no column named " + colname)
197   {}
198 };
199 
200 
column(std::string name) const201 int Result::column(std::string name) const
202 {
203   int n = PQfnumber(res.get(),name.c_str());
204   if (n==-1) {
205     throw ColumnNotFound(name);
206   }
207   return n;
208 }
209 
210 
column_name(int col) const211 std::string Result::column_name(int col) const
212 {
213   return PQfname(res.get(),col);
214 }
215 
rawget(int row,int col) const216 char* Result::rawget(int row, int col) const
217 {
218   return PQgetvalue(res.get(),row,col);
219 }
220 
getlength(int row,int col) const221 int Result::getlength(int row, int col) const
222 {
223   return PQgetlength(res.get(),row,col);
224 }
225 
226 
is_null(int row,int col) const227 bool Result::is_null(int row, int col) const
228 {
229   return PQgetisnull(res.get(),row,col);
230 }
231 
232 
column_typecode(int col) const233 typecode_t Result::column_typecode(int col) const
234 {
235   Oid oid = PQftype(res.get(),col);
236   return oid_to_typecode(oid);
237 }
238 
239 
240 int statement_name_t::counter = 0;
241 
242 
can_use_pqexecparams(string querystr)243 static bool can_use_pqexecparams(string querystr)
244 {
245   boost::algorithm::trim_left_if(querystr,boost::algorithm::is_any_of(" ("));
246   return boost::algorithm::istarts_with(querystr,"select")
247       || boost::algorithm::istarts_with(querystr,"update")
248       || boost::algorithm::istarts_with(querystr,"insert")
249       || boost::algorithm::istarts_with(querystr,"delete");
250 }
251 
252 
QueryCore(Database & database_,std::string querystr_,int nparams_,typecode_t * argtypecodes,int * lengths,int * formats)253 QueryCore::QueryCore(Database& database_, std::string querystr_, int nparams_,
254                      typecode_t* argtypecodes, int* lengths, int* formats):
255   database(database_),
256   querystr(querystr_),
257   params_ok(can_use_pqexecparams(querystr_)),
258   nparams(nparams_),
259   param_lengths(lengths),
260   param_formats(formats),
261   prepared(false)
262 {
263   while (nparams>0 && argtypecodes[nparams-1]==null_type) {
264     nparams--;
265   }
266   argoids = new Oid[nparams];  // hmm, use something smart
267   for (int i=0; i<nparams; ++i) {
268     argoids[i] = typecode_to_oid(argtypecodes[i]);
269   }
270 }
271 
272 
~QueryCore()273 QueryCore::~QueryCore()
274 {
275   if (prepared) {
276     try {
277       database.exec_sql("deallocate "+statement_name);
278     }
279     catch(...) {
280       // Mustn't throw an exception from inside a destructor, in case it is being
281       // invoked during exception processing.
282       // (TODO is there a better fix for this?)
283     }
284   }
285   delete[] argoids;
286 }
287 
288 
operator ()(const char * enc_args[])289 Result QueryCore::operator()(const char* enc_args[])
290 {
291   if (!params_ok) {
292     return runonce(enc_args);
293   }
294 
295   if (!prepared) {
296     prepare();
297   }
298   boost::shared_ptr<PGresult>
299   result(PQexecPrepared(database.pgconn, statement_name.c_str(), nparams,
300                         enc_args, param_lengths, param_formats, 1),
301          PQclear);
302   if (result) {
303     ExecStatusType status = PQresultStatus(result.get());
304     if (status==PGRES_TUPLES_OK || status==PGRES_COMMAND_OK) {
305       return Result(result);
306     }
307   }
308   throw QueryFailed(database.pgconn, querystr);
309 }
310 
311 
wrap_PQexecParams(PGconn * conn,string command,int nparams,const Oid * paramTypes,const char * const * paramValues,const int * paramLengths,const int * paramFormats)312 static PGresult* wrap_PQexecParams(PGconn* conn, string command, int nparams,
313                                    const Oid* paramTypes, const char* const * paramValues,
314                                    const int* paramLengths, const int* paramFormats)
315 {
316   string new_command;
317   string::size_type p=0;
318   while (p<command.length()) {
319     const string::size_type q = command.find('$',p);
320     if (q==string::npos) {
321       new_command += command.substr(p);
322       break;
323     }
324     new_command += command.substr(p,(q-p));
325     string::size_type r = command.find_first_not_of("0123456789",q+1);
326     if (r==string::npos) {
327       r = command.length();
328     }
329     int n = boost::lexical_cast<int>(command.substr(q+1,(r-q-1)));
330     if (n==0) {
331       throw "$0 not allowed";
332     }
333     if (n>nparams) {
334       throw "Not enough parameters";
335     }
336     --n;
337     Oid o = paramTypes[n];
338     switch (o) {
339       case TEXTOID: {      boost::scoped_array<char> buf(new char[paramLengths[n]*2+1]);
340                            PQescapeStringConn(conn,buf.get(),paramValues[n],
341                                               paramLengths[n],NULL);
342                            new_command += string("\'") + buf.get() + "\'";
343                            break; }
344       case BYTEAOID: {     boost::shared_ptr<unsigned char> buf (
345                              PQescapeByteaConn(conn,
346                                                reinterpret_cast<const unsigned char*>(paramValues[n]),
347                                                paramLengths[n],NULL),
348                              PQfreemem);
349                            new_command += string("\'")
350                                        + reinterpret_cast<const char*>(buf.get()) + "\'";
351                            break; }
352       case INT4OID: {      int32_t i = ntohl(*reinterpret_cast<const int32_t*>(paramValues[n]));
353                            new_command += boost::lexical_cast<string>(i);
354                            break; }
355       case INT8OID: {      int64_t i = ntoh64(*reinterpret_cast<const int64_t*>(paramValues[n]));
356                            new_command += boost::lexical_cast<string>(i);
357                            break; }
358       case TIMESTAMPTZOID: throw "timestamptz not implemented";
359                            break;
360       default:             throw "unrecognised oid";
361     }
362     p = r;
363   }
364   //cout << "converted '" << command << "' to '" << new_command << "'\n";
365   return PQexec(conn, new_command.c_str());
366 }
367 
368 
runonce(const char * enc_args[])369 Result QueryCore::runonce(const char* enc_args[])
370 {
371   if (params_ok) {
372     boost::shared_ptr<PGresult>
373     result(PQexecParams(database.pgconn, querystr.c_str(), nparams,
374                         argoids, enc_args, param_lengths, param_formats, 1),
375            PQclear);
376     if (result) {
377       ExecStatusType status = PQresultStatus(result.get());
378       if (status==PGRES_TUPLES_OK || status==PGRES_COMMAND_OK) {
379         return Result(result);
380       }
381     }
382   } else {
383     boost::shared_ptr<PGresult>
384     result(wrap_PQexecParams(database.pgconn, querystr, nparams,
385                              argoids, enc_args, param_lengths, param_formats),
386            PQclear);
387     if (result) {
388       ExecStatusType status = PQresultStatus(result.get());
389       if (status==PGRES_TUPLES_OK) {
390         throw StrException("Not expecting tuples in result from "
391                            "non-pqexecparams query '"+querystr+"'");
392       }
393       if (status==PGRES_COMMAND_OK) {
394         return Result(result);
395       }
396     }
397   }
398   throw QueryFailed(database.pgconn, querystr);
399 }
400 
401 
prepare(void)402 void QueryCore::prepare(void)
403 {
404 //cout << "Preparing query with nparams=" << nparams << "\n";
405   boost::shared_ptr<PGresult>
406   result(PQprepare(database.pgconn, statement_name.c_str(),
407                    querystr.c_str(), nparams, argoids),
408          PQclear);
409   if (result) {
410     ExecStatusType status = PQresultStatus(result.get());
411     if (status==PGRES_COMMAND_OK) {
412       prepared=true;
413       return;
414     }
415   }
416   throw QueryFailed(database.pgconn, querystr);
417 }
418 
419 
420 
421 };
422 
423