1package goqu_test
2
3import (
4	"testing"
5	"time"
6
7	"github.com/DATA-DOG/go-sqlmock"
8	"github.com/doug-martin/goqu/v9"
9	"github.com/doug-martin/goqu/v9/exp"
10	"github.com/doug-martin/goqu/v9/internal/errors"
11	"github.com/doug-martin/goqu/v9/internal/sb"
12	"github.com/doug-martin/goqu/v9/mocks"
13	"github.com/stretchr/testify/mock"
14	"github.com/stretchr/testify/suite"
15)
16
17type (
18	insertTestCase struct {
19		ds      *goqu.InsertDataset
20		clauses exp.InsertClauses
21	}
22	insertDatasetSuite struct {
23		suite.Suite
24	}
25)
26
27func (ids *insertDatasetSuite) assertCases(cases ...insertTestCase) {
28	for _, s := range cases {
29		ids.Equal(s.clauses, s.ds.GetClauses())
30	}
31}
32
33func (ids *insertDatasetSuite) TestInsert() {
34	ds := goqu.Insert("test")
35	ids.IsType(&goqu.InsertDataset{}, ds)
36	ids.Implements((*exp.Expression)(nil), ds)
37	ids.Implements((*exp.AppendableExpression)(nil), ds)
38}
39
40func (ids *insertDatasetSuite) TestClone() {
41	ds := goqu.Insert("test")
42	ids.Equal(ds.Clone(), ds)
43}
44
45func (ids *insertDatasetSuite) TestExpression() {
46	ds := goqu.Insert("test")
47	ids.Equal(ds.Expression(), ds)
48}
49
50func (ids *insertDatasetSuite) TestDialect() {
51	ds := goqu.Insert("test")
52	ids.NotNil(ds.Dialect())
53}
54
55func (ids *insertDatasetSuite) TestWithDialect() {
56	ds := goqu.Insert("test")
57	md := new(mocks.SQLDialect)
58	ds = ds.SetDialect(md)
59
60	dialect := goqu.GetDialect("default")
61	dialectDs := ds.WithDialect("default")
62	ids.Equal(md, ds.Dialect())
63	ids.Equal(dialect, dialectDs.Dialect())
64}
65
66func (ids *insertDatasetSuite) TestPrepared() {
67	ds := goqu.Insert("test")
68	preparedDs := ds.Prepared(true)
69	ids.True(preparedDs.IsPrepared())
70	ids.False(ds.IsPrepared())
71	// should apply the prepared to any datasets created from the root
72	ids.True(preparedDs.Returning(goqu.C("col")).IsPrepared())
73
74	defer goqu.SetDefaultPrepared(false)
75	goqu.SetDefaultPrepared(true)
76
77	// should be prepared by default
78	ds = goqu.Insert("test")
79	ids.True(ds.IsPrepared())
80}
81
82func (ids *insertDatasetSuite) TestGetClauses() {
83	ds := goqu.Insert("test")
84	ce := exp.NewInsertClauses().SetInto(goqu.I("test"))
85	ids.Equal(ce, ds.GetClauses())
86}
87
88func (ids *insertDatasetSuite) TestWith() {
89	from := goqu.From("cte")
90	bd := goqu.Insert("items")
91	ids.assertCases(
92		insertTestCase{
93			ds: bd.With("test-cte", from),
94			clauses: exp.NewInsertClauses().
95				SetInto(goqu.C("items")).
96				CommonTablesAppend(exp.NewCommonTableExpression(false, "test-cte", from)),
97		},
98		insertTestCase{
99			ds:      bd,
100			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
101		},
102	)
103}
104
105func (ids *insertDatasetSuite) TestWithRecursive() {
106	from := goqu.From("cte")
107	bd := goqu.Insert("items")
108	ids.assertCases(
109		insertTestCase{
110			ds: bd.WithRecursive("test-cte", from),
111			clauses: exp.NewInsertClauses().
112				SetInto(goqu.C("items")).
113				CommonTablesAppend(exp.NewCommonTableExpression(true, "test-cte", from)),
114		},
115		insertTestCase{
116			ds:      bd,
117			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
118		},
119	)
120}
121
122func (ids *insertDatasetSuite) TestInto() {
123	bd := goqu.Insert("items")
124	ids.assertCases(
125		insertTestCase{
126			ds:      bd.Into("items2"),
127			clauses: exp.NewInsertClauses().SetInto(goqu.C("items2")),
128		},
129		insertTestCase{
130			ds:      bd.Into(goqu.L("items2")),
131			clauses: exp.NewInsertClauses().SetInto(goqu.L("items2")),
132		},
133		insertTestCase{
134			ds:      bd,
135			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
136		},
137	)
138
139	ids.PanicsWithValue(goqu.ErrUnsupportedIntoType, func() {
140		bd.Into(true)
141	})
142}
143
144func (ids *insertDatasetSuite) TestCols() {
145	bd := goqu.Insert("items")
146	ids.assertCases(
147		insertTestCase{
148			ds: bd.Cols("a", "b"),
149			clauses: exp.NewInsertClauses().
150				SetInto(goqu.C("items")).
151				SetCols(exp.NewColumnListExpression("a", "b")),
152		},
153		insertTestCase{
154			ds: bd.Cols("a", "b").Cols("c", "d"),
155			clauses: exp.NewInsertClauses().
156				SetInto(goqu.C("items")).
157				SetCols(exp.NewColumnListExpression("c", "d")),
158		},
159		insertTestCase{
160			ds:      bd,
161			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
162		},
163	)
164}
165
166func (ids *insertDatasetSuite) TestClearCols() {
167	bd := goqu.Insert("items").Cols("a", "b")
168	ids.assertCases(
169		insertTestCase{
170			ds:      bd.ClearCols(),
171			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
172		},
173		insertTestCase{
174			ds:      bd,
175			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetCols(exp.NewColumnListExpression("a", "b")),
176		},
177	)
178}
179
180func (ids *insertDatasetSuite) TestColsAppend() {
181	bd := goqu.Insert("items").Cols("a")
182	ids.assertCases(
183		insertTestCase{
184			ds:      bd.ColsAppend("b"),
185			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetCols(exp.NewColumnListExpression("a", "b")),
186		},
187		insertTestCase{
188			ds:      bd,
189			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetCols(exp.NewColumnListExpression("a")),
190		},
191	)
192}
193
194func (ids *insertDatasetSuite) TestFromQuery() {
195	bd := goqu.Insert("items")
196	ids.assertCases(
197		insertTestCase{
198			ds: bd.FromQuery(goqu.From("other_items").Where(goqu.C("b").Gt(10))),
199			clauses: exp.NewInsertClauses().
200				SetInto(goqu.C("items")).
201				SetFrom(goqu.From("other_items").Where(goqu.C("b").Gt(10))),
202		},
203		insertTestCase{
204			ds: bd.FromQuery(goqu.From("other_items").Where(goqu.C("b").Gt(10))).Cols("a", "b"),
205			clauses: exp.NewInsertClauses().
206				SetInto(goqu.C("items")).
207				SetCols(exp.NewColumnListExpression("a", "b")).
208				SetFrom(goqu.From("other_items").Where(goqu.C("b").Gt(10))),
209		},
210		insertTestCase{
211			ds:      bd,
212			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
213		},
214	)
215}
216
217func (ids *insertDatasetSuite) TestFromQueryDialectInheritance() {
218	md := new(mocks.SQLDialect)
219	md.On("Dialect").Return("dialect")
220
221	ids.Run("ok, default dialect is replaced with insert dialect", func() {
222		bd := goqu.Insert("items").SetDialect(md).FromQuery(goqu.From("other_items"))
223		ids.Require().Equal(md, bd.GetClauses().From().(*goqu.SelectDataset).Dialect())
224	})
225
226	ids.Run("ok, insert and select dialects coincide", func() {
227		bd := goqu.Insert("items").SetDialect(md).FromQuery(goqu.From("other_items").SetDialect(md))
228		ids.Require().Equal(md, bd.GetClauses().From().(*goqu.SelectDataset).Dialect())
229	})
230
231	ids.Run("ok, insert and select dialects are default", func() {
232		bd := goqu.Insert("items").FromQuery(goqu.From("other_items"))
233		ids.Require().Equal(goqu.GetDialect("default"), bd.GetClauses().From().(*goqu.SelectDataset).Dialect())
234	})
235
236	ids.Run("panic, insert and select dialects are different", func() {
237		defer func() {
238			r := recover()
239			if r == nil {
240				ids.Fail("there should be a panic")
241			}
242			ids.Require().Equal(
243				"incompatible dialects for INSERT (\"dialect\") and SELECT (\"other_dialect\")",
244				r.(error).Error(),
245			)
246		}()
247
248		otherDialect := new(mocks.SQLDialect)
249		otherDialect.On("Dialect").Return("other_dialect")
250		goqu.Insert("items").SetDialect(md).FromQuery(goqu.From("otherItems").SetDialect(otherDialect))
251	})
252}
253
254func (ids *insertDatasetSuite) TestVals() {
255	val1 := []interface{}{
256		"a", "b",
257	}
258	val2 := []interface{}{
259		"c", "d",
260	}
261
262	bd := goqu.Insert("items")
263	ids.assertCases(
264		insertTestCase{
265			ds: bd.Vals(val1),
266			clauses: exp.NewInsertClauses().
267				SetInto(goqu.C("items")).
268				SetVals([][]interface{}{val1}),
269		},
270		insertTestCase{
271			ds: bd.Vals(val1, val2),
272			clauses: exp.NewInsertClauses().
273				SetInto(goqu.C("items")).
274				SetVals([][]interface{}{val1, val2}),
275		},
276		insertTestCase{
277			ds: bd.Vals(val1).Vals(val2),
278			clauses: exp.NewInsertClauses().
279				SetInto(goqu.C("items")).
280				SetVals([][]interface{}{val1, val2}),
281		},
282		insertTestCase{
283			ds:      bd,
284			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
285		},
286	)
287}
288
289func (ids *insertDatasetSuite) TestClearVals() {
290	val := []interface{}{
291		"a", "b",
292	}
293	bd := goqu.Insert("items").Vals(val)
294	ids.assertCases(
295		insertTestCase{
296			ds:      bd.ClearVals(),
297			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
298		},
299		insertTestCase{
300			ds:      bd,
301			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetVals([][]interface{}{val}),
302		},
303	)
304}
305
306func (ids *insertDatasetSuite) TestRows() {
307	type item struct {
308		CreatedAt *time.Time `db:"created_at"`
309	}
310	n := time.Now()
311	r := item{CreatedAt: nil}
312	r2 := item{CreatedAt: &n}
313	bd := goqu.Insert("items")
314	ids.assertCases(
315		insertTestCase{
316			ds:      bd.Rows(r),
317			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetRows([]interface{}{r}),
318		},
319		insertTestCase{
320			ds:      bd.Rows(r).Rows(r2),
321			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetRows([]interface{}{r2}),
322		},
323		insertTestCase{
324			ds:      bd,
325			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
326		},
327	)
328}
329
330func (ids *insertDatasetSuite) TestClearRows() {
331	type item struct {
332		CreatedAt *time.Time `db:"created_at"`
333	}
334	r := item{CreatedAt: nil}
335	bd := goqu.Insert("items").Rows(r)
336	ids.assertCases(
337		insertTestCase{
338			ds:      bd.ClearRows(),
339			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
340		},
341		insertTestCase{
342			ds:      bd,
343			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetRows([]interface{}{r}),
344		},
345	)
346}
347
348func (ids *insertDatasetSuite) TestOnConflict() {
349	du := goqu.DoUpdate("other_items", goqu.Record{"a": 1})
350
351	bd := goqu.Insert("items")
352	ids.assertCases(
353		insertTestCase{
354			ds:      bd.OnConflict(nil),
355			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
356		},
357		insertTestCase{
358			ds:      bd.OnConflict(goqu.DoNothing()),
359			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetOnConflict(goqu.DoNothing()),
360		},
361		insertTestCase{
362			ds:      bd.OnConflict(du),
363			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetOnConflict(du),
364		},
365		insertTestCase{
366			ds:      bd,
367			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
368		},
369	)
370}
371
372func (ids *insertDatasetSuite) TestClearOnConflict() {
373	du := goqu.DoUpdate("other_items", goqu.Record{"a": 1})
374
375	bd := goqu.Insert("items").OnConflict(du)
376	ids.assertCases(
377		insertTestCase{
378			ds:      bd.ClearOnConflict(),
379			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
380		},
381		insertTestCase{
382			ds:      bd,
383			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")).SetOnConflict(du),
384		},
385	)
386}
387
388func (ids *insertDatasetSuite) TestReturning() {
389	bd := goqu.Insert("items")
390	ids.assertCases(
391		insertTestCase{
392			ds: bd.Returning("a"),
393			clauses: exp.NewInsertClauses().
394				SetInto(goqu.C("items")).
395				SetReturning(exp.NewColumnListExpression("a")),
396		},
397		insertTestCase{
398			ds: bd.Returning(),
399			clauses: exp.NewInsertClauses().
400				SetInto(goqu.C("items")).
401				SetReturning(exp.NewColumnListExpression()),
402		},
403		insertTestCase{
404			ds: bd.Returning(nil),
405			clauses: exp.NewInsertClauses().
406				SetInto(goqu.C("items")).
407				SetReturning(exp.NewColumnListExpression()),
408		},
409		insertTestCase{
410			ds: bd.Returning(),
411			clauses: exp.NewInsertClauses().
412				SetInto(goqu.C("items")).
413				SetReturning(exp.NewColumnListExpression()),
414		},
415		insertTestCase{
416			ds: bd.Returning("a").Returning("b"),
417			clauses: exp.NewInsertClauses().
418				SetInto(goqu.C("items")).
419				SetReturning(exp.NewColumnListExpression("b")),
420		},
421		insertTestCase{
422			ds:      bd,
423			clauses: exp.NewInsertClauses().SetInto(goqu.C("items")),
424		},
425	)
426}
427
428func (ids *insertDatasetSuite) TestReturnsColumns() {
429	ds := goqu.Insert("test")
430	ids.False(ds.ReturnsColumns())
431	ids.True(ds.Returning("foo", "bar").ReturnsColumns())
432}
433
434func (ids *insertDatasetSuite) TestExecutor() {
435	mDB, _, err := sqlmock.New()
436	ids.NoError(err)
437
438	ds := goqu.New("mock", mDB).Insert("items").
439		Rows(goqu.Record{"address": "111 Test Addr", "name": "Test1"})
440
441	isql, args, err := ds.Executor().ToSQL()
442	ids.NoError(err)
443	ids.Empty(args)
444	ids.Equal(`INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test1')`, isql)
445
446	isql, args, err = ds.Prepared(true).Executor().ToSQL()
447	ids.NoError(err)
448	ids.Equal([]interface{}{"111 Test Addr", "Test1"}, args)
449	ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?)`, isql)
450
451	defer goqu.SetDefaultPrepared(false)
452	goqu.SetDefaultPrepared(true)
453
454	isql, args, err = ds.Executor().ToSQL()
455	ids.NoError(err)
456	ids.Equal([]interface{}{"111 Test Addr", "Test1"}, args)
457	ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?)`, isql)
458}
459
460func (ids *insertDatasetSuite) TestInsertStruct() {
461	defer goqu.SetIgnoreUntaggedFields(false)
462
463	mDB, _, err := sqlmock.New()
464	ids.NoError(err)
465
466	item := dsUntaggedTestActionItem{
467		Address:  "111 Test Addr",
468		Name:     "Test1",
469		Untagged: "Test2",
470	}
471
472	ds := goqu.New("mock", mDB).Insert("items").
473		Rows(item)
474
475	isql, args, err := ds.Executor().ToSQL()
476	ids.NoError(err)
477	ids.Empty(args)
478	ids.Equal(`INSERT INTO "items" ("address", "name", "untagged") VALUES ('111 Test Addr', 'Test1', 'Test2')`, isql)
479
480	isql, args, err = ds.Prepared(true).Executor().ToSQL()
481	ids.NoError(err)
482	ids.Equal([]interface{}{"111 Test Addr", "Test1", "Test2"}, args)
483	ids.Equal(`INSERT INTO "items" ("address", "name", "untagged") VALUES (?, ?, ?)`, isql)
484
485	goqu.SetIgnoreUntaggedFields(true)
486
487	isql, args, err = ds.Executor().ToSQL()
488	ids.NoError(err)
489	ids.Empty(args)
490	ids.Equal(`INSERT INTO "items" ("address", "name") VALUES ('111 Test Addr', 'Test1')`, isql)
491
492	isql, args, err = ds.Prepared(true).Executor().ToSQL()
493	ids.NoError(err)
494	ids.Equal([]interface{}{"111 Test Addr", "Test1"}, args)
495	ids.Equal(`INSERT INTO "items" ("address", "name") VALUES (?, ?)`, isql)
496}
497
498func (ids *insertDatasetSuite) TestToSQL() {
499	md := new(mocks.SQLDialect)
500	ds := goqu.Insert("test").SetDialect(md)
501	c := ds.GetClauses()
502	sqlB := sb.NewSQLBuilder(false)
503	md.On("ToInsertSQL", sqlB, c).Return(nil).Once()
504	insertSQL, args, err := ds.ToSQL()
505	ids.Empty(insertSQL)
506	ids.Empty(args)
507	ids.Nil(err)
508	md.AssertExpectations(ids.T())
509}
510
511func (ids *insertDatasetSuite) TestToSQL_Prepared() {
512	md := new(mocks.SQLDialect)
513	ds := goqu.Insert("test").SetDialect(md).Prepared(true)
514	c := ds.GetClauses()
515	sqlB := sb.NewSQLBuilder(true)
516	md.On("ToInsertSQL", sqlB, c).Return(nil).Once()
517	insertSQL, args, err := ds.ToSQL()
518	ids.Empty(insertSQL)
519	ids.Empty(args)
520	ids.Nil(err)
521	md.AssertExpectations(ids.T())
522}
523
524func (ids *insertDatasetSuite) TestToSQL_ReturnedError() {
525	md := new(mocks.SQLDialect)
526	ds := goqu.Insert("test").SetDialect(md)
527	c := ds.GetClauses()
528	sqlB := sb.NewSQLBuilder(false)
529	ee := errors.New("expected error")
530	md.On("ToInsertSQL", sqlB, c).Run(func(args mock.Arguments) {
531		args.Get(0).(sb.SQLBuilder).SetError(ee)
532	}).Once()
533
534	insertSQL, args, err := ds.ToSQL()
535	ids.Empty(insertSQL)
536	ids.Empty(args)
537	ids.Equal(ee, err)
538	md.AssertExpectations(ids.T())
539}
540
541func (ids *insertDatasetSuite) TestSetError() {
542	err1 := errors.New("error #1")
543	err2 := errors.New("error #2")
544	err3 := errors.New("error #3")
545
546	// Verify initial error set/get works properly
547	md := new(mocks.SQLDialect)
548	ds := goqu.Insert("test").SetDialect(md)
549	ds = ds.SetError(err1)
550	ids.Equal(err1, ds.Error())
551	sql, args, err := ds.ToSQL()
552	ids.Empty(sql)
553	ids.Empty(args)
554	ids.Equal(err1, err)
555
556	// Repeated SetError calls on Dataset should not overwrite the original error
557	ds = ds.SetError(err2)
558	ids.Equal(err1, ds.Error())
559	sql, args, err = ds.ToSQL()
560	ids.Empty(sql)
561	ids.Empty(args)
562	ids.Equal(err1, err)
563
564	// Builder functions should not lose the error
565	ds = ds.Cols("a", "b")
566	ids.Equal(err1, ds.Error())
567	sql, args, err = ds.ToSQL()
568	ids.Empty(sql)
569	ids.Empty(args)
570	ids.Equal(err1, err)
571
572	// Deeper errors inside SQL generation should still return original error
573	c := ds.GetClauses()
574	sqlB := sb.NewSQLBuilder(false)
575	md.On("ToInsertSQL", sqlB, c).Run(func(args mock.Arguments) {
576		args.Get(0).(sb.SQLBuilder).SetError(err3)
577	}).Once()
578
579	sql, args, err = ds.ToSQL()
580	ids.Empty(sql)
581	ids.Empty(args)
582	ids.Equal(err1, err)
583}
584
585func TestInsertDataset(t *testing.T) {
586	suite.Run(t, new(insertDatasetSuite))
587}
588