1// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file.
5
6// +build sqlite_vtable vtable
7
8package sqlite3
9
10import (
11	"database/sql"
12	"errors"
13	"fmt"
14	"os"
15	"reflect"
16	"strings"
17	"testing"
18)
19
20type testModule struct {
21	t        *testing.T
22	intarray []int
23}
24
25type testVTab struct {
26	intarray []int
27}
28
29type testVTabCursor struct {
30	vTab  *testVTab
31	index int
32}
33
34func (m testModule) Create(c *SQLiteConn, args []string) (VTab, error) {
35	if len(args) != 6 {
36		m.t.Fatal("six arguments expected")
37	}
38	if args[0] != "test" {
39		m.t.Fatal("module name")
40	}
41	if args[1] != "main" {
42		m.t.Fatal("db name")
43	}
44	if args[2] != "vtab" {
45		m.t.Fatal("table name")
46	}
47	if args[3] != "'1'" {
48		m.t.Fatal("first arg")
49	}
50	if args[4] != "2" {
51		m.t.Fatal("second arg")
52	}
53	if args[5] != "three" {
54		m.t.Fatal("third argsecond arg")
55	}
56	err := c.DeclareVTab("CREATE TABLE x(test TEXT)")
57	if err != nil {
58		return nil, err
59	}
60	return &testVTab{m.intarray}, nil
61}
62
63func (m testModule) Connect(c *SQLiteConn, args []string) (VTab, error) {
64	return m.Create(c, args)
65}
66
67func (m testModule) DestroyModule() {}
68
69func (v *testVTab) BestIndex(cst []InfoConstraint, ob []InfoOrderBy) (*IndexResult, error) {
70	used := make([]bool, 0, len(cst))
71	for range cst {
72		used = append(used, false)
73	}
74	return &IndexResult{
75		Used:           used,
76		IdxNum:         0,
77		IdxStr:         "test-index",
78		AlreadyOrdered: true,
79		EstimatedCost:  100,
80		EstimatedRows:  200,
81	}, nil
82}
83
84func (v *testVTab) Disconnect() error {
85	return nil
86}
87
88func (v *testVTab) Destroy() error {
89	return nil
90}
91
92func (v *testVTab) Open() (VTabCursor, error) {
93	return &testVTabCursor{v, 0}, nil
94}
95
96func (vc *testVTabCursor) Close() error {
97	return nil
98}
99
100func (vc *testVTabCursor) Filter(idxNum int, idxStr string, vals []interface{}) error {
101	vc.index = 0
102	return nil
103}
104
105func (vc *testVTabCursor) Next() error {
106	vc.index++
107	return nil
108}
109
110func (vc *testVTabCursor) EOF() bool {
111	return vc.index >= len(vc.vTab.intarray)
112}
113
114func (vc *testVTabCursor) Column(c *SQLiteContext, col int) error {
115	if col != 0 {
116		return fmt.Errorf("column index out of bounds: %d", col)
117	}
118	c.ResultInt(vc.vTab.intarray[vc.index])
119	return nil
120}
121
122func (vc *testVTabCursor) Rowid() (int64, error) {
123	return int64(vc.index), nil
124}
125
126func TestCreateModule(t *testing.T) {
127	tempFilename := TempFilename(t)
128	defer os.Remove(tempFilename)
129	intarray := []int{1, 2, 3}
130	sql.Register("sqlite3_TestCreateModule", &SQLiteDriver{
131		ConnectHook: func(conn *SQLiteConn) error {
132			return conn.CreateModule("test", testModule{t, intarray})
133		},
134	})
135	db, err := sql.Open("sqlite3_TestCreateModule", tempFilename)
136	if err != nil {
137		t.Fatalf("could not open db: %v", err)
138	}
139	_, err = db.Exec("CREATE VIRTUAL TABLE vtab USING test('1', 2, three)")
140	if err != nil {
141		t.Fatalf("could not create vtable: %v", err)
142	}
143
144	var i, value int
145	rows, err := db.Query("SELECT rowid, * FROM vtab WHERE test = '3'")
146	if err != nil {
147		t.Fatalf("couldn't select from virtual table: %v", err)
148	}
149	for rows.Next() {
150		rows.Scan(&i, &value)
151		if intarray[i] != value {
152			t.Fatalf("want %v but %v", intarray[i], value)
153		}
154	}
155
156	_, err = db.Exec("DROP TABLE vtab")
157	if err != nil {
158		t.Fatalf("couldn't drop virtual table: %v", err)
159	}
160}
161
162func TestVUpdate(t *testing.T) {
163	tempFilename := TempFilename(t)
164	defer os.Remove(tempFilename)
165
166	// create module
167	updateMod := &vtabUpdateModule{t, make(map[string]*vtabUpdateTable)}
168
169	// register module
170	sql.Register("sqlite3_TestVUpdate", &SQLiteDriver{
171		ConnectHook: func(conn *SQLiteConn) error {
172			return conn.CreateModule("updatetest", updateMod)
173		},
174	})
175
176	// connect
177	db, err := sql.Open("sqlite3_TestVUpdate", tempFilename)
178	if err != nil {
179		t.Fatalf("could not open db: %v", err)
180	}
181
182	// create test table
183	_, err = db.Exec(`CREATE VIRTUAL TABLE vt USING updatetest(f1 integer, f2 text, f3 text)`)
184	if err != nil {
185		t.Fatalf("could not create updatetest vtable vt, got: %v", err)
186	}
187
188	// check that table is defined properly
189	if len(updateMod.tables) != 1 {
190		t.Fatalf("expected exactly 1 table to exist, got: %d", len(updateMod.tables))
191	}
192	if _, ok := updateMod.tables["vt"]; !ok {
193		t.Fatalf("expected table `vt` to exist in tables")
194	}
195
196	// check nothing in updatetest
197	rows, err := db.Query(`select * from vt`)
198	if err != nil {
199		t.Fatalf("could not query vt, got: %v", err)
200	}
201	i, err := getRowCount(rows)
202	if err != nil {
203		t.Fatalf("expected no error, got: %v", err)
204	}
205	if i != 0 {
206		t.Fatalf("expected no rows in vt, got: %d", i)
207	}
208
209	_, err = db.Exec(`delete from vt where f1 = 'yes'`)
210	if err != nil {
211		t.Fatalf("expected error on delete, got nil")
212	}
213
214	// test bad column name
215	_, err = db.Exec(`insert into vt (f4) values('a')`)
216	if err == nil {
217		t.Fatalf("expected error on insert, got nil")
218	}
219
220	// insert to vt
221	res, err := db.Exec(`insert into vt (f1, f2, f3) values (115, 'b', 'c'), (116, 'd', 'e')`)
222	if err != nil {
223		t.Fatalf("expected no error on insert, got: %v", err)
224	}
225	n, err := res.RowsAffected()
226	if err != nil {
227		t.Fatalf("expected no error, got: %v", err)
228	}
229	if n != 2 {
230		t.Fatalf("expected 1 row affected, got: %d", n)
231	}
232
233	// check vt table
234	vt := updateMod.tables["vt"]
235	if len(vt.data) != 2 {
236		t.Fatalf("expected table vt to have exactly 2 rows, got: %d", len(vt.data))
237	}
238	if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
239		t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
240	}
241	if !reflect.DeepEqual(vt.data[1], []interface{}{int64(116), "d", "e"}) {
242		t.Fatalf("expected table vt entry 1 to be [116 d e], instead: %v", vt.data[1])
243	}
244
245	// query vt
246	var f1 int
247	var f2, f3 string
248	err = db.QueryRow(`select * from vt where f1 = 115`).Scan(&f1, &f2, &f3)
249	if err != nil {
250		t.Fatalf("expected no error on vt query, got: %v", err)
251	}
252
253	// check column values
254	if f1 != 115 || f2 != "b" || f3 != "c" {
255		t.Errorf("expected f1==115, f2==b, f3==c, got: %d, %q, %q", f1, f2, f3)
256	}
257
258	// update vt
259	res, err = db.Exec(`update vt set f1=117, f2='f' where f3='e'`)
260	if err != nil {
261		t.Fatalf("expected no error, got: %v", err)
262	}
263	n, err = res.RowsAffected()
264	if err != nil {
265		t.Fatalf("expected no error, got: %v", err)
266	}
267	if n != 1 {
268		t.Fatalf("expected exactly one row updated, got: %d", n)
269	}
270
271	// check vt table
272	if len(vt.data) != 2 {
273		t.Fatalf("expected table vt to have exactly 2 rows, got: %d", len(vt.data))
274	}
275	if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
276		t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
277	}
278	if !reflect.DeepEqual(vt.data[1], []interface{}{int64(117), "f", "e"}) {
279		t.Fatalf("expected table vt entry 1 to be [117 f e], instead: %v", vt.data[1])
280	}
281
282	// delete from vt
283	res, err = db.Exec(`delete from vt where f1 = 117`)
284	if err != nil {
285		t.Fatalf("expected no error, got: %v", err)
286	}
287	n, err = res.RowsAffected()
288	if err != nil {
289		t.Fatalf("expected no error, got: %v", err)
290	}
291	if n != 1 {
292		t.Fatalf("expected exactly one row deleted, got: %d", n)
293	}
294
295	// check vt table
296	if len(vt.data) != 1 {
297		t.Fatalf("expected table vt to have exactly 1 row, got: %d", len(vt.data))
298	}
299	if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
300		t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
301	}
302
303	// check updatetest has 1 result
304	rows, err = db.Query(`select * from vt`)
305	if err != nil {
306		t.Fatalf("could not query vt, got: %v", err)
307	}
308	i, err = getRowCount(rows)
309	if err != nil {
310		t.Fatalf("expected no error, got: %v", err)
311	}
312	if i != 1 {
313		t.Fatalf("expected 1 row in vt, got: %d", i)
314	}
315}
316
317func getRowCount(rows *sql.Rows) (int, error) {
318	var i int
319	for rows.Next() {
320		i++
321	}
322	return i, nil
323}
324
325type vtabUpdateModule struct {
326	t      *testing.T
327	tables map[string]*vtabUpdateTable
328}
329
330func (m *vtabUpdateModule) Create(c *SQLiteConn, args []string) (VTab, error) {
331	if len(args) < 2 {
332		return nil, errors.New("must declare at least one column")
333	}
334
335	// get database name, table name, and column declarations ...
336	dbname, tname, decls := args[1], args[2], args[3:]
337
338	// extract column names + types from parameters declarations
339	cols, typs := make([]string, len(decls)), make([]string, len(decls))
340	for i := 0; i < len(decls); i++ {
341		n, typ := decls[i], ""
342		if j := strings.IndexAny(n, " \t\n"); j != -1 {
343			typ, n = strings.TrimSpace(n[j+1:]), n[:j]
344		}
345		cols[i], typs[i] = n, typ
346	}
347
348	// declare table
349	err := c.DeclareVTab(fmt.Sprintf(`CREATE TABLE "%s"."%s" (%s)`, dbname, tname, strings.Join(decls, ",")))
350	if err != nil {
351		return nil, err
352	}
353
354	// create table
355	vtab := &vtabUpdateTable{m.t, dbname, tname, cols, typs, make([][]interface{}, 0)}
356	m.tables[tname] = vtab
357	return vtab, nil
358}
359
360func (m *vtabUpdateModule) Connect(c *SQLiteConn, args []string) (VTab, error) {
361	return m.Create(c, args)
362}
363
364func (m *vtabUpdateModule) DestroyModule() {}
365
366type vtabUpdateTable struct {
367	t    *testing.T
368	db   string
369	name string
370	cols []string
371	typs []string
372	data [][]interface{}
373}
374
375func (t *vtabUpdateTable) Open() (VTabCursor, error) {
376	return &vtabUpdateCursor{t, 0}, nil
377}
378
379func (t *vtabUpdateTable) BestIndex(cst []InfoConstraint, ob []InfoOrderBy) (*IndexResult, error) {
380	return &IndexResult{Used: make([]bool, len(cst))}, nil
381}
382
383func (t *vtabUpdateTable) Disconnect() error {
384	return nil
385}
386
387func (t *vtabUpdateTable) Destroy() error {
388	return nil
389}
390
391func (t *vtabUpdateTable) Insert(id interface{}, vals []interface{}) (int64, error) {
392	var i int64
393	if id == nil {
394		i, t.data = int64(len(t.data)), append(t.data, vals)
395		return i, nil
396	}
397
398	var ok bool
399	i, ok = id.(int64)
400	if !ok {
401		return 0, fmt.Errorf("id is invalid type: %T", id)
402	}
403
404	t.data[i] = vals
405
406	return i, nil
407}
408
409func (t *vtabUpdateTable) Update(id interface{}, vals []interface{}) error {
410	i, ok := id.(int64)
411	if !ok {
412		return fmt.Errorf("id is invalid type: %T", id)
413	}
414
415	if int(i) >= len(t.data) || i < 0 {
416		return fmt.Errorf("invalid row id %d", i)
417	}
418
419	t.data[int(i)] = vals
420
421	return nil
422}
423
424func (t *vtabUpdateTable) Delete(id interface{}) error {
425	i, ok := id.(int64)
426	if !ok {
427		return fmt.Errorf("id is invalid type: %T", id)
428	}
429
430	if int(i) >= len(t.data) || i < 0 {
431		return fmt.Errorf("invalid row id %d", i)
432	}
433
434	t.data = append(t.data[:i], t.data[i+1:]...)
435
436	return nil
437}
438
439type vtabUpdateCursor struct {
440	t *vtabUpdateTable
441	i int
442}
443
444func (c *vtabUpdateCursor) Column(ctxt *SQLiteContext, col int) error {
445	switch x := c.t.data[c.i][col].(type) {
446	case []byte:
447		ctxt.ResultBlob(x)
448	case bool:
449		ctxt.ResultBool(x)
450	case float64:
451		ctxt.ResultDouble(x)
452	case int:
453		ctxt.ResultInt(x)
454	case int64:
455		ctxt.ResultInt64(x)
456	case nil:
457		ctxt.ResultNull()
458	case string:
459		ctxt.ResultText(x)
460	default:
461		ctxt.ResultText(fmt.Sprintf("%v", x))
462	}
463
464	return nil
465}
466
467func (c *vtabUpdateCursor) Filter(ixNum int, ixName string, vals []interface{}) error {
468	return nil
469}
470
471func (c *vtabUpdateCursor) Next() error {
472	c.i++
473	return nil
474}
475
476func (c *vtabUpdateCursor) EOF() bool {
477	return c.i >= len(c.t.data)
478}
479
480func (c *vtabUpdateCursor) Rowid() (int64, error) {
481	return int64(c.i), nil
482}
483
484func (c *vtabUpdateCursor) Close() error {
485	return nil
486}
487
488type testModuleEponymousOnly struct {
489	t        *testing.T
490	intarray []int
491}
492
493type testVTabEponymousOnly struct {
494	intarray []int
495}
496
497type testVTabCursorEponymousOnly struct {
498	vTab  *testVTabEponymousOnly
499	index int
500}
501
502func (m testModuleEponymousOnly) EponymousOnlyModule() {}
503
504func (m testModuleEponymousOnly) Create(c *SQLiteConn, args []string) (VTab, error) {
505	err := c.DeclareVTab("CREATE TABLE x(test INT)")
506	if err != nil {
507		return nil, err
508	}
509	return &testVTabEponymousOnly{m.intarray}, nil
510}
511
512func (m testModuleEponymousOnly) Connect(c *SQLiteConn, args []string) (VTab, error) {
513	return m.Create(c, args)
514}
515
516func (m testModuleEponymousOnly) DestroyModule() {}
517
518func (v *testVTabEponymousOnly) BestIndex(cst []InfoConstraint, ob []InfoOrderBy) (*IndexResult, error) {
519	used := make([]bool, 0, len(cst))
520	for range cst {
521		used = append(used, false)
522	}
523	return &IndexResult{
524		Used:           used,
525		IdxNum:         0,
526		IdxStr:         "test-index",
527		AlreadyOrdered: true,
528		EstimatedCost:  100,
529		EstimatedRows:  200,
530	}, nil
531}
532
533func (v *testVTabEponymousOnly) Disconnect() error {
534	return nil
535}
536
537func (v *testVTabEponymousOnly) Destroy() error {
538	return nil
539}
540
541func (v *testVTabEponymousOnly) Open() (VTabCursor, error) {
542	return &testVTabCursorEponymousOnly{v, 0}, nil
543}
544
545func (vc *testVTabCursorEponymousOnly) Close() error {
546	return nil
547}
548
549func (vc *testVTabCursorEponymousOnly) Filter(idxNum int, idxStr string, vals []interface{}) error {
550	vc.index = 0
551	return nil
552}
553
554func (vc *testVTabCursorEponymousOnly) Next() error {
555	vc.index++
556	return nil
557}
558
559func (vc *testVTabCursorEponymousOnly) EOF() bool {
560	return vc.index >= len(vc.vTab.intarray)
561}
562
563func (vc *testVTabCursorEponymousOnly) Column(c *SQLiteContext, col int) error {
564	if col != 0 {
565		return fmt.Errorf("column index out of bounds: %d", col)
566	}
567	c.ResultInt(vc.vTab.intarray[vc.index])
568	return nil
569}
570
571func (vc *testVTabCursorEponymousOnly) Rowid() (int64, error) {
572	return int64(vc.index), nil
573}
574
575func TestCreateModuleEponymousOnly(t *testing.T) {
576	tempFilename := TempFilename(t)
577	defer os.Remove(tempFilename)
578	intarray := []int{1, 2, 3}
579	sql.Register("sqlite3_TestCreateModuleEponymousOnly", &SQLiteDriver{
580		ConnectHook: func(conn *SQLiteConn) error {
581			return conn.CreateModule("test", testModuleEponymousOnly{t, intarray})
582		},
583	})
584	db, err := sql.Open("sqlite3_TestCreateModuleEponymousOnly", tempFilename)
585	if err != nil {
586		t.Fatalf("could not open db: %v", err)
587	}
588
589	var i, value int
590	rows, err := db.Query("SELECT rowid, * FROM test")
591	if err != nil {
592		t.Fatalf("couldn't select from virtual table: %v", err)
593	}
594	for rows.Next() {
595		err := rows.Scan(&i, &value)
596		if err != nil {
597			t.Fatal(err)
598		}
599		if intarray[i] != value {
600			t.Fatalf("want %v but %v", intarray[i], value)
601		}
602	}
603
604	_, err = db.Exec("DROP TABLE test")
605	if err != nil {
606		t.Fatalf("couldn't drop virtual table: %v", err)
607	}
608}
609