1// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file.
5
6// +build !sqlite_omit_load_extension
7
8package sqlite3
9
10/*
11#ifndef USE_LIBSQLITE3
12#include "sqlite3-binding.h"
13#else
14#include <sqlite3.h>
15#endif
16#include <stdlib.h>
17*/
18import "C"
19import (
20	"errors"
21	"unsafe"
22)
23
24func (c *SQLiteConn) loadExtensions(extensions []string) error {
25	rv := C.sqlite3_enable_load_extension(c.db, 1)
26	if rv != C.SQLITE_OK {
27		return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
28	}
29
30	for _, extension := range extensions {
31		if err := c.loadExtension(extension, nil); err != nil {
32			C.sqlite3_enable_load_extension(c.db, 0)
33			return err
34		}
35	}
36
37	rv = C.sqlite3_enable_load_extension(c.db, 0)
38	if rv != C.SQLITE_OK {
39		return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
40	}
41
42	return nil
43}
44
45// LoadExtension load the sqlite3 extension.
46func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
47	rv := C.sqlite3_enable_load_extension(c.db, 1)
48	if rv != C.SQLITE_OK {
49		return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
50	}
51
52	if err := c.loadExtension(lib, &entry); err != nil {
53		C.sqlite3_enable_load_extension(c.db, 0)
54		return err
55	}
56
57	rv = C.sqlite3_enable_load_extension(c.db, 0)
58	if rv != C.SQLITE_OK {
59		return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
60	}
61
62	return nil
63}
64
65func (c *SQLiteConn) loadExtension(lib string, entry *string) error {
66	clib := C.CString(lib)
67	defer C.free(unsafe.Pointer(clib))
68
69	var centry *C.char
70	if entry != nil {
71		centry = C.CString(*entry)
72		defer C.free(unsafe.Pointer(centry))
73	}
74
75	var errMsg *C.char
76	defer C.sqlite3_free(unsafe.Pointer(errMsg))
77
78	rv := C.sqlite3_load_extension(c.db, clib, centry, &errMsg)
79	if rv != C.SQLITE_OK {
80		return errors.New(C.GoString(errMsg))
81	}
82
83	return nil
84}
85