1#!/usr/bin/env python
2# vi:tabstop=4:expandtab:shiftwidth=4:softtabstop=4:autoindent:smarttab
3'''
4Usage: python sqlite2cpp.py path_to_sql_file
5'''
6
7import sys
8import os
9import datetime
10import sqlite3
11
12# https://github.com/django/django/blob/master/django/db/backends/sqlite3/introspection.py
13def get_table_list(cursor):
14    "Returns a list of table names in the current database."
15    # Skip the sqlite_sequence system table used for autoincrement key
16    # generation.
17    cursor.execute("""
18        SELECT name, sql FROM sqlite_master
19        WHERE type='table' AND NOT name='sqlite_sequence'
20        ORDER BY name""")
21    return [(row[0], row[1]) for row in cursor.fetchall()]
22
23
24def _table_info(cursor, name):
25    cursor.execute('PRAGMA table_info(%s)' % name)
26    # cid, name, type, notnull, dflt_value, pk
27    return [{'cid': field[0],
28        'name': field[1],
29        'type': field[2].upper(),
30        'null_ok': not field[3],
31        'pk': field[5]     # undocumented
32        } for field in cursor.fetchall()]
33
34def get_index_list(cursor, tbl_name):
35    "Returns a list of table names in the current database."
36    # Skip the sqlite_sequence system table used for autoincrement key
37    # generation.
38    cursor.execute("""
39        SELECT tbl_name, sql FROM sqlite_master
40        WHERE type='index' AND name NOT LIKE 'sqlite_autoindex_%%' AND tbl_name = '%s'
41        ORDER BY name""" % tbl_name)
42    return [row[1] for row in cursor.fetchall()]
43
44
45base_data_types_reverse = {
46    'TEXT': 'wxString',
47    'NUMERIC': 'double',
48    'INTEGER': 'int',
49    'REAL': 'double',
50    'BLOB': 'wxString',
51    'DATE': 'wxDateTime',
52}
53
54base_data_types_function = {
55    'TEXT': 'GetString',
56    'NUMERIC': 'GetDouble',
57    'INTEGER': 'GetInt',
58    'REAL': 'GetDouble',
59}
60
61class DB_Table:
62    def __init__(self, table, fields, index):
63        self._table = table
64        self._fields = fields
65        self._primay_key = [field['name'] for field in self._fields if field['pk']][0]
66        self._index = index
67
68    def generate_class(self, header, sql):
69        fp = open('DB_Table_' + self._table.title() + '.h', 'w')
70        fp.write(header + self.to_string(sql))
71        fp.close()
72
73    def to_string(self, sql = None):
74
75        s = '''
76#ifndef DB_TABLE_%s_H
77#define DB_TABLE_%s_H
78
79#include "DB_Table.h"
80
81struct DB_Table_%s : public DB_Table
82{
83    struct Data;
84    typedef DB_Table_%s Self;
85    /** A container to hold list of Data records for the table*/
86    struct Data_Set : public std::vector<Self::Data>
87    {
88        std::wstring to_json() const
89        {
90            json::Array a;
91            for (const auto & item: *this)
92            {
93                json::Object o;
94                item.to_json(o);
95                a.Insert(o);
96            }
97            std::wstringstream ss;
98            json::Writer::Write(a, ss);
99            return ss.str();
100        }
101    };
102    /** A container to hold a list of Data record pointers for the table in memory*/
103    typedef std::vector<Self::Data*> Cache;
104    typedef std::map<int, Self::Data*> Index_By_Id;
105    Cache cache_;
106    Index_By_Id index_by_id_;
107
108    /** Destructor: clears any data records stored in memory */
109    ~DB_Table_%s()
110    {
111        destroy_cache();
112    }
113
114    /** Removes all records stored in memory (cache) for the table*/
115    void destroy_cache()
116    {
117        std::for_each(cache_.begin(), cache_.end(), std::mem_fun(&Data::destroy));
118        cache_.clear();
119        index_by_id_.clear(); // no memory release since it just stores pointer and the according objects are in cache
120    }
121''' % (self._table.upper(), self._table.upper(), self._table, self._table, self._table)
122
123        s += '''
124    /** Creates the database table if the table does not exist*/
125    bool ensure(wxSQLite3Database* db)
126    {
127        if (!exists(db))
128		{
129			try
130			{
131				db->ExecuteUpdate("%s");
132			}
133			catch(const wxSQLite3Exception &e)
134			{
135				wxLogError("%s: Exception %%s", e.GetMessage().c_str());
136				return false;
137			}
138		}
139
140        this->ensure_index(db);
141
142        return true;
143    }
144''' % (sql.replace('\n', ''), self._table)
145
146        s += '''
147    bool ensure_index(wxSQLite3Database* db)
148    {
149        try
150        {'''
151        for i in self._index:
152            mi = i.split()
153            mi.insert(2, 'IF')
154            mi.insert(3, 'NOT')
155            mi.insert(4, 'EXISTS')
156            ni = ' '.join(mi)
157            s += '''
158            db->ExecuteUpdate("%s");''' % (ni.replace('\n', ''))
159
160        s += '''
161        }
162        catch(const wxSQLite3Exception &e)
163        {
164            wxLogError("%s: Exception %%s", e.GetMessage().c_str());
165            return false;
166        }
167
168        return true;
169    }
170''' % (self._table)
171
172        for field in self._fields:
173            s += '''
174    struct %s : public DB_Column<%s>
175    {
176        static wxString name() { return "%s"; }
177        explicit %s(const %s &v, OP op = EQUAL): DB_Column<%s>(v, op) {}
178    };''' % (field['name'], base_data_types_reverse[field['type']], field['name']
179            , field['name'], base_data_types_reverse[field['type']], base_data_types_reverse[field['type']])
180
181        s += '''
182    typedef %s PRIMARY;''' % self._primay_key
183
184        s += '''
185    enum COLUMN
186    {
187        COL_%s = 0''' % self._primay_key.upper()
188
189        for index, name in enumerate([field['name'] for field in self._fields if not field['pk']]):
190            s += '''
191        , COL_%s = %d''' % (name.upper(), index +1)
192
193        s +='''
194    };
195'''
196        s += '''
197    /** Returns the column name as a string*/
198    static wxString column_to_name(COLUMN col)
199    {
200        switch(col)
201        {
202            case COL_%s: return "%s";''' % (self._primay_key.upper(), self._primay_key)
203
204        for index, name in enumerate([field['name'] for field in self._fields if not field['pk']]):
205            s += '''
206            case COL_%s: return "%s";''' %(name.upper(), name)
207        s +='''
208            default: break;
209        }
210
211        return "UNKNOWN";
212    }
213'''
214        s +='''
215    /** Returns the column number from the given column name*/
216    static COLUMN name_to_column(const wxString& name)
217    {
218        if ("%s" == name) return COL_%s;''' % (self._primay_key, self._primay_key.upper())
219
220        for index, name in enumerate([field['name'] for field in self._fields if not field['pk']]):
221            s += '''
222        else if ("%s" == name) return COL_%s;''' %(name, name.upper())
223
224        s += '''
225
226        return COLUMN(-1);
227    }
228    '''
229        s += '''
230    /** Data is a single record in the database table*/
231    struct Data
232    {
233        friend struct DB_Table_%s;
234        /** This is a instance pointer to itself in memory. */
235        Self* view_;
236    ''' % self._table.upper()
237        for field in self._fields:
238            s += '''
239        %s %s;%s''' % (base_data_types_reverse[field['type']], field['name'], field['pk'] and '//  primary key' or '')
240
241        s +='''
242        int id() const { return %s; }
243        void id(int id) { %s = id; }
244        bool operator < (const Data& r) const
245        {
246            return this->id() < r.id();
247        }
248        bool operator < (const Data* r) const
249        {
250            return this->id() < r->id();
251        }
252''' % (self._primay_key, self._primay_key)
253
254        s += '''
255        explicit Data(Self* view = 0)
256        {
257            view_ = view;
258        '''
259
260        for field in self._fields:
261            type = base_data_types_reverse[field['type']]
262            if type == 'wxString':
263                continue
264            elif type == 'double':
265                s += '''
266            %s = 0.0;''' % field['name']
267            elif type == 'int':
268                s += '''
269            %s = -1;''' % field['name']
270
271
272        s += '''
273        }
274
275        explicit Data(wxSQLite3ResultSet& q, Self* view = 0)
276        {
277            view_ = view;
278        '''
279        for field in self._fields:
280            func = base_data_types_function[field['type']]
281            s += '''
282            %s = q.%s(%d); // %s''' % (field['name'], func, field['cid'], field['name'])
283
284        s += '''
285        }
286
287        Data& operator=(const Data& other)
288        {
289            if (this == &other) return *this;
290'''
291        for field in self._fields:
292            s += '''
293            %s = other.%s;''' % (field['name'], field['name'])
294        s += '''
295            return *this;
296        }
297'''
298        s += '''
299        template<typename C>
300        bool match(const C &c) const
301        {
302            return false;
303        }'''
304        for field in self._fields:
305            type = base_data_types_reverse[field['type']]
306            if type == 'wxString':
307                s += '''
308        bool match(const Self::%s &in) const
309        {
310            return this->%s.CmpNoCase(in.v_) == 0;
311        }''' % (field['name'], field['name'])
312            else:
313                s += '''
314        bool match(const Self::%s &in) const
315        {
316            return this->%s == in.v_;
317        }''' % (field['name'], field['name'])
318
319        s += '''
320        wxString to_json() const
321        {
322            json::Object o;
323            this->to_json(o);
324            std::wstringstream ss;
325            json::Writer::Write(o, ss);
326            return ss.str();
327        }
328
329        int to_json(json::Object& o) const
330        {'''
331
332        for field in self._fields:
333            type = base_data_types_reverse[field['type']]
334            if type == 'wxString':
335                s += '''
336            o[L"%s"] = json::String(this->%s.ToStdWstring());''' % (field['name'], field['name'])
337            else:
338                s += '''
339            o[L"%s"] = json::Number(this->%s);''' % (field['name'], field['name'])
340
341        s +='''
342            return 0;
343        }'''
344
345        s +='''
346        row_t to_row_t() const
347        {
348            row_t row;'''
349        for field in self._fields:
350            s += '''
351            row(L"%s") = %s;'''%(field['name'], field['name'])
352
353        s+='''
354            return row;
355        }'''
356
357        s +='''
358        void to_template(html_template& t) const
359        {'''
360        for field in self._fields:
361            s += '''
362            t(L"%s") = %s;''' % (field['name'], field['name'])
363
364        s +='''
365        }'''
366
367        s += '''
368
369        /** Save the record instance in memory to the database. */
370        bool save(wxSQLite3Database* db)
371        {
372            if (db && db->IsReadOnly()) return false;
373            if (!view_ || !db)
374            {
375                wxLogError("can not save %s");
376                return false;
377            }
378
379            return view_->save(this, db);
380        }
381
382        /** Remove the record instance from memory and the database. */
383        bool remove(wxSQLite3Database* db)
384        {
385            if (!view_ || !db)
386            {
387                wxLogError("can not remove %s");
388                return false;
389            }
390
391            return view_->remove(this, db);
392        }
393
394        void destroy()
395        {
396            //if (this->id() < 0)
397            //    wxSafeShowMessage("unsaved object", this->to_json());
398            delete this;
399        }
400    };
401''' % (self._table.upper(), self._table.upper())
402        s +='''
403    enum
404    {
405        NUM_COLUMNS = %d
406    };
407
408    size_t num_columns() const { return NUM_COLUMNS; }
409''' % len(self._fields)
410
411        s += '''
412    /** Name of the table*/
413    wxString name() const { return "%s"; }
414''' % self._table
415
416        s +='''
417    DB_Table_%s()
418    {
419        query_ = "SELECT * FROM %s ";
420    }
421''' % (self._table, self._table)
422
423        s +='''
424    /** Create a new Data record and add to memory table (cache)*/
425    Self::Data* create()
426    {
427        Self::Data* entity = new Self::Data(this);
428        cache_.push_back(entity);
429        return entity;
430    }
431
432    /** Create a copy of the Data record and add to memory table (cache)*/
433    Self::Data* clone(const Data* e)
434    {
435        Self::Data* entity = create();
436        *entity = *e;
437        entity->id(-1);
438        return entity;
439    }
440'''
441        s +='''
442    /**
443    * Saves the Data record to the database table.
444    * Either create a new record or update the existing record.
445    * Remove old record from the memory table (cache)
446    */
447    bool save(Self::Data* entity, wxSQLite3Database* db)
448    {
449        wxString sql = wxEmptyString;
450        if (entity->id() <= 0) //  new & insert
451        {
452            sql = "INSERT INTO %s(%s) VALUES(%s)";
453        }''' % (self._table, ', '.join([field['name'] for field in self._fields if not field['pk']]), ', '.join(['?' for field in self._fields if not field['pk']]))
454
455        s +='''
456        else
457        {
458            sql = "UPDATE %s SET %s WHERE %s = ?";
459        }
460
461        try
462        {
463            wxSQLite3Statement stmt = db->PrepareStatement(sql);
464''' % (self._table, ', '.join([field['name'] + ' = ?' for field in self._fields if not field['pk']]), self._primay_key)
465
466        for index, name in enumerate([field['name'] for field in self._fields if not field['pk']]):
467            s +='''
468            stmt.Bind(%d, entity->%s);'''% (index + 1, name)
469
470
471        s +='''
472            if (entity->id() > 0)
473                stmt.Bind(%d, entity->%s);
474
475            stmt.ExecuteUpdate();
476            stmt.Finalize();
477
478            if (entity->id() > 0) // existent
479            {
480                for(Cache::iterator it = cache_.begin(); it != cache_.end(); ++ it)
481                {
482                    Self::Data* e = *it;
483                    if (e->id() == entity->id())
484                        *e = *entity;  // in-place update
485                }
486            }
487        }
488        catch(const wxSQLite3Exception &e)
489        {
490            wxLogError("%s: Exception %%s, %%s", e.GetMessage().c_str(), entity->to_json());
491            return false;
492        }
493
494        if (entity->id() <= 0)
495        {
496            entity->id((db->GetLastRowId()).ToLong());
497            index_by_id_.insert(std::make_pair(entity->id(), entity));
498        }
499        return true;
500    }
501''' % (len(self._fields), self._primay_key, self._table)
502        s +='''
503    /** Remove the Data record from the database and the memory table (cache) */
504    bool remove(int id, wxSQLite3Database* db)
505    {
506        if (id <= 0) return false;
507        try
508        {
509            wxString sql = "DELETE FROM %s WHERE %s = ?";
510            wxSQLite3Statement stmt = db->PrepareStatement(sql);
511            stmt.Bind(1, id);
512            stmt.ExecuteUpdate();
513            stmt.Finalize();
514
515            Cache c;
516            for(Cache::iterator it = cache_.begin(); it != cache_.end(); ++ it)
517            {
518                Self::Data* entity = *it;
519                if (entity->id() == id)
520                {
521                    index_by_id_.erase(entity->id());
522                    delete entity;
523                }
524                else
525                {
526                    c.push_back(entity);
527                }
528            }
529            cache_.clear();
530            cache_.swap(c);
531        }
532        catch(const wxSQLite3Exception &e)
533        {
534            wxLogError("%s: Exception %%s", e.GetMessage().c_str());
535            return false;
536        }
537
538        return true;
539    }
540
541    /** Remove the Data record from the database and the memory table (cache) */
542    bool remove(Self::Data* entity, wxSQLite3Database* db)
543    {
544        if (remove(entity->id(), db))
545        {
546            entity->id(-1);
547            return true;
548        }
549
550        return false;
551    }
552''' % (self._table, self._primay_key, self._table)
553
554        s += '''
555    template<typename... Args>
556    Self::Data* get_one(const Args& ... args)
557    {
558        for (Index_By_Id::iterator it = index_by_id_.begin(); it != index_by_id_.end(); ++ it)
559        {
560            Self::Data* item = it->second;
561            if (item->id() > 0 && match(item, args...))
562            {
563                ++ hit_;
564                return item;
565            }
566        }
567
568        ++ miss_;
569
570        return 0;
571    }'''
572
573        s +='''
574
575    /**
576    * Search the memory table (Cache) for the data record.
577    * If not found in memory, search the database and update the cache.
578    */
579    Self::Data* get(int id, wxSQLite3Database* db)
580    {
581        if (id <= 0)
582        {
583            ++ skip_;
584            return 0;
585        }
586
587        Index_By_Id::iterator it = index_by_id_.find(id);
588        if (it != index_by_id_.end())
589        {
590            ++ hit_;
591            return it->second;
592        }
593
594        ++ miss_;
595        Self::Data* entity = 0;
596        wxString where = wxString::Format(" WHERE %s = ?", PRIMARY::name().c_str());
597        try
598        {
599            wxSQLite3Statement stmt = db->PrepareStatement(this->query() + where);
600            stmt.Bind(1, id);
601
602            wxSQLite3ResultSet q = stmt.ExecuteQuery();
603            if(q.NextRow())
604            {
605                entity = new Self::Data(q, this);
606                cache_.push_back(entity);
607                index_by_id_.insert(std::make_pair(id, entity));
608            }
609            stmt.Finalize();
610        }
611        catch(const wxSQLite3Exception &e)
612        {
613            wxLogError("%s: Exception %s", this->name().c_str(), e.GetMessage().c_str());
614        }
615
616        if (!entity)
617        {
618            wxLogError("%s: %d not found", this->name().c_str(), id);
619        }
620
621        return entity;
622    }
623'''
624        s +='''
625    /**
626    * Return a list of Data records (Data_Set) derived directly from the database.
627    * The Data_Set is sorted based on the column number.
628    */
629    const Data_Set all(wxSQLite3Database* db, COLUMN col = COLUMN(0), bool asc = true)
630    {
631        Data_Set result;
632        try
633        {
634            wxSQLite3ResultSet q = db->ExecuteQuery(col == COLUMN(0) ? this->query() : this->query() + " ORDER BY " + column_to_name(col) + " COLLATE NOCASE " + (asc ? " ASC " : " DESC "));
635
636            while(q.NextRow())
637            {
638                Self::Data entity(q, this);
639                result.push_back(entity);
640            }
641
642            q.Finalize();
643        }
644        catch(const wxSQLite3Exception &e)
645        {
646            wxLogError("%s: Exception %s", this->name().c_str(), e.GetMessage().c_str());
647        }
648
649        return result;
650    }
651'''
652        s += '''};
653#endif //
654'''
655        return s
656
657def generate_base_class(header, fields=set):
658    fp = open('DB_Table.h', 'w')
659    code = header + '''
660#ifndef DB_TABLE_H
661#define DB_TABLE_H
662
663#include <vector>
664#include <map>
665#include <algorithm>
666#include <functional>
667#include <wx/wxsqlite3.h>
668
669#include "cajun/json/elements.h"
670#include "cajun/json/reader.h"
671#include "cajun/json/writer.h"
672#include "html_template.h"
673using namespace tmpl;
674
675class wxString;
676enum OP { EQUAL = 0, GREATER, LESS, GREATER_OR_EQUAL, LESS_OR_EQUAL, NOT_EQUAL };
677template<class V>
678struct DB_Column
679{
680    V v_;
681    OP op_;
682    DB_Column(const V& v, OP op = EQUAL): v_(v), op_(op)
683    {}
684};
685
686struct DB_Table
687{
688    DB_Table(): hit_(0), miss_(0), skip_(0) {};
689    virtual ~DB_Table() {};
690    wxString query_;
691    size_t hit_, miss_, skip_;
692    virtual wxString query() const { return this->query_; }
693    virtual size_t num_columns() const = 0;
694    virtual wxString name() const = 0;
695
696    bool exists(wxSQLite3Database* db) const
697    {
698       return db->TableExists(this->name());
699    }
700};
701
702template<typename Arg1>
703void condition(wxString& out, bool /*op_and*/, const Arg1& arg1)
704{
705    out += Arg1::name();
706    switch (arg1.op_)
707    {
708    case GREATER:           out += " > ? ";     break;
709    case GREATER_OR_EQUAL:  out += " >= ? ";    break;
710    case LESS:              out += " < ? ";     break;
711    case LESS_OR_EQUAL:     out += " <= ? ";    break;
712    case NOT_EQUAL:         out += " != ? ";    break;
713    default:
714        out += " = ? "; break;
715    }
716}
717
718template<typename Arg1, typename... Args>
719void condition(wxString& out, bool op_and, const Arg1& arg1, const Args&... args)
720{
721    out += Arg1::name();
722    switch (arg1.op_)
723    {
724    case GREATER:           out += " > ? ";     break;
725    case GREATER_OR_EQUAL:  out += " >= ? ";    break;
726    case LESS:              out += " < ? ";     break;
727    case LESS_OR_EQUAL:     out += " <= ? ";    break;
728    case NOT_EQUAL:         out += " != ? ";    break;
729    default:
730        out += " = ? "; break;
731    }
732    out += op_and? " AND " : " OR ";
733    condition(out, op_and, args...);
734}
735
736template<typename Arg1>
737void bind(wxSQLite3Statement& stmt, int index, const Arg1& arg1)
738{
739    stmt.Bind(index, arg1.v_);
740}
741
742template<typename Arg1, typename... Args>
743void bind(wxSQLite3Statement& stmt, int index, const Arg1& arg1, const Args&... args)
744{
745    stmt.Bind(index, arg1.v_);
746    bind(stmt, index+1, args...);
747}
748
749template<typename TABLE, typename... Args>
750const typename TABLE::Data_Set find_by(TABLE* table, wxSQLite3Database* db, bool op_and, const Args&... args)
751{
752    typename TABLE::Data_Set result;
753    try
754    {
755        wxString query = table->query() + " WHERE ";
756        condition(query, op_and, args...);
757        wxSQLite3Statement stmt = db->PrepareStatement(query);
758        bind(stmt, 1, args...);
759
760        wxSQLite3ResultSet q = stmt.ExecuteQuery();
761
762        while(q.NextRow())
763        {
764            typename TABLE::Data entity(q, table);
765            result.push_back(entity);
766        }
767
768        q.Finalize();
769    }
770    catch(const wxSQLite3Exception &e)
771    {
772        wxLogError("%s: Exception %s", table->name().c_str(), e.GetMessage().c_str());
773    }
774
775    return result;
776}
777
778template<class DATA, typename Arg1>
779bool match(const DATA* data, const Arg1& arg1)
780{
781    return data->match(arg1);
782}
783
784template<class DATA, typename Arg1, typename... Args>
785bool match(const DATA* data, const Arg1& arg1, const Args&... args)
786{
787    if (data->match(arg1))
788        return match(data, args...);
789    else
790        return false; // Short-circuit evaluation
791}
792'''
793    for field in sorted(fields):
794        code += '''
795struct SorterBy%s
796{
797    template<class DATA>
798    bool operator()(const DATA& x, const DATA& y)
799    {
800        return x.%s < y.%s;
801    }
802};
803''' % (field, field, field)
804
805    code += '''
806#endif //
807'''
808    fp = open('db_table.h', 'w')
809    fp.write(code)
810    fp.close
811
812if __name__ == '__main__':
813    header =  '''// -*- C++ -*-
814//=============================================================================
815/**
816 *      Copyright (c) 2013,2014,2015 Guan Lisheng (guanlisheng@gmail.com)
817 *
818 *      @file
819 *
820 *      @author [%s]
821 *
822 *      @brief
823 *
824 *      Revision History:
825 *          AUTO GENERATED at %s.
826 *          DO NOT EDIT!
827 */
828//=============================================================================
829'''% (os.path.basename(__file__), str(datetime.datetime.now()))
830
831    conn, cur, sql_file = None, None, None
832    try:
833        sql_file = sys.argv[1]
834        conn = sqlite3.connect(":memory:")
835        conn.row_factory = sqlite3.Row
836        cur = conn.cursor()
837    except:
838        print __doc__
839        sys.exit(1)
840
841    sql = ""
842    for line in open(sql_file, 'rb'):
843        sql = sql + line;
844
845    cur.executescript(sql)
846
847    all_fields = set()
848    for table, sql in get_table_list(cur):
849        fields = _table_info(cur, table)
850        index = get_index_list(cur, table)
851        view = DB_Table(table, fields, index)
852        view.generate_class(header, sql)
853        for field in fields:
854            all_fields.add(field['name'])
855
856    generate_base_class(header, all_fields)
857
858    conn.close()
859
860