1 #include "Sql.h"
2 
3 #define LTIMING(x)
4 
5 namespace Upp {
6 
7 template <class T>
MakeSqlValue(int code,T & value)8 String MakeSqlValue(int code, T& value)
9 {
10 	StringBuffer b(sizeof(T) + 1);
11 	b[0] = code;
12 	memcpy(~b + 1, &value, sizeof(T));
13 	return String(b);
14 }
15 
16 template <class T>
ReadSqlValue(T & x,const char * & s)17 T ReadSqlValue(T& x, const char *&s) {
18 	memcpy(&x, s, sizeof(T));
19 	s += sizeof(T);
20 	return x;
21 }
22 
23 static bool sSqlIdQuoted;
24 static bool sToUpperCase;
25 static bool sToLowerCase;
26 
IsUseQuotes()27 bool SqlId::IsUseQuotes()
28 {
29 	return sSqlIdQuoted;
30 }
31 
UseQuotes(bool b)32 void SqlId::UseQuotes(bool b)
33 {
34 	sSqlIdQuoted = b;
35 }
36 
ToLowerCase(bool b)37 void SqlId::ToLowerCase(bool b)
38 {
39 	sToUpperCase = sToUpperCase && !b;
40 	sToLowerCase = b;
41 }
42 
ToUpperCase(bool b)43 void SqlId::ToUpperCase(bool b)
44 {
45 	sToLowerCase = sToLowerCase && !b;
46 	sToUpperCase = b;
47 }
48 
Quoted() const49 String SqlId::Quoted() const
50 {
51 	if(!id.IsNull())
52 		return String().Cat() << '\t' << id << '\t';
53 	return id.ToString();
54 }
55 
SqlCompile(const char * & s,StringBuffer * r,byte dialect,Vector<SqlVal> * split)56 void SqlCompile(const char *&s, StringBuffer *r, byte dialect, Vector<SqlVal> *split)
57 {
58 	char quote = dialect == MY_SQL ? '`' : '\"';
59 	const char *b = s;
60 	int lvl = 0;
61 	for(;;) {
62 		int c = *s++;
63 		switch(c) {
64 		case SQLC_OF:
65 			if(r)
66 				*r << '.';
67 			break;
68 		case SQLC_AS:
69 			if(r) {
70 				if(dialect & (MSSQL | PGSQL))
71 					*r << " as ";
72 				else
73 					*r << ' ';
74 			}
75 			break;
76 		case SQLC_COMMA:
77 			if(r)
78 				*r << ", ";
79 			break;
80 		case SQLC_ID: {
81 				for(;;) {
82 					const char *b = s;
83 					bool do_quote = sSqlIdQuoted && *s != '*';
84 					while((byte)*s >= 32)
85 						s++;
86 					int c = *s;
87 					if(r) {
88 						if(do_quote)
89 							*r << quote;
90 						if(sToUpperCase)
91 							r->Cat(ToUpper(String(b, s)));
92 						else
93 						if(sToLowerCase)
94 							r->Cat(ToLower(String(b, s)));
95 						else
96 							r->Cat(b, s);
97 						if(do_quote)
98 							*r << quote;
99 						if(c == SQLC_AS) {
100 							if(dialect & (MSSQL | PGSQL))
101 								*r << " as ";
102 							else
103 								*r << ' ';
104 						}
105 						else
106 						if(c == SQLC_OF)
107 							*r << '.';
108 						else
109 						if(c == SQLC_COMMA)
110 							*r << ", ";
111 					}
112 					s++;
113 					if(c == SQLC_ID)
114 						break;
115 					if(c == '\0')
116 						return;
117 				}
118 			}
119 			break;
120 		case SQLC_IF: {
121 			LTIMING("SqlCompile IF");
122 			StringBuffer *er = r;
123 			for(;;) {
124 				c = *s++;
125 				if(c & dialect) {
126 					SqlCompile(s, er, dialect, NULL);
127 					er = NULL;
128 				}
129 				else
130 					SqlCompile(s, NULL, dialect, NULL);
131 				if(*s == '\0')
132 					return;
133 				c = *s++;
134 				if(c == SQLC_ELSE) {
135 					SqlCompile(s, er, dialect, NULL);
136 					ASSERT(*s == SQLC_ENDIF);
137 					s++;
138 					break;
139 				}
140 				if(c == SQLC_ENDIF)
141 					break;
142 				ASSERT(c == SQLC_ELSEIF);
143 			}
144 			break;
145 		}
146 		case SQLC_DATE: {
147 			LTIMING("SqlCompile DATE");
148 			Date x;
149 			ReadSqlValue(x, s);
150 			if(!r) break;
151 			if(IsNull(x)) {
152 				*r << "NULL";
153 				break;
154 			}
155 			switch(dialect) {
156 			case MSSQL:
157 				if(x.year < 1753) x.year = 1753; // Date::Low()
158 				*r << Format("convert(datetime, '%d/%d/%d', 120)", x.year, x.month, x.day);
159 				break;
160 			case ORACLE:
161 				*r << Format("to_date('%d/%d/%d', 'SYYYY/MM/DD')", x.year, x.month, x.day);
162 				break;
163 			case PGSQL:
164 				if(x.year < 1) x.year = 1; // Date::Low()
165 				*r << "date ";
166 			default:
167 				*r << Format("\'%04d-%02d-%02d\'", x.year, x.month, x.day);
168 			}
169 			break;
170 		}
171 		case SQLC_TIME: {
172 			LTIMING("SqlCompile TIME");
173 			Time x;
174 			ReadSqlValue(x, s);
175 			if(!r) break;
176 			if(IsNull(x)) {
177 				*r << "NULL";
178 				break;
179 			}
180 			switch(dialect) {
181 			case MSSQL:
182 				if(x.year < 1753) x.year = 1753; // Date::Low()
183 				*r << Format(x.hour || x.minute || x.second
184 				             ? "convert(datetime, '%d/%d/%d %d:%d:%d', 120)"
185 				             : "convert(datetime, '%d/%d/%d', 120)",
186 				             x.year, x.month, x.day, x.hour, x.minute, x.second);
187 				break;
188 			case ORACLE:
189 				*r << Format("to_date('%d/%d/%d/%d', 'SYYYY/MM/DD/SSSSS')",
190 				             x.year, x.month, x.day, x.second + 60 * (x.minute + 60 * x.hour));
191 				break;
192 			case PGSQL:
193 				if(x.year < 1) x.year = 1; // Date::Low()
194 				*r << "timestamp ";
195 			default:
196 				*r << Format("\'%04d-%02d-%02d %02d:%02d:%02d\'",
197 				             x.year, x.month, x.day, x.hour, x.minute, x.second);
198 			}
199 			break;
200 		}
201 		case SQLC_BINARY: {
202 			int l;
203 			ReadSqlValue(l, s);
204 			if(r) {
205 				if(l == 0)
206 					*r << "NULL";
207 				else
208 					switch(dialect) {
209 					case PGSQL: {
210 						*r << "E\'";
211 						const char *e = s + l;
212 						while(s < e) {
213 							byte c = *s++;
214 							if(c < 32 || c > 126 || c == 39 || c == 92) {
215 								*r << '\\' << '\\';
216 								r->Cat(((c >> 6) & 3) + '0');
217 								r->Cat(((c >> 3) & 7) + '0');
218 								r->Cat((c & 7) + '0');
219 							}
220 							else
221 								r->Cat(c);
222 						}
223 						*r << "\'::bytea";
224 						break;
225 					}
226 					case MSSQL:
227 						*r << "0x" << HexString(s, l);
228 						s += l;
229 						break;
230 					case SQLITE3:
231 					case MY_SQL:
232 						*r << "X";
233 					default:
234 						*r << "\'" << HexString(s, l) << "\'";
235 						s += l;
236 						break;
237 					}
238 			}
239 			else
240 				s += l;
241 			break;
242 		}
243 		case SQLC_STRING: {
244 			LTIMING("SqlCompile STRING");
245 			int l;
246 			ReadSqlValue(l, s);
247 			String x = String(s, l);
248 			s += l;
249 			if(!r) break;
250 			if(IsNull(x)) {
251 				*r << "NULL";
252 				break;
253 			}
254 			if(dialect == PGSQL)
255 				r->Cat('E');
256 			r->Cat('\'');
257 			for(const char *q = x; *q; q++) {
258 				int c = (byte)*q;
259 				if(c == '\'') {
260 					if(dialect == MY_SQL)
261 						r->Cat("\\\'");
262 					else if(dialect == PGSQL)
263 						r->Cat("\\'");
264 					else
265 					 	r->Cat("\'\'");
266 				}
267 				else {
268 					if((c == '\"' || c == '\\') && (dialect == MY_SQL || dialect == PGSQL))
269 						r->Cat('\\');
270 					if(dialect == PGSQL && c < 32) {
271 						if(c == '\n')
272 							r->Cat("\\n");
273 						else
274 						if(c == '\r')
275 							r->Cat("\\r");
276 						else
277 						if(c == '\t')
278 							r->Cat("\\t");
279 						else {
280 							char h[4];
281 							h[0] = '\\';
282 							h[1] = (3 & (c >> 6)) + '0';
283 							h[2] = (7 & (c >> 3)) + '0';
284 							h[3] = (7 & c) + '0';
285 							r->Cat(h, 4);
286 						}
287 					}
288 					else
289 						r->Cat(c);
290 				}
291 			}
292 			r->Cat('\'');
293 			break;
294 		}
295 		default:
296 			bool end = c >= 0 && c < 32;
297 			if(split) {
298 				if(c == '(')
299 					lvl++;
300 				if(c == ')')
301 					lvl--;
302 				if((c == ',' && lvl == 0 || end) && s - 1 > b) {
303 					while(*b == ' ')
304 						b++;
305 					split->Add(SqlVal(String(b, s - 1), SqlS::HIGH));
306 					b = s;
307 				}
308 			}
309 			if(end) {
310 				s--;
311 				return;
312 			}
313 			else
314 				if(r) {
315 					const char *p = s - 1;
316 					while((byte)*s >= 32)
317 						s++;
318 					r->Cat(p, s);
319 				}
320 		}
321 	}
322 }
323 
SqlCompile(byte dialect,const String & s)324 String SqlCompile(byte dialect, const String& s)
325 {
326 	StringBuffer b;
327 	b.Reserve(s.GetLength() + 100);
328 	const char *q = s;
329 	SqlCompile(q, &b, dialect, NULL);
330 	return String(b);
331 }
332 
333 #ifndef NOAPPSQL
SqlCompile(const String & s)334 String SqlCompile(const String& s)
335 {
336 	return SqlCompile(SQL.GetDialect(), s);
337 }
338 #endif
339 
SplitSqlSet(const SqlSet & set)340 Vector<SqlVal> SplitSqlSet(const SqlSet& set)
341 {
342 	String h = ~set;
343 	const char *q = h;
344 	Vector<SqlVal> r;
345 	SqlCompile(q, NULL, ORACLE, &r);
346 	return r;
347 }
348 
SqlFormat(int x)349 String SqlFormat(int x)
350 {
351 	if(IsNull(x)) return "NULL";
352 	return Format("%d", x);
353 }
354 
SqlFormat(double x)355 String SqlFormat(double x)
356 {
357 	if(IsNull(x)) return "NULL";
358 	return FormatDouble(x, 20);
359 }
360 
SqlFormat(int64 x)361 String SqlFormat(int64 x)
362 {
363 	if(IsNull(x)) return "NULL";
364 	return FormatInt64(x);
365 }
366 
SqlFormat0(const char * s,int l,int code)367 String SqlFormat0(const char *s, int l, int code)
368 {
369 	StringBuffer b(1 + sizeof(int) + l);
370 	b[0] = code;
371 	memcpy(~b + 1, &l, sizeof(int));
372 	memcpy(~b + 1 + sizeof(int), s, l);
373 	return String(b);
374 }
375 
SqlFormat(const char * s,int l)376 String SqlFormat(const char *s, int l)
377 {
378 	return SqlFormat0(s, l, SQLC_STRING);
379 }
380 
SqlFormat(const char * s)381 String SqlFormat(const char *s)
382 {
383 	return SqlFormat(s, (int)strlen(s));
384 }
385 
SqlFormat(const String & x)386 String SqlFormat(const String& x)
387 {
388 	return SqlFormat(x, x.GetLength());
389 }
390 
SqlFormatBinary(const char * s,int l)391 String SqlFormatBinary(const char *s, int l)
392 {
393 	return SqlFormat0(s, l, SQLC_BINARY);
394 }
395 
SqlFormatBinary(const String & x)396 String SqlFormatBinary(const String& x)
397 {
398 	return SqlFormatBinary(x, x.GetLength());
399 }
400 
SqlFormat(Date x)401 String SqlFormat(Date x)
402 {
403 	return MakeSqlValue(SQLC_DATE, x);
404 }
405 
SqlFormat(Time x)406 String SqlFormat(Time x)
407 {
408 	return MakeSqlValue(SQLC_TIME, x);
409 }
410 
SqlFormat(const Value & x)411 String SqlFormat(const Value& x)
412 {
413 	if(x.IsNull()) return "NULL";
414 	switch(x.GetType()) {
415 	case BOOL_V:
416 	case INT_V:
417 		return SqlFormat((int) x);
418 	case INT64_V:
419 		return SqlFormat((int64) x);
420 	case DOUBLE_V:
421 		return SqlFormat((double) x);
422 	case STRING_V:
423 	case WSTRING_V:
424 		return SqlFormat(String(x));
425 	case DATE_V:
426 		return SqlFormat(Date(x));
427 	case TIME_V:
428 		return SqlFormat(Time(x));
429 	case SQLRAW_V:
430 		return SqlFormatBinary(SqlRaw(x));
431 	}
432 	NEVER();
433 	return "NULL";
434 }
435 
operator ()(const String & text)436 String SqlCode::operator()(const String& text) {
437 	return s << (char)SQLC_ELSE << text << (char)SQLC_ENDIF;
438 }
439 
operator ()()440 String SqlCode::operator()() {
441 	return s << (char)SQLC_ENDIF;
442 }
443 
operator ()(byte cond,const String & text)444 SqlCode SqlCode::operator()(byte cond, const String& text) {
445 	s << (char)SQLC_ELSEIF << (char)cond << text;
446 	return *this;
447 }
448 
SqlCode(byte cond,const String & text)449 SqlCode::SqlCode(byte cond, const String& text) {
450 	s << (char)SQLC_IF << (char)cond << text;
451 }
452 
453 }
454