1package sqlutil
2
3import (
4	"context"
5	"database/sql"
6	"reflect"
7	"testing"
8
9	sqlmock "github.com/DATA-DOG/go-sqlmock"
10)
11
12func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) {
13	db, mock, err := sqlmock.New()
14	assertNoError(t, err, "Failed to make DB")
15	limit := uint(4)
16
17	r := mock.NewRows([]string{"id"}).
18		AddRow(1).
19		AddRow(2).
20		AddRow(3)
21
22	mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
23	// nolint:goconst
24	q := "SELECT id WHERE id IN ($1)"
25	v := []int{1, 2, 3}
26	iKeyIDs := make([]interface{}, len(v))
27	for i, d := range v {
28		iKeyIDs[i] = d
29	}
30
31	ctx := context.Background()
32	var result = make([]int, 0)
33	err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
34		for rows.Next() {
35			var id int
36			err = rows.Scan(&id)
37			assertNoError(t, err, "rows.Scan returned an error")
38			result = append(result, id)
39		}
40		return nil
41	})
42	assertNoError(t, err, "Call returned an error")
43	if len(result) != len(v) {
44		t.Fatalf("Result should be 3 long")
45	}
46}
47
48func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) {
49	db, mock, err := sqlmock.New()
50	assertNoError(t, err, "Failed to make DB")
51	limit := uint(4)
52
53	r := mock.NewRows([]string{"id"}).
54		AddRow(1).
55		AddRow(2).
56		AddRow(3).
57		AddRow(4)
58
59	mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r)
60	// nolint:goconst
61	q := "SELECT id WHERE id IN ($1)"
62	v := []int{1, 2, 3, 4}
63	iKeyIDs := make([]interface{}, len(v))
64	for i, d := range v {
65		iKeyIDs[i] = d
66	}
67
68	ctx := context.Background()
69	var result = make([]int, 0)
70	err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
71		for rows.Next() {
72			var id int
73			err = rows.Scan(&id)
74			assertNoError(t, err, "rows.Scan returned an error")
75			result = append(result, id)
76		}
77		return nil
78	})
79	assertNoError(t, err, "Call returned an error")
80	if len(result) != len(v) {
81		t.Fatalf("Result should be 4 long")
82	}
83}
84
85func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) {
86	db, mock, err := sqlmock.New()
87	assertNoError(t, err, "Failed to make DB")
88	limit := uint(4)
89
90	r1 := mock.NewRows([]string{"id"}).
91		AddRow(1).
92		AddRow(2).
93		AddRow(3).
94		AddRow(4)
95
96	r2 := mock.NewRows([]string{"id"}).
97		AddRow(5)
98
99	mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1)
100	mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2)
101	// nolint:goconst
102	q := "SELECT id WHERE id IN ($1)"
103	v := []int{1, 2, 3, 4, 5}
104	iKeyIDs := make([]interface{}, len(v))
105	for i, d := range v {
106		iKeyIDs[i] = d
107	}
108
109	ctx := context.Background()
110	var result = make([]int, 0)
111	err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
112		for rows.Next() {
113			var id int
114			err = rows.Scan(&id)
115			assertNoError(t, err, "rows.Scan returned an error")
116			result = append(result, id)
117		}
118		return nil
119	})
120	assertNoError(t, err, "Call returned an error")
121	if len(result) != len(v) {
122		t.Fatalf("Result should be 5 long")
123	}
124	if !reflect.DeepEqual(v, result) {
125		t.Fatalf("Result is not as expected: got %v want %v", v, result)
126	}
127}
128
129func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) {
130	db, mock, err := sqlmock.New()
131	assertNoError(t, err, "Failed to make DB")
132	limit := uint(4)
133
134	// adding a string ID should result in rows.Scan returning an error
135	r := mock.NewRows([]string{"id"}).
136		AddRow("hej").
137		AddRow(2).
138		AddRow(3)
139
140	mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
141	// nolint:goconst
142	q := "SELECT id WHERE id IN ($1)"
143	v := []int{-1, -2, 3}
144	iKeyIDs := make([]interface{}, len(v))
145	for i, d := range v {
146		iKeyIDs[i] = d
147	}
148
149	ctx := context.Background()
150	var result = make([]uint, 0)
151	err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
152		for rows.Next() {
153			var id uint
154			err = rows.Scan(&id)
155			if err != nil {
156				return err
157			}
158			result = append(result, id)
159		}
160		return nil
161	})
162	if err == nil {
163		t.Fatalf("Call did not return an error")
164	}
165}
166
167func assertNoError(t *testing.T, err error, msg string) {
168	t.Helper()
169	if err == nil {
170		return
171	}
172	t.Fatalf(msg)
173}
174