1// Copyright 2012, Google Inc. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sqlparser
6
7import (
8	"bytes"
9	"fmt"
10	"strings"
11
12	"github.com/dinedal/textql/sqlparser/sqltypes"
13)
14
15const EOFCHAR = 0x100
16
17// Tokenizer is the struct used to generate SQL
18// tokens for the parser.
19type Tokenizer struct {
20	InStream      *strings.Reader
21	AllowComments bool
22	ForceEOF      bool
23	lastChar      uint16
24	Position      int
25	errorToken    []byte
26	LastError     string
27	posVarIndex   int
28	ParseTree     Statement
29}
30
31// NewStringTokenizer creates a new Tokenizer for the
32// sql string.
33func NewStringTokenizer(sql string) *Tokenizer {
34	return &Tokenizer{InStream: strings.NewReader(sql)}
35}
36
37var keywords = map[string]int{
38	"all":           ALL,
39	"alter":         ALTER,
40	"analyze":       ANALYZE,
41	"and":           AND,
42	"as":            AS,
43	"asc":           ASC,
44	"between":       BETWEEN,
45	"by":            BY,
46	"case":          CASE,
47	"create":        CREATE,
48	"cross":         CROSS,
49	"default":       DEFAULT,
50	"delete":        DELETE,
51	"desc":          DESC,
52	"describe":      DESCRIBE,
53	"distinct":      DISTINCT,
54	"drop":          DROP,
55	"duplicate":     DUPLICATE,
56	"else":          ELSE,
57	"end":           END,
58	"except":        EXCEPT,
59	"exists":        EXISTS,
60	"explain":       EXPLAIN,
61	"for":           FOR,
62	"force":         FORCE,
63	"from":          FROM,
64	"group":         GROUP,
65	"having":        HAVING,
66	"if":            IF,
67	"ignore":        IGNORE,
68	"in":            IN,
69	"index":         INDEX,
70	"inner":         INNER,
71	"insert":        INSERT,
72	"intersect":     INTERSECT,
73	"into":          INTO,
74	"is":            IS,
75	"join":          JOIN,
76	"key":           KEY,
77	"keyrange":      KEYRANGE,
78	"left":          LEFT,
79	"like":          LIKE,
80	"limit":         LIMIT,
81	"lock":          LOCK,
82	"minus":         MINUS,
83	"natural":       NATURAL,
84	"not":           NOT,
85	"null":          NULL,
86	"on":            ON,
87	"or":            OR,
88	"order":         ORDER,
89	"outer":         OUTER,
90	"rename":        RENAME,
91	"right":         RIGHT,
92	"select":        SELECT,
93	"set":           SET,
94	"show":          SHOW,
95	"straight_join": STRAIGHT_JOIN,
96	"table":         TABLE,
97	"then":          THEN,
98	"to":            TO,
99	"union":         UNION,
100	"unique":        UNIQUE,
101	"update":        UPDATE,
102	"use":           USE,
103	"using":         USING,
104	"values":        VALUES,
105	"view":          VIEW,
106	"when":          WHEN,
107	"where":         WHERE,
108}
109
110// Lex returns the next token form the Tokenizer.
111// This function is used by go yacc.
112func (tkn *Tokenizer) Lex(lval *yySymType) int {
113	typ, val := tkn.Scan()
114	for typ == COMMENT {
115		if tkn.AllowComments {
116			break
117		}
118		typ, val = tkn.Scan()
119	}
120	switch typ {
121	case ID, STRING, NUMBER, VALUE_ARG, LIST_ARG, COMMENT:
122		lval.bytes = val
123	}
124	tkn.errorToken = val
125	return typ
126}
127
128// Error is called by go yacc if there's a parsing error.
129func (tkn *Tokenizer) Error(err string) {
130	buf := bytes.NewBuffer(make([]byte, 0, 32))
131	if tkn.errorToken != nil {
132		fmt.Fprintf(buf, "%s at position %v near %s", err, tkn.Position, tkn.errorToken)
133	} else {
134		fmt.Fprintf(buf, "%s at position %v", err, tkn.Position)
135	}
136	tkn.LastError = buf.String()
137}
138
139// Scan scans the tokenizer for the next token and returns
140// the token type and an optional value.
141func (tkn *Tokenizer) Scan() (int, []byte) {
142	if tkn.ForceEOF {
143		return 0, nil
144	}
145
146	if tkn.lastChar == 0 {
147		tkn.next()
148	}
149	tkn.skipBlank()
150	switch ch := tkn.lastChar; {
151	case isLetter(ch):
152		return tkn.scanIdentifier()
153	case isDigit(ch):
154		return tkn.scanNumber(false)
155	case ch == ':':
156		return tkn.scanBindVar()
157	default:
158		tkn.next()
159		switch ch {
160		case EOFCHAR:
161			return 0, nil
162		case '=', ',', ';', '(', ')', '+', '*', '%', '&', '|', '^', '~':
163			return int(ch), nil
164		case '?':
165			tkn.posVarIndex++
166			buf := new(bytes.Buffer)
167			fmt.Fprintf(buf, ":v%d", tkn.posVarIndex)
168			return VALUE_ARG, buf.Bytes()
169		case '.':
170			if isDigit(tkn.lastChar) {
171				return tkn.scanNumber(true)
172			} else {
173				return int(ch), nil
174			}
175		case '/':
176			switch tkn.lastChar {
177			case '/':
178				tkn.next()
179				return tkn.scanCommentType1("//")
180			case '*':
181				tkn.next()
182				return tkn.scanCommentType2()
183			default:
184				return int(ch), nil
185			}
186		case '-':
187			if tkn.lastChar == '-' {
188				tkn.next()
189				return tkn.scanCommentType1("--")
190			} else {
191				return int(ch), nil
192			}
193		case '<':
194			switch tkn.lastChar {
195			case '>':
196				tkn.next()
197				return NE, nil
198			case '=':
199				tkn.next()
200				switch tkn.lastChar {
201				case '>':
202					tkn.next()
203					return NULL_SAFE_EQUAL, nil
204				default:
205					return LE, nil
206				}
207			default:
208				return int(ch), nil
209			}
210		case '>':
211			if tkn.lastChar == '=' {
212				tkn.next()
213				return GE, nil
214			} else {
215				return int(ch), nil
216			}
217		case '!':
218			if tkn.lastChar == '=' {
219				tkn.next()
220				return NE, nil
221			} else {
222				return LEX_ERROR, []byte("!")
223			}
224		case '\'', '"':
225			return tkn.scanString(ch, STRING)
226		case '`':
227			return tkn.scanLiteralIdentifier()
228		default:
229			return LEX_ERROR, []byte{byte(ch)}
230		}
231	}
232}
233
234func (tkn *Tokenizer) skipBlank() {
235	ch := tkn.lastChar
236	for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' {
237		tkn.next()
238		ch = tkn.lastChar
239	}
240}
241
242func (tkn *Tokenizer) scanIdentifier() (int, []byte) {
243	buffer := bytes.NewBuffer(make([]byte, 0, 8))
244	buffer.WriteByte(byte(tkn.lastChar))
245	for tkn.next(); isLetter(tkn.lastChar) || isDigit(tkn.lastChar); tkn.next() {
246		buffer.WriteByte(byte(tkn.lastChar))
247	}
248	lowered := bytes.ToLower(buffer.Bytes())
249	if keywordId, found := keywords[string(lowered)]; found {
250		return keywordId, lowered
251	}
252	return ID, buffer.Bytes()
253}
254
255func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) {
256	buffer := bytes.NewBuffer(make([]byte, 0, 8))
257	buffer.WriteByte(byte(tkn.lastChar))
258	if !isLetter(tkn.lastChar) {
259		return LEX_ERROR, buffer.Bytes()
260	}
261	for tkn.next(); isLetter(tkn.lastChar) || isDigit(tkn.lastChar); tkn.next() {
262		buffer.WriteByte(byte(tkn.lastChar))
263	}
264	if tkn.lastChar != '`' {
265		return LEX_ERROR, buffer.Bytes()
266	}
267	tkn.next()
268	return ID, buffer.Bytes()
269}
270
271func (tkn *Tokenizer) scanBindVar() (int, []byte) {
272	buffer := bytes.NewBuffer(make([]byte, 0, 8))
273	buffer.WriteByte(byte(tkn.lastChar))
274	token := VALUE_ARG
275	tkn.next()
276	if tkn.lastChar == ':' {
277		token = LIST_ARG
278		buffer.WriteByte(byte(tkn.lastChar))
279		tkn.next()
280	}
281	if !isLetter(tkn.lastChar) {
282		return LEX_ERROR, buffer.Bytes()
283	}
284	for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || tkn.lastChar == '.' {
285		buffer.WriteByte(byte(tkn.lastChar))
286		tkn.next()
287	}
288	return token, buffer.Bytes()
289}
290
291func (tkn *Tokenizer) scanMantissa(base int, buffer *bytes.Buffer) {
292	for digitVal(tkn.lastChar) < base {
293		tkn.ConsumeNext(buffer)
294	}
295}
296
297func (tkn *Tokenizer) scanNumber(seenDecimalPoint bool) (int, []byte) {
298	buffer := bytes.NewBuffer(make([]byte, 0, 8))
299	if seenDecimalPoint {
300		buffer.WriteByte('.')
301		tkn.scanMantissa(10, buffer)
302		goto exponent
303	}
304
305	if tkn.lastChar == '0' {
306		// int or float
307		tkn.ConsumeNext(buffer)
308		if tkn.lastChar == 'x' || tkn.lastChar == 'X' {
309			// hexadecimal int
310			tkn.ConsumeNext(buffer)
311			tkn.scanMantissa(16, buffer)
312		} else {
313			// octal int or float
314			seenDecimalDigit := false
315			tkn.scanMantissa(8, buffer)
316			if tkn.lastChar == '8' || tkn.lastChar == '9' {
317				// illegal octal int or float
318				seenDecimalDigit = true
319				tkn.scanMantissa(10, buffer)
320			}
321			if tkn.lastChar == '.' || tkn.lastChar == 'e' || tkn.lastChar == 'E' {
322				goto fraction
323			}
324			// octal int
325			if seenDecimalDigit {
326				return LEX_ERROR, buffer.Bytes()
327			}
328		}
329		goto exit
330	}
331
332	// decimal int or float
333	tkn.scanMantissa(10, buffer)
334
335fraction:
336	if tkn.lastChar == '.' {
337		tkn.ConsumeNext(buffer)
338		tkn.scanMantissa(10, buffer)
339	}
340
341exponent:
342	if tkn.lastChar == 'e' || tkn.lastChar == 'E' {
343		tkn.ConsumeNext(buffer)
344		if tkn.lastChar == '+' || tkn.lastChar == '-' {
345			tkn.ConsumeNext(buffer)
346		}
347		tkn.scanMantissa(10, buffer)
348	}
349
350exit:
351	return NUMBER, buffer.Bytes()
352}
353
354func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) {
355	buffer := bytes.NewBuffer(make([]byte, 0, 8))
356	for {
357		ch := tkn.lastChar
358		tkn.next()
359		if ch == delim {
360			if tkn.lastChar == delim {
361				tkn.next()
362			} else {
363				break
364			}
365		} else if ch == '\\' {
366			if tkn.lastChar == EOFCHAR {
367				return LEX_ERROR, buffer.Bytes()
368			}
369			if decodedChar := sqltypes.SqlDecodeMap[byte(tkn.lastChar)]; decodedChar == sqltypes.DONTESCAPE {
370				ch = tkn.lastChar
371			} else {
372				ch = uint16(decodedChar)
373			}
374			tkn.next()
375		}
376		if ch == EOFCHAR {
377			return LEX_ERROR, buffer.Bytes()
378		}
379		buffer.WriteByte(byte(ch))
380	}
381	return typ, buffer.Bytes()
382}
383
384func (tkn *Tokenizer) scanCommentType1(prefix string) (int, []byte) {
385	buffer := bytes.NewBuffer(make([]byte, 0, 8))
386	buffer.WriteString(prefix)
387	for tkn.lastChar != EOFCHAR {
388		if tkn.lastChar == '\n' {
389			tkn.ConsumeNext(buffer)
390			break
391		}
392		tkn.ConsumeNext(buffer)
393	}
394	return COMMENT, buffer.Bytes()
395}
396
397func (tkn *Tokenizer) scanCommentType2() (int, []byte) {
398	buffer := bytes.NewBuffer(make([]byte, 0, 8))
399	buffer.WriteString("/*")
400	for {
401		if tkn.lastChar == '*' {
402			tkn.ConsumeNext(buffer)
403			if tkn.lastChar == '/' {
404				tkn.ConsumeNext(buffer)
405				break
406			}
407			continue
408		}
409		if tkn.lastChar == EOFCHAR {
410			return LEX_ERROR, buffer.Bytes()
411		}
412		tkn.ConsumeNext(buffer)
413	}
414	return COMMENT, buffer.Bytes()
415}
416
417func (tkn *Tokenizer) ConsumeNext(buffer *bytes.Buffer) {
418	if tkn.lastChar == EOFCHAR {
419		// This should never happen.
420		panic("unexpected EOF")
421	}
422	buffer.WriteByte(byte(tkn.lastChar))
423	tkn.next()
424}
425
426func (tkn *Tokenizer) next() {
427	if ch, err := tkn.InStream.ReadByte(); err != nil {
428		// Only EOF is possible.
429		tkn.lastChar = EOFCHAR
430	} else {
431		tkn.lastChar = uint16(ch)
432	}
433	tkn.Position++
434}
435
436func isLetter(ch uint16) bool {
437	return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '@'
438}
439
440func digitVal(ch uint16) int {
441	switch {
442	case '0' <= ch && ch <= '9':
443		return int(ch) - '0'
444	case 'a' <= ch && ch <= 'f':
445		return int(ch) - 'a' + 10
446	case 'A' <= ch && ch <= 'F':
447		return int(ch) - 'A' + 10
448	}
449	return 16 // larger than any legal digit val
450}
451
452func isDigit(ch uint16) bool {
453	return '0' <= ch && ch <= '9'
454}
455