1// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file.
5
6// +build sqlite_trace trace
7
8package sqlite3
9
10/*
11#ifndef USE_LIBSQLITE3
12#include <sqlite3-binding.h>
13#else
14#include <sqlite3.h>
15#endif
16#include <stdlib.h>
17
18int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
19*/
20import "C"
21
22import (
23	"fmt"
24	"strings"
25	"sync"
26	"unsafe"
27)
28
29// Trace... constants identify the possible events causing callback invocation.
30// Values are same as the corresponding SQLite Trace Event Codes.
31const (
32	TraceStmt    = uint32(C.SQLITE_TRACE_STMT)
33	TraceProfile = uint32(C.SQLITE_TRACE_PROFILE)
34	TraceRow     = uint32(C.SQLITE_TRACE_ROW)
35	TraceClose   = uint32(C.SQLITE_TRACE_CLOSE)
36)
37
38type TraceInfo struct {
39	// Pack together the shorter fields, to keep the struct smaller.
40	// On a 64-bit machine there would be padding
41	// between EventCode and ConnHandle; having AutoCommit here is "free":
42	EventCode  uint32
43	AutoCommit bool
44	ConnHandle uintptr
45
46	// Usually filled, unless EventCode = TraceClose = SQLITE_TRACE_CLOSE:
47	// identifier for a prepared statement:
48	StmtHandle uintptr
49
50	// Two strings filled when EventCode = TraceStmt = SQLITE_TRACE_STMT:
51	// (1) either the unexpanded SQL text of the prepared statement, or
52	//     an SQL comment that indicates the invocation of a trigger;
53	// (2) expanded SQL, if requested and if (1) is not an SQL comment.
54	StmtOrTrigger string
55	ExpandedSQL   string // only if requested (TraceConfig.WantExpandedSQL = true)
56
57	// filled when EventCode = TraceProfile = SQLITE_TRACE_PROFILE:
58	// estimated number of nanoseconds that the prepared statement took to run:
59	RunTimeNanosec int64
60
61	DBError Error
62}
63
64// TraceUserCallback gives the signature for a trace function
65// provided by the user (Go application programmer).
66// SQLite 3.14 documentation (as of September 2, 2016)
67// for SQL Trace Hook = sqlite3_trace_v2():
68// The integer return value from the callback is currently ignored,
69// though this may change in future releases. Callback implementations
70// should return zero to ensure future compatibility.
71type TraceUserCallback func(TraceInfo) int
72
73type TraceConfig struct {
74	Callback        TraceUserCallback
75	EventMask       uint32
76	WantExpandedSQL bool
77}
78
79func fillDBError(dbErr *Error, db *C.sqlite3) {
80	// See SQLiteConn.lastError(), in file 'sqlite3.go' at the time of writing (Sept 5, 2016)
81	dbErr.Code = ErrNo(C.sqlite3_errcode(db))
82	dbErr.ExtendedCode = ErrNoExtended(C.sqlite3_extended_errcode(db))
83	dbErr.err = C.GoString(C.sqlite3_errmsg(db))
84}
85
86func fillExpandedSQL(info *TraceInfo, db *C.sqlite3, pStmt unsafe.Pointer) {
87	if pStmt == nil {
88		panic("No SQLite statement pointer in P arg of trace_v2 callback")
89	}
90
91	expSQLiteCStr := C.sqlite3_expanded_sql((*C.sqlite3_stmt)(pStmt))
92	defer C.sqlite3_free(unsafe.Pointer(expSQLiteCStr))
93	if expSQLiteCStr == nil {
94		fillDBError(&info.DBError, db)
95		return
96	}
97	info.ExpandedSQL = C.GoString(expSQLiteCStr)
98}
99
100//export traceCallbackTrampoline
101func traceCallbackTrampoline(
102	traceEventCode C.uint,
103	// Parameter named 'C' in SQLite docs = Context given at registration:
104	ctx unsafe.Pointer,
105	// Parameter named 'P' in SQLite docs (Primary event data?):
106	p unsafe.Pointer,
107	// Parameter named 'X' in SQLite docs (eXtra event data?):
108	xValue unsafe.Pointer) C.int {
109
110	eventCode := uint32(traceEventCode)
111
112	if ctx == nil {
113		panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode))
114	}
115
116	contextDB := (*C.sqlite3)(ctx)
117	connHandle := uintptr(ctx)
118
119	var traceConf TraceConfig
120	var found bool
121	if eventCode == TraceClose {
122		// clean up traceMap: 'pop' means get and delete
123		traceConf, found = popTraceMapping(connHandle)
124	} else {
125		traceConf, found = lookupTraceMapping(connHandle)
126	}
127
128	if !found {
129		panic(fmt.Sprintf("Mapping not found for handle 0x%x (ev 0x%x)",
130			connHandle, eventCode))
131	}
132
133	var info TraceInfo
134
135	info.EventCode = eventCode
136	info.AutoCommit = (int(C.sqlite3_get_autocommit(contextDB)) != 0)
137	info.ConnHandle = connHandle
138
139	switch eventCode {
140	case TraceStmt:
141		info.StmtHandle = uintptr(p)
142
143		var xStr string
144		if xValue != nil {
145			xStr = C.GoString((*C.char)(xValue))
146		}
147		info.StmtOrTrigger = xStr
148		if !strings.HasPrefix(xStr, "--") {
149			// Not SQL comment, therefore the current event
150			// is not related to a trigger.
151			// The user might want to receive the expanded SQL;
152			// let's check:
153			if traceConf.WantExpandedSQL {
154				fillExpandedSQL(&info, contextDB, p)
155			}
156		}
157
158	case TraceProfile:
159		info.StmtHandle = uintptr(p)
160
161		if xValue == nil {
162			panic("NULL pointer in X arg of trace_v2 callback for SQLITE_TRACE_PROFILE event")
163		}
164
165		info.RunTimeNanosec = *(*int64)(xValue)
166
167		// sample the error //TODO: is it safe? is it useful?
168		fillDBError(&info.DBError, contextDB)
169
170	case TraceRow:
171		info.StmtHandle = uintptr(p)
172
173	case TraceClose:
174		handle := uintptr(p)
175		if handle != info.ConnHandle {
176			panic(fmt.Sprintf("Different conn handle 0x%x (expected 0x%x) in SQLITE_TRACE_CLOSE event.",
177				handle, info.ConnHandle))
178		}
179
180	default:
181		// Pass unsupported events to the user callback (if configured);
182		// let the user callback decide whether to panic or ignore them.
183	}
184
185	// Do not execute user callback when the event was not requested by user!
186	// Remember that the Close event is always selected when
187	// registering this callback trampoline with SQLite --- for cleanup.
188	// In the future there may be more events forced to "selected" in SQLite
189	// for the driver's needs.
190	if traceConf.EventMask&eventCode == 0 {
191		return 0
192	}
193
194	r := 0
195	if traceConf.Callback != nil {
196		r = traceConf.Callback(info)
197	}
198	return C.int(r)
199}
200
201type traceMapEntry struct {
202	config TraceConfig
203}
204
205var traceMapLock sync.Mutex
206var traceMap = make(map[uintptr]traceMapEntry)
207
208func addTraceMapping(connHandle uintptr, traceConf TraceConfig) {
209	traceMapLock.Lock()
210	defer traceMapLock.Unlock()
211
212	oldEntryCopy, found := traceMap[connHandle]
213	if found {
214		panic(fmt.Sprintf("Adding trace config %v: handle 0x%x already registered (%v).",
215			traceConf, connHandle, oldEntryCopy.config))
216	}
217	traceMap[connHandle] = traceMapEntry{config: traceConf}
218	fmt.Printf("Added trace config %v: handle 0x%x.\n", traceConf, connHandle)
219}
220
221func lookupTraceMapping(connHandle uintptr) (TraceConfig, bool) {
222	traceMapLock.Lock()
223	defer traceMapLock.Unlock()
224
225	entryCopy, found := traceMap[connHandle]
226	return entryCopy.config, found
227}
228
229// 'pop' = get and delete from map before returning the value to the caller
230func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
231	traceMapLock.Lock()
232	defer traceMapLock.Unlock()
233
234	entryCopy, found := traceMap[connHandle]
235	if found {
236		delete(traceMap, connHandle)
237		fmt.Printf("Pop handle 0x%x: deleted trace config %v.\n", connHandle, entryCopy.config)
238	}
239	return entryCopy.config, found
240}
241
242// SetTrace installs or removes the trace callback for the given database connection.
243// It's not named 'RegisterTrace' because only one callback can be kept and called.
244// Calling SetTrace a second time on same database connection
245// overrides (cancels) any prior callback and all its settings:
246// event mask, etc.
247func (c *SQLiteConn) SetTrace(requested *TraceConfig) error {
248	connHandle := uintptr(unsafe.Pointer(c.db))
249
250	_, _ = popTraceMapping(connHandle)
251
252	if requested == nil {
253		// The traceMap entry was deleted already by popTraceMapping():
254		// can disable all events now, no need to watch for TraceClose.
255		err := c.setSQLiteTrace(0)
256		return err
257	}
258
259	reqCopy := *requested
260
261	// Disable potentially expensive operations
262	// if their result will not be used. We are doing this
263	// just in case the caller provided nonsensical input.
264	if reqCopy.EventMask&TraceStmt == 0 {
265		reqCopy.WantExpandedSQL = false
266	}
267
268	addTraceMapping(connHandle, reqCopy)
269
270	// The callback trampoline function does cleanup on Close event,
271	// regardless of the presence or absence of the user callback.
272	// Therefore it needs the Close event to be selected:
273	actualEventMask := uint(reqCopy.EventMask | TraceClose)
274	err := c.setSQLiteTrace(actualEventMask)
275	return err
276}
277
278func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error {
279	rv := C.sqlite3_trace_v2(c.db,
280		C.uint(sqliteEventMask),
281		(*[0]byte)(unsafe.Pointer(C.traceCallbackTrampoline)),
282		unsafe.Pointer(c.db)) // Fourth arg is same as first: we are
283	// passing the database connection handle as callback context.
284
285	if rv != C.SQLITE_OK {
286		return c.lastError()
287	}
288	return nil
289}
290