1 /* Sqlite3.hpp
2  *
3  * Copyright (C) 2017 Red Hat, Inc.
4  *
5  * Licensed under the GNU Lesser General Public License Version 2.1
6  *
7  * This library is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU Lesser General Public
9  * License as published by the Free Software Foundation; either
10  * version 2.1 of the License, or (at your option) any later version.
11  *
12  * This library is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15  * Lesser General Public License for more details.
16  *
17  * You should have received a copy of the GNU Lesser General Public
18  * License along with this library; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
20  */
21 
22 #ifndef _SQLITE3_HPP
23 #define _SQLITE3_HPP
24 
25 #include "../../error.hpp"
26 #include "../../log.hpp"
27 
28 #include <sqlite3.h>
29 
30 #include <map>
31 #include <memory>
32 #include <stdexcept>
33 #include <string>
34 #include <vector>
35 
36 class SQLite3 {
37 public:
38     class Error : public libdnf::Error {
39     public:
Error(const SQLite3 & s,int code,const std::string & msg)40         Error(const SQLite3& s, int code, const std::string &msg) :
41             libdnf::Error("SQLite error on \"" + s.getPath() + "\": " + msg + ": " + s.getError()),
42             ecode{code}
43         {}
44 
code() const45         int code() const noexcept { return ecode; }
codeStr() const46         const char *codeStr() const noexcept { return sqlite3_errstr(ecode); }
47 
48     protected:
49         int ecode;
50     };
51 
52     struct Blob {
53         size_t size;
54         const void *data;
55     };
56 
57     class Statement {
58     public:
59         /**
60          * An error class that will log the SQL statement in its constructor.
61          */
62         class Error : public SQLite3::Error {
63         public:
Error(Statement & stmt,int code,const std::string & msg)64             Error(Statement& stmt, int code, const std::string& msg) :
65                 SQLite3::Error(stmt.db, code, msg)
66             {
67                 auto logger(libdnf::Log::getLogger());
68                 logger->debug(std::string("SQL statement being executed: ") + stmt.getExpandedSql());
69             }
70         };
71 
72         enum class StepResult { DONE, ROW, BUSY };
73 
74         Statement(const Statement &) = delete;
75         Statement &operator=(const Statement &) = delete;
76 
Statement(SQLite3 & db,const char * sql)77         Statement(SQLite3 &db, const char *sql)
78           : db(db)
79         {
80             auto result = sqlite3_prepare_v2(db.db, sql, -1, &stmt, nullptr);
81             if (result != SQLITE_OK)
82                 throw SQLite3::Error(db, result, "Creating statement failed");
83         };
84 
Statement(SQLite3 & db,const std::string & sql)85         Statement(SQLite3 &db, const std::string &sql)
86           : db(db)
87         {
88             auto result = sqlite3_prepare_v2(db.db, sql.c_str(), sql.length() + 1, &stmt, nullptr);
89             if (result != SQLITE_OK)
90                 throw SQLite3::Error(db, result, "Creating statement failed");
91         };
92 
bind(int pos,int val)93         void bind(int pos, int val)
94         {
95             auto result = sqlite3_bind_int(stmt, pos, val);
96             if (result != SQLITE_OK)
97                 throw Error(*this, result, "Integer bind failed");
98         }
99 
bind(int pos,std::int64_t val)100         void bind(int pos, std::int64_t val)
101         {
102             auto result = sqlite3_bind_int64(stmt, pos, val);
103             if (result != SQLITE_OK)
104                 throw Error(*this, result, "Integer64 bind failed");
105         }
106 
bind(int pos,std::uint32_t val)107         void bind(int pos, std::uint32_t val)
108         {
109             auto result = sqlite3_bind_int(stmt, pos, val);
110             if (result != SQLITE_OK)
111                 throw Error(*this, result, "Unsigned integer bind failed");
112         }
113 
bind(int pos,double val)114         void bind(int pos, double val)
115         {
116             auto result = sqlite3_bind_double(stmt, pos, val);
117             if (result != SQLITE_OK)
118                 throw Error(*this, result, "Double bind failed");
119         }
120 
bind(int pos,bool val)121         void bind(int pos, bool val)
122         {
123             auto result = sqlite3_bind_int(stmt, pos, val ? 1 : 0);
124             if (result != SQLITE_OK)
125                 throw Error(*this, result, "Bool bind failed");
126         }
127 
bind(int pos,const char * val)128         void bind(int pos, const char *val)
129         {
130             auto result = sqlite3_bind_text(stmt, pos, val, -1, SQLITE_TRANSIENT);
131             if (result != SQLITE_OK)
132                 throw Error(*this, result, "Text bind failed");
133         }
134 
bind(int pos,const std::string & val)135         void bind(int pos, const std::string &val)
136         {
137             auto result = sqlite3_bind_text(stmt, pos, val.c_str(), -1, SQLITE_TRANSIENT);
138             if (result != SQLITE_OK)
139                 throw Error(*this, result, "Text bind failed");
140         }
141 
bind(int pos,const Blob & val)142         void bind(int pos, const Blob &val)
143         {
144             auto result = sqlite3_bind_blob(stmt, pos, val.data, val.size, SQLITE_TRANSIENT);
145             if (result != SQLITE_OK)
146                 throw Error(*this, result, "Blob bind failed");
147         }
148 
bind(int pos,const std::vector<unsigned char> & val)149         void bind(int pos, const std::vector< unsigned char > &val)
150         {
151             auto result = sqlite3_bind_blob(stmt, pos, val.data(), val.size(), SQLITE_TRANSIENT);
152             if (result != SQLITE_OK)
153                 throw Error(*this, result, "Blob bind failed");
154         }
155 
156         template< typename... Args >
bindv(Args &&...args)157         Statement &bindv(Args &&... args)
158         {
159             using Pass = int[];
160             size_t pos{0};
161             (void)Pass{(bind(++pos, args), 0)...};
162             return *this;
163         }
164 
step()165         StepResult step()
166         {
167             auto result = sqlite3_step(stmt);
168             switch (result) {
169                 case SQLITE_DONE:
170                     return StepResult::DONE;
171                 case SQLITE_ROW:
172                     return StepResult::ROW;
173                 case SQLITE_BUSY:
174                     return StepResult::BUSY;
175                 default:
176                     throw Error(*this, result, "Reading a row failed");
177             }
178         }
179 
getColumnCount() const180         int getColumnCount() const { return sqlite3_column_count(stmt); }
181 
getColumnDatabaseName(int idx) const182         const char *getColumnDatabaseName(int idx) const
183         {
184             return sqlite3_column_database_name(stmt, idx);
185         }
186 
getColumnTableName(int idx) const187         const char *getColumnTableName(int idx) const
188         {
189             return sqlite3_column_table_name(stmt, idx);
190         }
191 
getColumnOriginName(int idx) const192         const char *getColumnOriginName(int idx) const
193         {
194             return sqlite3_column_origin_name(stmt, idx);
195         }
196 
getColumnName(int idx) const197         const char *getColumnName(int idx) const { return sqlite3_column_name(stmt, idx); }
198 
getSql() const199         const char *getSql() const { return sqlite3_sql(stmt); }
200 
getExpandedSql()201         const char *getExpandedSql()
202         {
203 #if SQLITE_VERSION_NUMBER < 3014000
204             // sqlite3_expanded_sql was added in sqlite 3.14; return sql instead
205             return getSql();
206 #else
207             expandSql = sqlite3_expanded_sql(stmt);
208             if (!expandSql) {
209                 throw libdnf::Exception(
210                     "getExpandedSql(): insufficient memory or result "
211                     "exceed the maximum SQLite3 string length");
212             }
213             return expandSql;
214 #endif
215         }
216 
freeExpandedSql()217         void freeExpandedSql() { sqlite3_free(expandSql); }
218 
219         /**
220          * Reset prepared query to its initial state, ready to be re-executed.
221          * All the bound values remain untouched - retain their values.
222          * Use clearBindings if you need to reset them as well.
223          */
reset()224         void reset() { sqlite3_reset(stmt); }
225 
clearBindings()226         void clearBindings() { sqlite3_clear_bindings(stmt); }
227 
228         template< typename T >
get(int idx)229         T get(int idx)
230         {
231             return get(idx, identity< T >{});
232         }
233 
~Statement()234         ~Statement()
235         {
236             freeExpandedSql();
237             sqlite3_finalize(stmt);
238         };
239 
240     protected:
241         template< typename T >
242         struct identity {
243             typedef T type;
244         };
245 
246         /*template<typename TL>
247         TL get(size_t, identity<TL>)
248         {
249             static_assert(sizeof(TL) == 0, "Not implemented");
250         }*/
251 
get(int idx,identity<int>)252         int get(int idx, identity< int >) { return sqlite3_column_int(stmt, idx); }
253 
get(int idx,identity<uint32_t>)254         uint32_t get(int idx, identity< uint32_t >) { return sqlite3_column_int(stmt, idx); }
255 
get(int idx,identity<int64_t>)256         int64_t get(int idx, identity< int64_t >) { return sqlite3_column_int64(stmt, idx); }
257 
get(int idx,identity<double>)258         double get(int idx, identity< double >) { return sqlite3_column_double(stmt, idx); }
259 
get(int idx,identity<bool>)260         bool get(int idx, identity< bool >) { return sqlite3_column_int(stmt, idx) != 0; }
261 
get(int idx,identity<const char * >)262         const char *get(int idx, identity< const char * >)
263         {
264             return reinterpret_cast< const char * >(sqlite3_column_text(stmt, idx));
265         }
266 
get(int idx,identity<std::string>)267         std::string get(int idx, identity< std::string >)
268         {
269             auto ret = reinterpret_cast< const char * >(sqlite3_column_text(stmt, idx));
270             return ret ? ret : "";
271         }
272 
get(int idx,identity<Blob>)273         Blob get(int idx, identity< Blob >)
274         {
275             return {static_cast< size_t >(sqlite3_column_bytes(stmt, idx)),
276                     sqlite3_column_blob(stmt, idx)};
277         }
278 
279         SQLite3 &db;
280         sqlite3_stmt *stmt;
281         char *expandSql{nullptr};
282     };
283 
284     class Query : public Statement {
285     public:
Query(SQLite3 & db,const char * sql)286         Query(SQLite3 &db, const char *sql)
287           : Statement{db, sql}
288         {
289             mapColsName();
290         }
Query(SQLite3 & db,const std::string & sql)291         Query(SQLite3 &db, const std::string &sql)
292           : Statement{db, sql}
293         {
294             mapColsName();
295         }
296 
getColumnIndex(const std::string & colName)297         int getColumnIndex(const std::string &colName)
298         {
299             auto it = colsName2idx.find(colName);
300             if (it == colsName2idx.end())
301                 throw libdnf::Exception("get() column \"" + colName + "\" not found");
302             return it->second;
303         }
304 
305         using Statement::get;
306 
307         template< typename T >
get(const std::string & colName)308         T get(const std::string &colName)
309         {
310             return get(getColumnIndex(colName), identity< T >{});
311         }
312 
313     private:
mapColsName()314         void mapColsName()
315         {
316             for (int idx = 0; idx < getColumnCount(); ++idx) {
317                 const char *name = getColumnName(idx);
318                 if (name)
319                     colsName2idx[name] = idx;
320             }
321         }
322 
323         std::map< std::string, int > colsName2idx;
324     };
325 
326     SQLite3(const SQLite3 &) = delete;
327     SQLite3 &operator=(const SQLite3 &) = delete;
328 
SQLite3(const std::string & dbPath)329     SQLite3(const std::string &dbPath)
330       : path{dbPath}
331       , db{nullptr}
332     {
333         open();
334     }
335 
~SQLite3()336     ~SQLite3() { close(); }
337 
getPath() const338     const std::string &getPath() const { return path; }
339 
340     void open();
341     void close();
isOpened()342     bool isOpened() { return db != nullptr; };
343 
exec(const char * sql)344     void exec(const char *sql)
345     {
346         auto result = sqlite3_exec(db, sql, nullptr, nullptr, nullptr);
347         if (result != SQLITE_OK) {
348             throw Error(*this, result, "Executing an SQL statement failed");
349         }
350     }
351 
changes()352     int changes() { return sqlite3_changes(db); }
353 
lastInsertRowID()354     int64_t lastInsertRowID() { return sqlite3_last_insert_rowid(db); }
355 
getError() const356     std::string getError() const { return sqlite3_errmsg(db); }
357 
358     void backup(const std::string &outputFile);
359     void restore(const std::string &inputFile);
360 
361 protected:
362     std::string path;
363 
364     sqlite3 *db;
365 };
366 
367 typedef std::shared_ptr< SQLite3 > SQLite3Ptr;
368 
369 #endif
370