1// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>. 2// 3// Use of this source code is governed by an MIT-style 4// license that can be found in the LICENSE file. 5 6package sqlite3 7 8// You can't export a Go function to C and have definitions in the C 9// preamble in the same file, so we have to have callbackTrampoline in 10// its own file. Because we need a separate file anyway, the support 11// code for SQLite custom functions is in here. 12 13/* 14#ifndef USE_LIBSQLITE3 15#include <sqlite3-binding.h> 16#else 17#include <sqlite3.h> 18#endif 19#include <stdlib.h> 20 21void _sqlite3_result_text(sqlite3_context* ctx, const char* s); 22void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l); 23*/ 24import "C" 25 26import ( 27 "errors" 28 "fmt" 29 "math" 30 "reflect" 31 "sync" 32 "unsafe" 33) 34 35//export callbackTrampoline 36func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { 37 args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] 38 fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo) 39 fi.Call(ctx, args) 40} 41 42//export stepTrampoline 43func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) { 44 args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)] 45 ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo) 46 ai.Step(ctx, args) 47} 48 49//export doneTrampoline 50func doneTrampoline(ctx *C.sqlite3_context) { 51 handle := uintptr(C.sqlite3_user_data(ctx)) 52 ai := lookupHandle(handle).(*aggInfo) 53 ai.Done(ctx) 54} 55 56//export compareTrampoline 57func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int { 58 cmp := lookupHandle(handlePtr).(func(string, string) int) 59 return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb))) 60} 61 62//export commitHookTrampoline 63func commitHookTrampoline(handle uintptr) int { 64 callback := lookupHandle(handle).(func() int) 65 return callback() 66} 67 68//export rollbackHookTrampoline 69func rollbackHookTrampoline(handle uintptr) { 70 callback := lookupHandle(handle).(func()) 71 callback() 72} 73 74//export updateHookTrampoline 75func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) { 76 callback := lookupHandle(handle).(func(int, string, string, int64)) 77 callback(op, C.GoString(db), C.GoString(table), rowid) 78} 79 80//export authorizerTrampoline 81func authorizerTrampoline(handle uintptr, op int, arg1 *C.char, arg2 *C.char, arg3 *C.char) int { 82 callback := lookupHandle(handle).(func(int, string, string, string) int) 83 return callback(op, C.GoString(arg1), C.GoString(arg2), C.GoString(arg3)) 84} 85 86// Use handles to avoid passing Go pointers to C. 87 88type handleVal struct { 89 db *SQLiteConn 90 val interface{} 91} 92 93var handleLock sync.Mutex 94var handleVals = make(map[uintptr]handleVal) 95var handleIndex uintptr = 100 96 97func newHandle(db *SQLiteConn, v interface{}) uintptr { 98 handleLock.Lock() 99 defer handleLock.Unlock() 100 i := handleIndex 101 handleIndex++ 102 handleVals[i] = handleVal{db, v} 103 return i 104} 105 106func lookupHandle(handle uintptr) interface{} { 107 handleLock.Lock() 108 defer handleLock.Unlock() 109 r, ok := handleVals[handle] 110 if !ok { 111 if handle >= 100 && handle < handleIndex { 112 panic("deleted handle") 113 } else { 114 panic("invalid handle") 115 } 116 } 117 return r.val 118} 119 120func deleteHandles(db *SQLiteConn) { 121 handleLock.Lock() 122 defer handleLock.Unlock() 123 for handle, val := range handleVals { 124 if val.db == db { 125 delete(handleVals, handle) 126 } 127 } 128} 129 130// This is only here so that tests can refer to it. 131type callbackArgRaw C.sqlite3_value 132 133type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error) 134 135type callbackArgCast struct { 136 f callbackArgConverter 137 typ reflect.Type 138} 139 140func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) { 141 val, err := c.f(v) 142 if err != nil { 143 return reflect.Value{}, err 144 } 145 if !val.Type().ConvertibleTo(c.typ) { 146 return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ) 147 } 148 return val.Convert(c.typ), nil 149} 150 151func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) { 152 if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { 153 return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") 154 } 155 return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil 156} 157 158func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) { 159 if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { 160 return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") 161 } 162 i := int64(C.sqlite3_value_int64(v)) 163 val := false 164 if i != 0 { 165 val = true 166 } 167 return reflect.ValueOf(val), nil 168} 169 170func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) { 171 if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { 172 return reflect.Value{}, fmt.Errorf("argument must be a FLOAT") 173 } 174 return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil 175} 176 177func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) { 178 switch C.sqlite3_value_type(v) { 179 case C.SQLITE_BLOB: 180 l := C.sqlite3_value_bytes(v) 181 p := C.sqlite3_value_blob(v) 182 return reflect.ValueOf(C.GoBytes(p, l)), nil 183 case C.SQLITE_TEXT: 184 l := C.sqlite3_value_bytes(v) 185 c := unsafe.Pointer(C.sqlite3_value_text(v)) 186 return reflect.ValueOf(C.GoBytes(c, l)), nil 187 default: 188 return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") 189 } 190} 191 192func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) { 193 switch C.sqlite3_value_type(v) { 194 case C.SQLITE_BLOB: 195 l := C.sqlite3_value_bytes(v) 196 p := (*C.char)(C.sqlite3_value_blob(v)) 197 return reflect.ValueOf(C.GoStringN(p, l)), nil 198 case C.SQLITE_TEXT: 199 c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) 200 return reflect.ValueOf(C.GoString(c)), nil 201 default: 202 return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") 203 } 204} 205 206func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) { 207 switch C.sqlite3_value_type(v) { 208 case C.SQLITE_INTEGER: 209 return callbackArgInt64(v) 210 case C.SQLITE_FLOAT: 211 return callbackArgFloat64(v) 212 case C.SQLITE_TEXT: 213 return callbackArgString(v) 214 case C.SQLITE_BLOB: 215 return callbackArgBytes(v) 216 case C.SQLITE_NULL: 217 // Interpret NULL as a nil byte slice. 218 var ret []byte 219 return reflect.ValueOf(ret), nil 220 default: 221 panic("unreachable") 222 } 223} 224 225func callbackArg(typ reflect.Type) (callbackArgConverter, error) { 226 switch typ.Kind() { 227 case reflect.Interface: 228 if typ.NumMethod() != 0 { 229 return nil, errors.New("the only supported interface type is interface{}") 230 } 231 return callbackArgGeneric, nil 232 case reflect.Slice: 233 if typ.Elem().Kind() != reflect.Uint8 { 234 return nil, errors.New("the only supported slice type is []byte") 235 } 236 return callbackArgBytes, nil 237 case reflect.String: 238 return callbackArgString, nil 239 case reflect.Bool: 240 return callbackArgBool, nil 241 case reflect.Int64: 242 return callbackArgInt64, nil 243 case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: 244 c := callbackArgCast{callbackArgInt64, typ} 245 return c.Run, nil 246 case reflect.Float64: 247 return callbackArgFloat64, nil 248 case reflect.Float32: 249 c := callbackArgCast{callbackArgFloat64, typ} 250 return c.Run, nil 251 default: 252 return nil, fmt.Errorf("don't know how to convert to %s", typ) 253 } 254} 255 256func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) { 257 var args []reflect.Value 258 259 if len(argv) < len(converters) { 260 return nil, fmt.Errorf("function requires at least %d arguments", len(converters)) 261 } 262 263 for i, arg := range argv[:len(converters)] { 264 v, err := converters[i](arg) 265 if err != nil { 266 return nil, err 267 } 268 args = append(args, v) 269 } 270 271 if variadic != nil { 272 for _, arg := range argv[len(converters):] { 273 v, err := variadic(arg) 274 if err != nil { 275 return nil, err 276 } 277 args = append(args, v) 278 } 279 } 280 return args, nil 281} 282 283type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error 284 285func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error { 286 switch v.Type().Kind() { 287 case reflect.Int64: 288 case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: 289 v = v.Convert(reflect.TypeOf(int64(0))) 290 case reflect.Bool: 291 b := v.Interface().(bool) 292 if b { 293 v = reflect.ValueOf(int64(1)) 294 } else { 295 v = reflect.ValueOf(int64(0)) 296 } 297 default: 298 return fmt.Errorf("cannot convert %s to INTEGER", v.Type()) 299 } 300 301 C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64))) 302 return nil 303} 304 305func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error { 306 switch v.Type().Kind() { 307 case reflect.Float64: 308 case reflect.Float32: 309 v = v.Convert(reflect.TypeOf(float64(0))) 310 default: 311 return fmt.Errorf("cannot convert %s to FLOAT", v.Type()) 312 } 313 314 C.sqlite3_result_double(ctx, C.double(v.Interface().(float64))) 315 return nil 316} 317 318func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error { 319 if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 { 320 return fmt.Errorf("cannot convert %s to BLOB", v.Type()) 321 } 322 i := v.Interface() 323 if i == nil || len(i.([]byte)) == 0 { 324 C.sqlite3_result_null(ctx) 325 } else { 326 bs := i.([]byte) 327 C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs))) 328 } 329 return nil 330} 331 332func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error { 333 if v.Type().Kind() != reflect.String { 334 return fmt.Errorf("cannot convert %s to TEXT", v.Type()) 335 } 336 C._sqlite3_result_text(ctx, C.CString(v.Interface().(string))) 337 return nil 338} 339 340func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error { 341 return nil 342} 343 344func callbackRet(typ reflect.Type) (callbackRetConverter, error) { 345 switch typ.Kind() { 346 case reflect.Interface: 347 errorInterface := reflect.TypeOf((*error)(nil)).Elem() 348 if typ.Implements(errorInterface) { 349 return callbackRetNil, nil 350 } 351 fallthrough 352 case reflect.Slice: 353 if typ.Elem().Kind() != reflect.Uint8 { 354 return nil, errors.New("the only supported slice type is []byte") 355 } 356 return callbackRetBlob, nil 357 case reflect.String: 358 return callbackRetText, nil 359 case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: 360 return callbackRetInteger, nil 361 case reflect.Float32, reflect.Float64: 362 return callbackRetFloat, nil 363 default: 364 return nil, fmt.Errorf("don't know how to convert to %s", typ) 365 } 366} 367 368func callbackError(ctx *C.sqlite3_context, err error) { 369 cstr := C.CString(err.Error()) 370 defer C.free(unsafe.Pointer(cstr)) 371 C.sqlite3_result_error(ctx, cstr, -1) 372} 373 374// Test support code. Tests are not allowed to import "C", so we can't 375// declare any functions that use C.sqlite3_value. 376func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter { 377 return func(*C.sqlite3_value) (reflect.Value, error) { 378 return v, err 379 } 380} 381