1 /*
2  * Copyright (C) 2016 dragon jiang<jianlinlong@gmail.com>
3  * Copyright (C) 2019 Tildeslash Ltd.
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a copy
6  * of this software and associated documentation files(the "Software"), to deal
7  * in the Software without restriction, including without limitation the rights
8  * to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
9  * copies of the Software, and to permit persons to whom the Software is
10  * furnished to do so, subject to the following conditions :
11  *
12  * The above copyright notice and this permission notice shall be included in all
13  * copies or substantial portions of the Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
18  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21  * SOFTWARE.
22  */
23 
24 #ifndef _ZDBPP_H_
25 #define _ZDBPP_H_
26 
27 #include "zdb.h"
28 #include <string>
29 #include <utility>
30 #include <stdexcept>
31 
32 
33 namespace zdb {
34 
35     class sql_exception : public std::runtime_error
36     {
37     public:
38         sql_exception(const char* msg = "SQLException")
runtime_error(msg)39         : std::runtime_error(msg)
40         {}
41     };
42 
43 #define except_wrapper(f) TRY { f; } ELSE {throw sql_exception(Exception_frame.message);} END_TRY
44 
45     struct noncopyable
46     {
47         noncopyable() = default;
48 
49         // make it noncopyable
50         noncopyable(noncopyable const&) = delete;
51         noncopyable& operator=(noncopyable const&) = delete;
52 
53         // make it not movable
54         noncopyable(noncopyable&&) = delete;
55         noncopyable& operator=(noncopyable&&) = delete;
56     };
57 
58 
59     class URL: private noncopyable
60     {
61     public:
URL(const std::string & url)62         URL(const std::string& url)
63         :URL(url.c_str())
64         {}
65 
URL(const char * url)66         URL(const char *url) {
67             t_ = URL_new(url);
68         }
69 
~URL()70         ~URL() {
71             if (t_)
72                 URL_free(&t_);
73         }
74 
URL_T()75         operator URL_T() {
76             return t_;
77         }
78 
79     public:
protocol()80         const char *protocol() const {
81             return URL_getProtocol(t_);
82         }
83 
user()84         const char *user() const {
85             return URL_getUser(t_);
86         }
87 
password()88         const char *password() const {
89             return URL_getPassword(t_);
90         }
91 
host()92         const char *host() const {
93             return URL_getHost(t_);
94         }
95 
port()96         int port() const {
97             return URL_getPort(t_);
98         }
99 
path()100         const char *path() const {
101             return URL_getPath(t_);
102         }
103 
queryString()104         const char *queryString() const {
105             return URL_getQueryString(t_);
106         }
107 
parameterNames()108         const char **parameterNames() const {
109             return URL_getParameterNames(t_);
110         }
111 
parameter(const char * name)112         const char *parameter(const char *name) const {
113             return URL_getParameter(t_, name);
114         }
115 
tostring()116         const char *tostring() const {
117             return URL_toString(t_);
118         }
119 
120     private:
121         URL_T t_;
122     };
123 
124     class ResultSet : private noncopyable
125     {
126     public:
ResultSet_T()127         operator ResultSet_T() {
128             return t_;
129         }
130 
ResultSet(ResultSet && r)131         ResultSet(ResultSet&& r)
132         :t_(r.t_)
133         {
134             r.t_ = nullptr;
135         }
136 
137     protected:
138         friend class PreparedStatement;
139         friend class Connection;
140 
ResultSet(ResultSet_T t)141         ResultSet(ResultSet_T t)
142         :t_(t)
143         {}
144 
145     public:
columnCount()146         int columnCount() {
147             return ResultSet_getColumnCount(t_);
148         }
149 
columnName(int columnIndex)150         const char *columnName(int columnIndex) {
151             except_wrapper( RETURN ResultSet_getColumnName(t_, columnIndex) );
152         }
153 
columnSize(int columnIndex)154         long columnSize(int columnIndex) {
155             except_wrapper( RETURN ResultSet_getColumnSize(t_, columnIndex) );
156         }
157 
setFetchSize(int prefetch_rows)158         void setFetchSize(int prefetch_rows) {
159             ResultSet_setFetchSize(t_, prefetch_rows);
160         }
161 
getFetchSize()162         int getFetchSize() {
163             return ResultSet_getFetchSize(t_);
164         }
165 
next()166         bool next() {
167             except_wrapper( RETURN ResultSet_next(t_) );
168         }
169 
isnull(int columnIndex)170         bool isnull(int columnIndex) {
171             except_wrapper( RETURN ResultSet_isnull(t_, columnIndex) );
172         }
173 
getString(int columnIndex)174         const char *getString(int columnIndex) {
175             except_wrapper( RETURN ResultSet_getString(t_, columnIndex) );
176         }
177 
getString(const char * columnName)178         const char *getString(const char *columnName) {
179             except_wrapper( RETURN ResultSet_getStringByName(t_, columnName) );
180         }
181 
getInt(int columnIndex)182         int getInt(int columnIndex) {
183             except_wrapper( RETURN ResultSet_getInt(t_, columnIndex) );
184         }
185 
getInt(const char * columnName)186         int getInt(const char *columnName) {
187             except_wrapper( RETURN ResultSet_getIntByName(t_, columnName) );
188         }
189 
getLLong(int columnIndex)190         long long getLLong(int columnIndex) {
191             except_wrapper( RETURN ResultSet_getLLong(t_, columnIndex) );
192         }
193 
getLLong(const char * columnName)194         long long getLLong(const char *columnName) {
195             except_wrapper( RETURN ResultSet_getLLongByName(t_, columnName) );
196         }
197 
getDouble(int columnIndex)198         double getDouble(int columnIndex) {
199             except_wrapper( RETURN ResultSet_getDouble(t_, columnIndex) );
200         }
201 
getDouble(const char * columnName)202         double getDouble(const char *columnName) {
203             except_wrapper( RETURN ResultSet_getDoubleByName(t_, columnName) );
204         }
205 
206         template <typename T>
getBlob(T v)207         std::tuple<const void*, int> getBlob(T v) {
208             int size = 0;
209             const void *blob = NULL;
210             if constexpr (std::is_integral<T>::value)
211                 except_wrapper( blob = ResultSet_getBlob(t_, v, &size) );
212             else
213                 except_wrapper( blob = ResultSet_getBlobByName(t_, v, &size) );
214             return {blob, size};
215         }
216 
getTimestamp(int columnIndex)217         time_t getTimestamp(int columnIndex) {
218             except_wrapper( RETURN ResultSet_getTimestamp(t_, columnIndex) );
219         }
220 
getTimestamp(const char * columnName)221         time_t getTimestamp(const char *columnName) {
222             except_wrapper( RETURN ResultSet_getTimestampByName(t_, columnName) );
223         }
224 
getDateTime(int columnIndex)225         struct tm getDateTime(int columnIndex) {
226             except_wrapper( RETURN ResultSet_getDateTime(t_, columnIndex) );
227         }
228 
getDateTime(const char * columnName)229         struct tm getDateTime(const char *columnName) {
230             except_wrapper( RETURN ResultSet_getDateTimeByName(t_, columnName) );
231         }
232 
233     private:
234         ResultSet_T t_;
235     };
236 
237     class PreparedStatement : private noncopyable
238     {
239     public:
PreparedStatement_T()240         operator PreparedStatement_T() {
241             return t_;
242         }
243 
PreparedStatement(PreparedStatement && r)244         PreparedStatement(PreparedStatement&& r)
245         :t_(r.t_)
246         {
247             r.t_ = nullptr;
248         }
249 
250     protected:
251         friend class Connection;
252 
PreparedStatement(PreparedStatement_T t)253         PreparedStatement(PreparedStatement_T t)
254         :t_(t)
255         {}
256 
257     public:
setString(int parameterIndex,const char * x)258         void setString(int parameterIndex, const char *x) {
259             except_wrapper( PreparedStatement_setString(t_, parameterIndex, x) );
260         }
261 
setInt(int parameterIndex,int x)262         void setInt(int parameterIndex, int x) {
263             except_wrapper( PreparedStatement_setInt(t_, parameterIndex, x) );
264         }
265 
setLLong(int parameterIndex,long long x)266         void setLLong(int parameterIndex, long long x) {
267             except_wrapper( PreparedStatement_setLLong(t_, parameterIndex, x) );
268         }
269 
setDouble(int parameterIndex,double x)270         void setDouble(int parameterIndex, double x) {
271             except_wrapper( PreparedStatement_setDouble(t_, parameterIndex, x) );
272         }
273 
setBlob(int parameterIndex,const void * x,int size)274         void setBlob(int parameterIndex, const void *x, int size) {
275             except_wrapper( PreparedStatement_setBlob(t_, parameterIndex, x, size) );
276         }
277 
setTimestamp(int parameterIndex,time_t x)278         void setTimestamp(int parameterIndex, time_t x) {
279             except_wrapper( PreparedStatement_setTimestamp(t_, parameterIndex, x) );
280         }
281 
execute()282         void execute() {
283             except_wrapper( PreparedStatement_execute(t_) );
284         }
285 
executeQuery()286         ResultSet executeQuery() {
287             except_wrapper(
288                            ResultSet_T r = PreparedStatement_executeQuery(t_);
289                            RETURN ResultSet(r);
290                            );
291         }
292 
rowsChanged()293         long long rowsChanged() {
294             return PreparedStatement_rowsChanged(t_);
295         }
296 
getParameterCount()297         int getParameterCount() {
298             return PreparedStatement_getParameterCount(t_);
299         }
300 
301     public:
bind(int parameterIndex,const char * x)302         void bind(int parameterIndex, const char *x) {
303             this->setString(parameterIndex, x);
304         }
305 
bind(int parameterIndex,const std::string & x)306         void bind(int parameterIndex, const std::string& x) {
307             this->setString(parameterIndex, x.c_str());
308         }
309 
bind(int parameterIndex,int x)310         void bind(int parameterIndex, int x) {
311             this->setInt(parameterIndex, x);
312         }
313 
bind(int parameterIndex,long long x)314         void bind(int parameterIndex, long long x) {
315             this->setLLong(parameterIndex, x);
316         }
317 
bind(int parameterIndex,double x)318         void bind(int parameterIndex, double x) {
319             this->setDouble(parameterIndex, x);
320         }
321 
bind(int parameterIndex,time_t x)322         void bind(int parameterIndex, time_t x) {
323             this->setTimestamp(parameterIndex, x);
324         }
325 
326         //blob
bind(int parameterIndex,std::tuple<const void *,int> x)327         void bind(int parameterIndex, std::tuple<const void *, int> x) {
328             auto [blob, size] = x;
329             this->setBlob(parameterIndex, blob, size);
330         }
331 
332     private:
333         PreparedStatement_T t_;
334     };
335 
336     class Connection : private noncopyable
337     {
338     public:
Connection_T()339         operator Connection_T() {
340             return t_;
341         }
342 
~Connection()343         ~Connection() {
344             if (t_) {
345                 close();
346             }
347         }
348 
349     protected:  // for ConnectionPool
350         friend class ConnectionPool;
351 
Connection(Connection_T C)352         Connection(Connection_T C)
353         :t_(C)
354         {}
355 
setClosed()356         void setClosed() {
357             t_ = nullptr;
358         }
359 
360     public:
setQueryTimeout(int ms)361         void setQueryTimeout(int ms) {
362             Connection_setQueryTimeout(t_, ms);
363         }
364 
getQueryTimeout()365         int getQueryTimeout() {
366             return Connection_getQueryTimeout(t_);
367         }
368 
setMaxRows(int max)369         void setMaxRows(int max) {
370             Connection_setMaxRows(t_, max);
371         }
372 
getMaxRows()373         int getMaxRows() {
374             return Connection_getMaxRows(t_);
375         }
376 
setFetchSize(int rows)377         void setFetchSize(int rows) {
378             Connection_setFetchSize(t_, rows);
379         }
380 
getFetchSize()381         int getFetchSize() {
382             return Connection_getFetchSize(t_);
383         }
384 
385         //not supported
386         //URL_T Connection_getURL(T C);
387 
ping()388         bool ping() {
389             return Connection_ping(t_);
390         }
391 
clear()392         void clear() {
393             Connection_clear(t_);
394         }
395 
396         //after close(), t_ is set to NULL. so this Connection object can not be used again!
close()397         void close() {
398             if (t_) {
399                 Connection_close(t_);
400                 setClosed();
401             }
402         }
403 
beginTransaction()404         void beginTransaction() {
405             except_wrapper( Connection_beginTransaction(t_) );
406         }
407 
commit()408         void commit() {
409             except_wrapper( Connection_commit(t_) );
410         }
411 
rollback()412         void rollback() {
413             except_wrapper( Connection_rollback(t_) );
414         }
415 
lastRowId()416         long long lastRowId() {
417             return Connection_lastRowId(t_);
418         }
419 
rowsChanged()420         long long rowsChanged() {
421             return Connection_rowsChanged(t_);
422         }
423 
execute(const char * sql)424         void execute(const char *sql) {
425             except_wrapper( Connection_execute(t_, "%s", sql) );
426         }
427 
428         template <typename ...Args>
execute(const char * sql,Args...args)429         void execute(const char *sql, Args ... args) {
430             PreparedStatement p(this->prepareStatement(sql, args...));
431             p.execute();
432         }
433 
executeQuery(const char * sql)434         ResultSet executeQuery(const char *sql) {
435             except_wrapper(
436                            ResultSet_T r = Connection_executeQuery(t_, "%s", sql);
437                            RETURN ResultSet(r);
438                            );
439         }
440 
441         template <typename ...Args>
executeQuery(const char * sql,Args...args)442         ResultSet executeQuery(const char *sql, Args ... args) {
443             PreparedStatement p(this->prepareStatement(sql, args...));
444             return p.executeQuery();
445         }
446 
prepareStatement(const char * sql)447         PreparedStatement prepareStatement(const char *sql) {
448             except_wrapper(
449                            PreparedStatement_T p = Connection_prepareStatement(t_, "%s", sql);
450                            RETURN PreparedStatement(p);
451                            );
452         }
453 
454         template <typename ...Args>
prepareStatement(const char * sql,Args...args)455         PreparedStatement prepareStatement(const char *sql, Args ... args) {
456             except_wrapper(
457                            PreparedStatement p(this->prepareStatement(sql));
458                            int i = 1;
459                            (p.bind(i++, args), ...);
460                            RETURN p;
461                            );
462         }
463 
getLastError()464         const char *getLastError() {
465             return Connection_getLastError(t_);
466         }
467 
isSupported(const char * url)468         static bool isSupported(const char *url) {
469             return Connection_isSupported(url);
470         }
471 
472     private:
473         Connection_T t_;
474     };
475 
476 
477     class ConnectionPool : private noncopyable
478     {
479     public:
ConnectionPool(const std::string & url)480         ConnectionPool(const std::string& url)
481         :ConnectionPool(url.c_str())
482         {}
483 
ConnectionPool(const char * url)484         ConnectionPool(const char* url)
485         :url_(url)
486         {
487             if (!url_)
488                 throw sql_exception("Invalid URL");
489             t_ = ConnectionPool_new(url_);
490         }
491 
~ConnectionPool()492         ~ConnectionPool() {
493             ConnectionPool_free(&t_);
494         }
495 
ConnectionPool_T()496         operator ConnectionPool_T() {
497             return t_;
498         }
499 
500     public:
getURL()501         const URL& getURL() {
502             return url_;
503         }
504 
setInitialConnections(int connections)505         void setInitialConnections(int connections) {
506             ConnectionPool_setInitialConnections(t_, connections);
507         }
508 
getInitialConnections()509         int getInitialConnections() {
510             return ConnectionPool_getInitialConnections(t_);
511         }
512 
setMaxConnections(int maxConnections)513         void setMaxConnections(int maxConnections) {
514             ConnectionPool_setMaxConnections(t_, maxConnections);
515         }
516 
getMaxConnections()517         int getMaxConnections() {
518             return ConnectionPool_getMaxConnections(t_);
519         }
520 
setConnectionTimeout(int connectionTimeout)521         void setConnectionTimeout(int connectionTimeout) {
522             ConnectionPool_setConnectionTimeout(t_, connectionTimeout);
523         }
524 
getConnectionTimeout()525         int getConnectionTimeout() {
526             return ConnectionPool_getConnectionTimeout(t_);
527         }
528 
setAbortHandler(void (* abortHandler)(const char * error))529         void setAbortHandler(void(*abortHandler)(const char *error)) {
530             ConnectionPool_setAbortHandler(t_, abortHandler);
531         }
532 
setReaper(int sweepInterval)533         void setReaper(int sweepInterval) {
534             ConnectionPool_setReaper(t_, sweepInterval);
535         }
536 
size()537         int size() {
538             return ConnectionPool_size(t_);
539         }
540 
active()541         int active() {
542             return ConnectionPool_active(t_);
543         }
544 
start()545         void start() {
546             except_wrapper( ConnectionPool_start(t_) );
547         }
548 
stop()549         void stop() {
550             ConnectionPool_stop(t_);
551         }
552 
getConnection()553         Connection getConnection() {
554             Connection_T C = ConnectionPool_getConnection(t_);
555             if (!C) {
556                 throw sql_exception("maxConnection is reached (got null connection)!");
557             }
558             return Connection(C);
559         }
560 
returnConnection(Connection & con)561         void returnConnection(Connection& con) {
562             con.close();
563         }
564 
reapConnections()565         int reapConnections() {
566             return ConnectionPool_reapConnections(t_);
567         }
568 
version(void)569         static const char *version(void) {
570             return ConnectionPool_version();
571         }
572 
573     private:
574         URL url_;
575         ConnectionPool_T t_;
576     };
577 
578 
579 } // namespace
580 
581 #endif
582