1// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file.
5
6package sqlite3
7
8// You can't export a Go function to C and have definitions in the C
9// preamble in the same file, so we have to have callbackTrampoline in
10// its own file. Because we need a separate file anyway, the support
11// code for SQLite custom functions is in here.
12
13/*
14#ifndef USE_LIBSQLITE3
15#include <sqlite3-binding.h>
16#else
17#include <sqlite3.h>
18#endif
19#include <stdlib.h>
20
21void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
22void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
23*/
24import "C"
25
26import (
27	"errors"
28	"fmt"
29	"math"
30	"reflect"
31	"sync"
32	"unsafe"
33)
34
35//export callbackTrampoline
36func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
37	args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
38	fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo)
39	fi.Call(ctx, args)
40}
41
42//export stepTrampoline
43func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
44	args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
45	ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
46	ai.Step(ctx, args)
47}
48
49//export doneTrampoline
50func doneTrampoline(ctx *C.sqlite3_context) {
51	handle := uintptr(C.sqlite3_user_data(ctx))
52	ai := lookupHandle(handle).(*aggInfo)
53	ai.Done(ctx)
54}
55
56//export compareTrampoline
57func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int {
58	cmp := lookupHandle(handlePtr).(func(string, string) int)
59	return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
60}
61
62//export commitHookTrampoline
63func commitHookTrampoline(handle uintptr) int {
64	callback := lookupHandle(handle).(func() int)
65	return callback()
66}
67
68//export rollbackHookTrampoline
69func rollbackHookTrampoline(handle uintptr) {
70	callback := lookupHandle(handle).(func())
71	callback()
72}
73
74//export updateHookTrampoline
75func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
76	callback := lookupHandle(handle).(func(int, string, string, int64))
77	callback(op, C.GoString(db), C.GoString(table), rowid)
78}
79
80//export authorizerTrampoline
81func authorizerTrampoline(handle uintptr, op int, arg1 *C.char, arg2 *C.char, arg3 *C.char) int {
82	callback := lookupHandle(handle).(func(int, string, string, string) int)
83	return callback(op, C.GoString(arg1), C.GoString(arg2), C.GoString(arg3))
84}
85
86// Use handles to avoid passing Go pointers to C.
87
88type handleVal struct {
89	db  *SQLiteConn
90	val interface{}
91}
92
93var handleLock sync.Mutex
94var handleVals = make(map[uintptr]handleVal)
95var handleIndex uintptr = 100
96
97func newHandle(db *SQLiteConn, v interface{}) uintptr {
98	handleLock.Lock()
99	defer handleLock.Unlock()
100	i := handleIndex
101	handleIndex++
102	handleVals[i] = handleVal{db, v}
103	return i
104}
105
106func lookupHandle(handle uintptr) interface{} {
107	handleLock.Lock()
108	defer handleLock.Unlock()
109	r, ok := handleVals[handle]
110	if !ok {
111		if handle >= 100 && handle < handleIndex {
112			panic("deleted handle")
113		} else {
114			panic("invalid handle")
115		}
116	}
117	return r.val
118}
119
120func deleteHandles(db *SQLiteConn) {
121	handleLock.Lock()
122	defer handleLock.Unlock()
123	for handle, val := range handleVals {
124		if val.db == db {
125			delete(handleVals, handle)
126		}
127	}
128}
129
130// This is only here so that tests can refer to it.
131type callbackArgRaw C.sqlite3_value
132
133type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
134
135type callbackArgCast struct {
136	f   callbackArgConverter
137	typ reflect.Type
138}
139
140func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
141	val, err := c.f(v)
142	if err != nil {
143		return reflect.Value{}, err
144	}
145	if !val.Type().ConvertibleTo(c.typ) {
146		return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
147	}
148	return val.Convert(c.typ), nil
149}
150
151func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
152	if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
153		return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
154	}
155	return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
156}
157
158func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
159	if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
160		return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
161	}
162	i := int64(C.sqlite3_value_int64(v))
163	val := false
164	if i != 0 {
165		val = true
166	}
167	return reflect.ValueOf(val), nil
168}
169
170func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
171	if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
172		return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
173	}
174	return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
175}
176
177func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
178	switch C.sqlite3_value_type(v) {
179	case C.SQLITE_BLOB:
180		l := C.sqlite3_value_bytes(v)
181		p := C.sqlite3_value_blob(v)
182		return reflect.ValueOf(C.GoBytes(p, l)), nil
183	case C.SQLITE_TEXT:
184		l := C.sqlite3_value_bytes(v)
185		c := unsafe.Pointer(C.sqlite3_value_text(v))
186		return reflect.ValueOf(C.GoBytes(c, l)), nil
187	default:
188		return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
189	}
190}
191
192func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
193	switch C.sqlite3_value_type(v) {
194	case C.SQLITE_BLOB:
195		l := C.sqlite3_value_bytes(v)
196		p := (*C.char)(C.sqlite3_value_blob(v))
197		return reflect.ValueOf(C.GoStringN(p, l)), nil
198	case C.SQLITE_TEXT:
199		c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
200		return reflect.ValueOf(C.GoString(c)), nil
201	default:
202		return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
203	}
204}
205
206func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
207	switch C.sqlite3_value_type(v) {
208	case C.SQLITE_INTEGER:
209		return callbackArgInt64(v)
210	case C.SQLITE_FLOAT:
211		return callbackArgFloat64(v)
212	case C.SQLITE_TEXT:
213		return callbackArgString(v)
214	case C.SQLITE_BLOB:
215		return callbackArgBytes(v)
216	case C.SQLITE_NULL:
217		// Interpret NULL as a nil byte slice.
218		var ret []byte
219		return reflect.ValueOf(ret), nil
220	default:
221		panic("unreachable")
222	}
223}
224
225func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
226	switch typ.Kind() {
227	case reflect.Interface:
228		if typ.NumMethod() != 0 {
229			return nil, errors.New("the only supported interface type is interface{}")
230		}
231		return callbackArgGeneric, nil
232	case reflect.Slice:
233		if typ.Elem().Kind() != reflect.Uint8 {
234			return nil, errors.New("the only supported slice type is []byte")
235		}
236		return callbackArgBytes, nil
237	case reflect.String:
238		return callbackArgString, nil
239	case reflect.Bool:
240		return callbackArgBool, nil
241	case reflect.Int64:
242		return callbackArgInt64, nil
243	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
244		c := callbackArgCast{callbackArgInt64, typ}
245		return c.Run, nil
246	case reflect.Float64:
247		return callbackArgFloat64, nil
248	case reflect.Float32:
249		c := callbackArgCast{callbackArgFloat64, typ}
250		return c.Run, nil
251	default:
252		return nil, fmt.Errorf("don't know how to convert to %s", typ)
253	}
254}
255
256func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
257	var args []reflect.Value
258
259	if len(argv) < len(converters) {
260		return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
261	}
262
263	for i, arg := range argv[:len(converters)] {
264		v, err := converters[i](arg)
265		if err != nil {
266			return nil, err
267		}
268		args = append(args, v)
269	}
270
271	if variadic != nil {
272		for _, arg := range argv[len(converters):] {
273			v, err := variadic(arg)
274			if err != nil {
275				return nil, err
276			}
277			args = append(args, v)
278		}
279	}
280	return args, nil
281}
282
283type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
284
285func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
286	switch v.Type().Kind() {
287	case reflect.Int64:
288	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
289		v = v.Convert(reflect.TypeOf(int64(0)))
290	case reflect.Bool:
291		b := v.Interface().(bool)
292		if b {
293			v = reflect.ValueOf(int64(1))
294		} else {
295			v = reflect.ValueOf(int64(0))
296		}
297	default:
298		return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
299	}
300
301	C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
302	return nil
303}
304
305func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
306	switch v.Type().Kind() {
307	case reflect.Float64:
308	case reflect.Float32:
309		v = v.Convert(reflect.TypeOf(float64(0)))
310	default:
311		return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
312	}
313
314	C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
315	return nil
316}
317
318func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
319	if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
320		return fmt.Errorf("cannot convert %s to BLOB", v.Type())
321	}
322	i := v.Interface()
323	if i == nil || len(i.([]byte)) == 0 {
324		C.sqlite3_result_null(ctx)
325	} else {
326		bs := i.([]byte)
327		C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
328	}
329	return nil
330}
331
332func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
333	if v.Type().Kind() != reflect.String {
334		return fmt.Errorf("cannot convert %s to TEXT", v.Type())
335	}
336	C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
337	return nil
338}
339
340func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
341	return nil
342}
343
344func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
345	switch typ.Kind() {
346	case reflect.Interface:
347		errorInterface := reflect.TypeOf((*error)(nil)).Elem()
348		if typ.Implements(errorInterface) {
349			return callbackRetNil, nil
350		}
351		fallthrough
352	case reflect.Slice:
353		if typ.Elem().Kind() != reflect.Uint8 {
354			return nil, errors.New("the only supported slice type is []byte")
355		}
356		return callbackRetBlob, nil
357	case reflect.String:
358		return callbackRetText, nil
359	case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
360		return callbackRetInteger, nil
361	case reflect.Float32, reflect.Float64:
362		return callbackRetFloat, nil
363	default:
364		return nil, fmt.Errorf("don't know how to convert to %s", typ)
365	}
366}
367
368func callbackError(ctx *C.sqlite3_context, err error) {
369	cstr := C.CString(err.Error())
370	defer C.free(unsafe.Pointer(cstr))
371	C.sqlite3_result_error(ctx, cstr, -1)
372}
373
374// Test support code. Tests are not allowed to import "C", so we can't
375// declare any functions that use C.sqlite3_value.
376func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
377	return func(*C.sqlite3_value) (reflect.Value, error) {
378		return v, err
379	}
380}
381