1package sqlutil 2 3import ( 4 "context" 5 "database/sql" 6 "reflect" 7 "testing" 8 9 sqlmock "github.com/DATA-DOG/go-sqlmock" 10) 11 12func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { 13 db, mock, err := sqlmock.New() 14 assertNoError(t, err, "Failed to make DB") 15 limit := uint(4) 16 17 r := mock.NewRows([]string{"id"}). 18 AddRow(1). 19 AddRow(2). 20 AddRow(3) 21 22 mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) 23 // nolint:goconst 24 q := "SELECT id WHERE id IN ($1)" 25 v := []int{1, 2, 3} 26 iKeyIDs := make([]interface{}, len(v)) 27 for i, d := range v { 28 iKeyIDs[i] = d 29 } 30 31 ctx := context.Background() 32 var result = make([]int, 0) 33 err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { 34 for rows.Next() { 35 var id int 36 err = rows.Scan(&id) 37 assertNoError(t, err, "rows.Scan returned an error") 38 result = append(result, id) 39 } 40 return nil 41 }) 42 assertNoError(t, err, "Call returned an error") 43 if len(result) != len(v) { 44 t.Fatalf("Result should be 3 long") 45 } 46} 47 48func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) { 49 db, mock, err := sqlmock.New() 50 assertNoError(t, err, "Failed to make DB") 51 limit := uint(4) 52 53 r := mock.NewRows([]string{"id"}). 54 AddRow(1). 55 AddRow(2). 56 AddRow(3). 57 AddRow(4) 58 59 mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r) 60 // nolint:goconst 61 q := "SELECT id WHERE id IN ($1)" 62 v := []int{1, 2, 3, 4} 63 iKeyIDs := make([]interface{}, len(v)) 64 for i, d := range v { 65 iKeyIDs[i] = d 66 } 67 68 ctx := context.Background() 69 var result = make([]int, 0) 70 err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { 71 for rows.Next() { 72 var id int 73 err = rows.Scan(&id) 74 assertNoError(t, err, "rows.Scan returned an error") 75 result = append(result, id) 76 } 77 return nil 78 }) 79 assertNoError(t, err, "Call returned an error") 80 if len(result) != len(v) { 81 t.Fatalf("Result should be 4 long") 82 } 83} 84 85func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) { 86 db, mock, err := sqlmock.New() 87 assertNoError(t, err, "Failed to make DB") 88 limit := uint(4) 89 90 r1 := mock.NewRows([]string{"id"}). 91 AddRow(1). 92 AddRow(2). 93 AddRow(3). 94 AddRow(4) 95 96 r2 := mock.NewRows([]string{"id"}). 97 AddRow(5) 98 99 mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1) 100 mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2) 101 // nolint:goconst 102 q := "SELECT id WHERE id IN ($1)" 103 v := []int{1, 2, 3, 4, 5} 104 iKeyIDs := make([]interface{}, len(v)) 105 for i, d := range v { 106 iKeyIDs[i] = d 107 } 108 109 ctx := context.Background() 110 var result = make([]int, 0) 111 err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { 112 for rows.Next() { 113 var id int 114 err = rows.Scan(&id) 115 assertNoError(t, err, "rows.Scan returned an error") 116 result = append(result, id) 117 } 118 return nil 119 }) 120 assertNoError(t, err, "Call returned an error") 121 if len(result) != len(v) { 122 t.Fatalf("Result should be 5 long") 123 } 124 if !reflect.DeepEqual(v, result) { 125 t.Fatalf("Result is not as expected: got %v want %v", v, result) 126 } 127} 128 129func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { 130 db, mock, err := sqlmock.New() 131 assertNoError(t, err, "Failed to make DB") 132 limit := uint(4) 133 134 // adding a string ID should result in rows.Scan returning an error 135 r := mock.NewRows([]string{"id"}). 136 AddRow("hej"). 137 AddRow(2). 138 AddRow(3) 139 140 mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) 141 // nolint:goconst 142 q := "SELECT id WHERE id IN ($1)" 143 v := []int{-1, -2, 3} 144 iKeyIDs := make([]interface{}, len(v)) 145 for i, d := range v { 146 iKeyIDs[i] = d 147 } 148 149 ctx := context.Background() 150 var result = make([]uint, 0) 151 err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { 152 for rows.Next() { 153 var id uint 154 err = rows.Scan(&id) 155 if err != nil { 156 return err 157 } 158 result = append(result, id) 159 } 160 return nil 161 }) 162 if err == nil { 163 t.Fatalf("Call did not return an error") 164 } 165} 166 167func assertNoError(t *testing.T, err error, msg string) { 168 t.Helper() 169 if err == nil { 170 return 171 } 172 t.Fatalf(msg) 173} 174