1#!/usr/bin/env python3
2
3from mako.template import Template
4
5import re
6import sys
7
8
9DEBUG = False
10
11
12def eprint(*args, **kwargs):
13    if not DEBUG:
14        return
15    print(*args, file=sys.stderr, **kwargs)
16
17
18class Rewriter(object):
19
20    def rewrite_types(self, query, mapping):
21        for old, new in mapping.items():
22            query = re.sub(old, new, query)
23        return query
24
25    def rewrite_single(self, query):
26        return query
27
28    def rewrite(self, queries):
29        for i, q in enumerate(queries):
30            org = q['query']
31            queries[i]['query'] = self.rewrite_single(org)
32            eprint("Rewritten statement\n\tfrom {}\n\t  to {}".format(org, q['query']))
33        return queries
34
35
36class Sqlite3Rewriter(Rewriter):
37    def rewrite_single(self, query):
38        # Replace DB specific queries with a no-op
39        if "/*PSQL*/" in query:
40            return "UPDATE vars SET intval=1 WHERE name='doesnotexist'"  # Return a no-op
41
42        typemapping = {
43            r'BIGINT': 'INTEGER',
44            r'BIGINTEGER': 'INTEGER',
45            r'BIGSERIAL': 'INTEGER',
46            r'CURRENT_TIMESTAMP\(\)': "strftime('%s', 'now')",
47            r'INSERT INTO[ \t]+(.*)[ \t]+ON CONFLICT.*DO NOTHING;': 'INSERT OR IGNORE INTO \\1;',
48            # Rewrite "decode('abcd', 'hex')" to become "x'abcd'"
49            r'decode\((.*),\s*[\'\"]hex[\'\"]\)': 'x\\1',
50            # GREATEST() of multiple columns is simple MAX in sqlite3.
51            r'GREATEST\(([^)]*)\)': "MAX(\\1)",
52        }
53        return self.rewrite_types(query, typemapping)
54
55
56class PostgresRewriter(Rewriter):
57    def rewrite_single(self, q):
58        # Replace DB specific queries with a no-op
59        if "/*SQLITE*/" in q:
60            return "UPDATE vars SET intval=1 WHERE name='doesnotexist'"  # Return a no-op
61
62        # Let's start by replacing any eventual '?' placeholders
63        q2 = ""
64        count = 1
65        for c in q:
66            if c == '?':
67                c = "${}".format(count)
68                count += 1
69            q2 += c
70        query = q2
71
72        typemapping = {
73            r'BLOB': 'BYTEA',
74            r'CURRENT_TIMESTAMP\(\)': "EXTRACT(epoch FROM now())",
75        }
76
77        query = self.rewrite_types(query, typemapping)
78        return query
79
80
81rewriters = {
82    "sqlite3": Sqlite3Rewriter(),
83    "postgres": PostgresRewriter(),
84}
85
86template = Template("""#ifndef LIGHTNINGD_WALLET_GEN_DB_${f.upper()}
87#define LIGHTNINGD_WALLET_GEN_DB_${f.upper()}
88
89#include <config.h>
90#include <wallet/db_common.h>
91
92#if HAVE_${f.upper()}
93
94struct db_query db_${f}_queries[] = {
95
96% for elem in queries:
97    {
98         .name = "${elem['name']}",
99         .query = "${elem['query']}",
100         .placeholders = ${elem['placeholders']},
101         .readonly = ${elem['readonly']},
102    },
103% endfor
104};
105
106#define DB_${f.upper()}_QUERY_COUNT ${len(queries)}
107
108#endif /* HAVE_${f.upper()} */
109
110#endif /* LIGHTNINGD_WALLET_GEN_DB_${f.upper()} */
111""")
112
113
114def extract_queries(pofile):
115    # Given a po-file, extract all queries and their associated names, and
116    # return them as a list.
117
118    def chunk(pofile):
119        # Chunk a given file into chunks separated by an empty line
120        with open(pofile, 'r') as f:
121            chunk = []
122            for line in f:
123                line = line.strip()
124                if line.strip() == "":
125                    yield chunk
126                    chunk = []
127                else:
128                    chunk.append(line.strip())
129            if chunk != []:
130                yield chunk
131
132    queries = []
133    for c in chunk(pofile):
134
135        # Skip other comments
136        i = 1
137        while c[i][0] == '#':
138            i += 1
139
140        # Strip header and surrounding quotes
141        query = c[i][7:][:-1]
142
143        queries.append({
144            'name': query,
145            'query': query,
146            'placeholders': query.count('?'),
147            'readonly': "true" if query.upper().startswith("SELECT") else "false",
148        })
149    return queries
150
151
152if __name__ == "__main__":
153    if len(sys.argv) != 3:
154        print("Usage:\n\t{} <statements.po-file> <output-dialect>".format(sys.argv[0]))
155        sys.exit(1)
156
157    dialect = sys.argv[2]
158
159    if dialect not in rewriters:
160        print("Unknown dialect {}. The following are available: {}".format(
161            dialect,
162            ", ".join(rewriters.keys())
163        ))
164        sys.exit(1)
165
166    rewriter = rewriters[dialect]
167
168    queries = extract_queries(sys.argv[1])
169    queries = rewriter.rewrite(queries)
170
171    print(template.render(f=dialect, queries=queries))
172