1// Copyright (C) 2019 G.J.R. Timmer <gjr.timmer@gmail.com>.
2// Copyright (C) 2018 segment.com <friends@segment.com>
3//
4// Use of this source code is governed by an MIT-style
5// license that can be found in the LICENSE file.
6
7// +build sqlite_preupdate_hook
8
9package sqlite3
10
11/*
12#cgo CFLAGS: -DSQLITE_ENABLE_PREUPDATE_HOOK
13#cgo LDFLAGS: -lm
14
15#ifndef USE_LIBSQLITE3
16#include <sqlite3-binding.h>
17#else
18#include <sqlite3.h>
19#endif
20#include <stdlib.h>
21#include <string.h>
22
23void preUpdateHookTrampoline(void*, sqlite3 *, int, char *, char *, sqlite3_int64, sqlite3_int64);
24*/
25import "C"
26import (
27	"errors"
28	"unsafe"
29)
30
31// RegisterPreUpdateHook sets the pre-update hook for a connection.
32//
33// The callback is passed a SQLitePreUpdateData struct with the data for
34// the update, as well as methods for fetching copies of impacted data.
35//
36// If there is an existing update hook for this connection, it will be
37// removed. If callback is nil the existing hook (if any) will be removed
38// without creating a new one.
39func (c *SQLiteConn) RegisterPreUpdateHook(callback func(SQLitePreUpdateData)) {
40	if callback == nil {
41		C.sqlite3_preupdate_hook(c.db, nil, nil)
42	} else {
43		C.sqlite3_preupdate_hook(c.db, (*[0]byte)(unsafe.Pointer(C.preUpdateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
44	}
45}
46
47// Depth returns the source path of the write, see sqlite3_preupdate_depth()
48func (d *SQLitePreUpdateData) Depth() int {
49	return int(C.sqlite3_preupdate_depth(d.Conn.db))
50}
51
52// Count returns the number of columns in the row
53func (d *SQLitePreUpdateData) Count() int {
54	return int(C.sqlite3_preupdate_count(d.Conn.db))
55}
56
57func (d *SQLitePreUpdateData) row(dest []interface{}, new bool) error {
58	for i := 0; i < d.Count() && i < len(dest); i++ {
59		var val *C.sqlite3_value
60		var src interface{}
61
62		// Initially I tried making this just a function pointer argument, but
63		// it's absurdly complicated to pass C function pointers.
64		if new {
65			C.sqlite3_preupdate_new(d.Conn.db, C.int(i), &val)
66		} else {
67			C.sqlite3_preupdate_old(d.Conn.db, C.int(i), &val)
68		}
69
70		switch C.sqlite3_value_type(val) {
71		case C.SQLITE_INTEGER:
72			src = int64(C.sqlite3_value_int64(val))
73		case C.SQLITE_FLOAT:
74			src = float64(C.sqlite3_value_double(val))
75		case C.SQLITE_BLOB:
76			len := C.sqlite3_value_bytes(val)
77			blobptr := C.sqlite3_value_blob(val)
78			src = C.GoBytes(blobptr, len)
79		case C.SQLITE_TEXT:
80			len := C.sqlite3_value_bytes(val)
81			cstrptr := unsafe.Pointer(C.sqlite3_value_text(val))
82			src = C.GoBytes(cstrptr, len)
83		case C.SQLITE_NULL:
84			src = nil
85		}
86
87		err := convertAssign(&dest[i], src)
88		if err != nil {
89			return err
90		}
91	}
92
93	return nil
94}
95
96// Old populates dest with the row data to be replaced. This works similar to
97// database/sql's Rows.Scan()
98func (d *SQLitePreUpdateData) Old(dest ...interface{}) error {
99	if d.Op == SQLITE_INSERT {
100		return errors.New("There is no old row for INSERT operations")
101	}
102	return d.row(dest, false)
103}
104
105// New populates dest with the replacement row data. This works similar to
106// database/sql's Rows.Scan()
107func (d *SQLitePreUpdateData) New(dest ...interface{}) error {
108	if d.Op == SQLITE_DELETE {
109		return errors.New("There is no new row for DELETE operations")
110	}
111	return d.row(dest, true)
112}
113