1// Copyright 2012, Google Inc. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package sqltypes implements interfaces and types that represent SQL values.
6package sqltypes
7
8import (
9	"encoding/base64"
10	"encoding/gob"
11	"encoding/json"
12	"fmt"
13	"reflect"
14	"strconv"
15	"time"
16	"unsafe"
17)
18
19var (
20	NULL       = Value{}
21	DONTESCAPE = byte(255)
22	nullstr    = []byte("null")
23)
24
25// BinWriter interface is used for encoding values.
26// Types like bytes.Buffer conform to this interface.
27// We expect the writer objects to be in-memory buffers.
28// So, we don't expect the write operations to fail.
29type BinWriter interface {
30	Write([]byte) (int, error)
31	WriteByte(byte) error
32}
33
34// Value can store any SQL value. NULL is stored as nil.
35type Value struct {
36	Inner InnerValue
37}
38
39// Numeric represents non-fractional SQL number.
40type Numeric []byte
41
42// Fractional represents fractional types like float and decimal
43// It's functionally equivalent to Numeric other than how it's constructed
44type Fractional []byte
45
46// String represents any SQL type that needs to be represented using quotes.
47type String []byte
48
49// MakeNumeric makes a Numeric from a []byte without validation.
50func MakeNumeric(b []byte) Value {
51	return Value{Numeric(b)}
52}
53
54// MakeFractional makes a Fractional value from a []byte without validation.
55func MakeFractional(b []byte) Value {
56	return Value{Fractional(b)}
57}
58
59// MakeString makes a String value from a []byte.
60func MakeString(b []byte) Value {
61	return Value{String(b)}
62}
63
64// Raw returns the raw bytes. All types are currently implemented as []byte.
65func (v Value) Raw() []byte {
66	if v.Inner == nil {
67		return nil
68	}
69	return v.Inner.raw()
70}
71
72// String returns the raw value as a string
73func (v Value) String() string {
74	if v.Inner == nil {
75		return ""
76	}
77	return toString(v.Inner.raw())
78}
79
80// String force casts a []byte to a string.
81// USE AT YOUR OWN RISK
82func toString(b []byte) (s string) {
83	if len(b) == 0 {
84		return ""
85	}
86	pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
87	pstring := (*reflect.StringHeader)(unsafe.Pointer(&s))
88	pstring.Data = pbytes.Data
89	pstring.Len = pbytes.Len
90	return
91}
92
93// ParseInt64 will parse a Numeric value into an int64
94func (v Value) ParseInt64() (val int64, err error) {
95	if v.Inner == nil {
96		return 0, fmt.Errorf("value is null")
97	}
98	n, ok := v.Inner.(Numeric)
99	if !ok {
100		return 0, fmt.Errorf("value is not Numeric")
101	}
102	return strconv.ParseInt(string(n.raw()), 10, 64)
103}
104
105// ParseUint64 will parse a Numeric value into a uint64
106func (v Value) ParseUint64() (val uint64, err error) {
107	if v.Inner == nil {
108		return 0, fmt.Errorf("value is null")
109	}
110	n, ok := v.Inner.(Numeric)
111	if !ok {
112		return 0, fmt.Errorf("value is not Numeric")
113	}
114	return strconv.ParseUint(string(n.raw()), 10, 64)
115}
116
117// ParseFloat64 will parse a Fractional value into an float64
118func (v Value) ParseFloat64() (val float64, err error) {
119	if v.Inner == nil {
120		return 0, fmt.Errorf("value is null")
121	}
122	n, ok := v.Inner.(Fractional)
123	if !ok {
124		return 0, fmt.Errorf("value is not Fractional")
125	}
126	return strconv.ParseFloat(string(n.raw()), 64)
127}
128
129// EncodeSql encodes the value into an SQL statement. Can be binary.
130func (v Value) EncodeSql(b BinWriter) {
131	if v.Inner == nil {
132		if _, err := b.Write(nullstr); err != nil {
133			panic(err)
134		}
135	} else {
136		v.Inner.encodeSql(b)
137	}
138}
139
140// EncodeAscii encodes the value using 7-bit clean ascii bytes.
141func (v Value) EncodeAscii(b BinWriter) {
142	if v.Inner == nil {
143		if _, err := b.Write(nullstr); err != nil {
144			panic(err)
145		}
146	} else {
147		v.Inner.encodeAscii(b)
148	}
149}
150
151func (v Value) IsNull() bool {
152	return v.Inner == nil
153}
154
155func (v Value) IsNumeric() (ok bool) {
156	if v.Inner != nil {
157		_, ok = v.Inner.(Numeric)
158	}
159	return ok
160}
161
162func (v Value) IsFractional() (ok bool) {
163	if v.Inner != nil {
164		_, ok = v.Inner.(Fractional)
165	}
166	return ok
167}
168
169func (v Value) IsString() (ok bool) {
170	if v.Inner != nil {
171		_, ok = v.Inner.(String)
172	}
173	return ok
174}
175
176// MarshalJSON should only be used for testing.
177// It's not a complete implementation.
178func (v Value) MarshalJSON() ([]byte, error) {
179	return json.Marshal(v.Inner)
180}
181
182// UnmarshalJSON should only be used for testing.
183// It's not a complete implementation.
184func (v *Value) UnmarshalJSON(b []byte) error {
185	if len(b) == 0 {
186		return fmt.Errorf("error unmarshaling empty bytes")
187	}
188	var val interface{}
189	var err error
190	switch b[0] {
191	case '-':
192		var ival int64
193		err = json.Unmarshal(b, &ival)
194		val = ival
195	case '"':
196		var bval []byte
197		err = json.Unmarshal(b, &bval)
198		val = bval
199	case 'n': // null
200		err = json.Unmarshal(b, &val)
201	default:
202		var uval uint64
203		err = json.Unmarshal(b, &uval)
204		val = uval
205	}
206	if err != nil {
207		return err
208	}
209	*v, err = BuildValue(val)
210	return err
211}
212
213// InnerValue defines methods that need to be supported by all non-null value types.
214type InnerValue interface {
215	raw() []byte
216	encodeSql(BinWriter)
217	encodeAscii(BinWriter)
218}
219
220func BuildValue(goval interface{}) (v Value, err error) {
221	switch bindVal := goval.(type) {
222	case nil:
223		// no op
224	case int:
225		v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))}
226	case int32:
227		v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))}
228	case int64:
229		v = Value{Numeric(strconv.AppendInt(nil, int64(bindVal), 10))}
230	case uint:
231		v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))}
232	case uint32:
233		v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))}
234	case uint64:
235		v = Value{Numeric(strconv.AppendUint(nil, uint64(bindVal), 10))}
236	case float64:
237		v = Value{Fractional(strconv.AppendFloat(nil, bindVal, 'f', -1, 64))}
238	case string:
239		v = Value{String([]byte(bindVal))}
240	case []byte:
241		v = Value{String(bindVal)}
242	case time.Time:
243		v = Value{String([]byte(bindVal.Format("2006-01-02 15:04:05")))}
244	case Numeric, Fractional, String:
245		v = Value{bindVal.(InnerValue)}
246	case Value:
247		v = bindVal
248	default:
249		return Value{}, fmt.Errorf("unsupported bind variable type %T: %v", goval, goval)
250	}
251	return v, nil
252}
253
254// BuildNumeric builds a Numeric type that represents any whole number.
255// It normalizes the representation to ensure 1:1 mapping between the
256// number and its representation.
257func BuildNumeric(val string) (n Value, err error) {
258	if val[0] == '-' || val[0] == '+' {
259		signed, err := strconv.ParseInt(val, 0, 64)
260		if err != nil {
261			return Value{}, err
262		}
263		n = Value{Numeric(strconv.AppendInt(nil, signed, 10))}
264	} else {
265		unsigned, err := strconv.ParseUint(val, 0, 64)
266		if err != nil {
267			return Value{}, err
268		}
269		n = Value{Numeric(strconv.AppendUint(nil, unsigned, 10))}
270	}
271	return n, nil
272}
273
274func (n Numeric) raw() []byte {
275	return []byte(n)
276}
277
278func (n Numeric) encodeSql(b BinWriter) {
279	if _, err := b.Write(n.raw()); err != nil {
280		panic(err)
281	}
282}
283
284func (n Numeric) encodeAscii(b BinWriter) {
285	if _, err := b.Write(n.raw()); err != nil {
286		panic(err)
287	}
288}
289
290func (n Numeric) MarshalJSON() ([]byte, error) {
291	return n.raw(), nil
292}
293
294func (f Fractional) raw() []byte {
295	return []byte(f)
296}
297
298func (f Fractional) encodeSql(b BinWriter) {
299	if _, err := b.Write(f.raw()); err != nil {
300		panic(err)
301	}
302}
303
304func (f Fractional) encodeAscii(b BinWriter) {
305	if _, err := b.Write(f.raw()); err != nil {
306		panic(err)
307	}
308}
309
310func (s String) MarshalJSON() ([]byte, error) {
311	return json.Marshal(string(s.raw()))
312}
313
314func (s String) raw() []byte {
315	return []byte(s)
316}
317
318func (s String) encodeSql(b BinWriter) {
319	writebyte(b, '\'')
320	for _, ch := range s.raw() {
321		if encodedChar := SqlEncodeMap[ch]; encodedChar == DONTESCAPE {
322			writebyte(b, ch)
323		} else {
324			writebyte(b, '\\')
325			writebyte(b, encodedChar)
326		}
327	}
328	writebyte(b, '\'')
329}
330
331func (s String) encodeAscii(b BinWriter) {
332	writebyte(b, '\'')
333	encoder := base64.NewEncoder(base64.StdEncoding, b)
334	encoder.Write(s.raw())
335	encoder.Close()
336	writebyte(b, '\'')
337}
338
339func writebyte(b BinWriter, c byte) {
340	if err := b.WriteByte(c); err != nil {
341		panic(err)
342	}
343}
344
345// SqlEncodeMap specifies how to escape binary data with '\'.
346// Complies to http://dev.mysql.com/doc/refman/5.1/en/string-syntax.html
347var SqlEncodeMap [256]byte
348
349// SqlDecodeMap is the reverse of SqlEncodeMap
350var SqlDecodeMap [256]byte
351
352var encodeRef = map[byte]byte{
353	'\x00': '0',
354	'\'':   '\'',
355	'"':    '"',
356	'\b':   'b',
357	'\n':   'n',
358	'\r':   'r',
359	'\t':   't',
360	26:     'Z', // ctl-Z
361	'\\':   '\\',
362}
363
364func init() {
365	for i := range SqlEncodeMap {
366		SqlEncodeMap[i] = DONTESCAPE
367		SqlDecodeMap[i] = DONTESCAPE
368	}
369	for i := range SqlEncodeMap {
370		if to, ok := encodeRef[byte(i)]; ok {
371			SqlEncodeMap[byte(i)] = to
372			SqlDecodeMap[to] = byte(i)
373		}
374	}
375	gob.Register(Numeric(nil))
376	gob.Register(Fractional(nil))
377	gob.Register(String(nil))
378}
379