1package native
2
3import (
4	"bytes"
5	"fmt"
6	"github.com/ziutek/mymysql/mysql"
7	"io/ioutil"
8	"os"
9	"reflect"
10	"testing"
11	"time"
12)
13
14var (
15	my     mysql.Conn
16	user   = "testuser"
17	passwd = "TestPasswd9"
18	dbname = "test"
19	//conn   = []string{"unix", "", "/var/run/mysqld/mysqld.sock"}
20	conn  = []string{"", "", "127.0.0.1:3306"}
21	debug = false
22)
23
24type RowsResErr struct {
25	rows []mysql.Row
26	res  mysql.Result
27	err  error
28}
29
30func query(sql string, params ...interface{}) *RowsResErr {
31	rows, res, err := my.Query(sql, params...)
32	return &RowsResErr{rows, res, err}
33}
34
35func exec(stmt *Stmt, params ...interface{}) *RowsResErr {
36	rows, res, err := stmt.Exec(params...)
37	return &RowsResErr{rows, res, err}
38}
39
40func checkErr(t *testing.T, err error, exp_err error) {
41	if err != exp_err {
42		if exp_err == nil {
43			t.Fatalf("Error: %v", err)
44		} else {
45			t.Fatalf("Error: %v\nExpected error: %v", err, exp_err)
46		}
47	}
48}
49
50func checkWarnCount(t *testing.T, res_cnt, exp_cnt int) {
51	if res_cnt != exp_cnt {
52		t.Errorf("Warning count: res=%d exp=%d", res_cnt, exp_cnt)
53		rows, res, err := my.Query("show warnings")
54		if err != nil {
55			t.Fatal("Can't get warrnings from MySQL", err)
56		}
57		for _, row := range rows {
58			t.Errorf("%s: \"%s\"", row.Str(res.Map("Level")),
59				row.Str(res.Map("Message")))
60		}
61		t.FailNow()
62	}
63}
64
65func checkErrWarn(t *testing.T, res, exp *RowsResErr) {
66	checkErr(t, res.err, exp.err)
67	checkWarnCount(t, res.res.WarnCount(), exp.res.WarnCount())
68}
69
70func types(row mysql.Row) (tt []reflect.Type) {
71	tt = make([]reflect.Type, len(row))
72	for ii, val := range row {
73		tt[ii] = reflect.TypeOf(val)
74	}
75	return
76}
77
78func checkErrWarnRows(t *testing.T, res, exp *RowsResErr) {
79	checkErrWarn(t, res, exp)
80	if !reflect.DeepEqual(res.rows, exp.rows) {
81		rlen := len(res.rows)
82		elen := len(exp.rows)
83		t.Error("Rows are different!")
84		t.Errorf("len/cap: res=%d/%d exp=%d/%d",
85			rlen, cap(res.rows), elen, cap(exp.rows))
86		max := rlen
87		if elen > max {
88			max = elen
89		}
90		for ii := 0; ii < max; ii++ {
91			if ii < len(res.rows) {
92				t.Errorf("%d: res type: %s", ii, types(res.rows[ii]))
93			} else {
94				t.Errorf("%d: res: ------", ii)
95			}
96			if ii < len(exp.rows) {
97				t.Errorf("%d: exp type: %s", ii, types(exp.rows[ii]))
98			} else {
99				t.Errorf("%d: exp: ------", ii)
100			}
101			if ii < len(res.rows) {
102				t.Error(" res: ", res.rows[ii])
103			}
104			if ii < len(exp.rows) {
105				t.Error(" exp: ", exp.rows[ii])
106			}
107			if ii < len(res.rows) {
108				t.Errorf(" res: %#v", res.rows[ii][2])
109			}
110			if ii < len(exp.rows) {
111				t.Errorf(" exp: %#v", exp.rows[ii][2])
112			}
113		}
114		t.FailNow()
115	}
116}
117
118func checkResult(t *testing.T, res, exp *RowsResErr) {
119	checkErrWarnRows(t, res, exp)
120	r, e := res.res.(*Result), exp.res.(*Result)
121	if r.my != e.my || r.binary != e.binary || r.status_only != e.status_only ||
122		r.status&0xdf != e.status || !bytes.Equal(r.message, e.message) ||
123		r.affected_rows != e.affected_rows ||
124		r.eor_returned != e.eor_returned ||
125		!reflect.DeepEqual(res.rows, exp.rows) || res.err != exp.err {
126		t.Fatalf("Bad result:\nres=%+v\nexp=%+v", res.res, exp.res)
127	}
128}
129
130func cmdOK(affected uint64, binary, eor bool) *RowsResErr {
131	return &RowsResErr{
132		res: &Result{
133			my:            my.(*Conn),
134			binary:        binary,
135			status_only:   true,
136			status:        0x2,
137			message:       []byte{},
138			affected_rows: affected,
139			eor_returned:  eor,
140		},
141	}
142}
143
144func selectOK(rows []mysql.Row, binary bool) (exp *RowsResErr) {
145	exp = cmdOK(0, binary, true)
146	exp.rows = rows
147	return
148}
149
150func myConnect(t *testing.T, with_dbname bool, max_pkt_size int) {
151	if with_dbname {
152		my = New(conn[0], conn[1], conn[2], user, passwd, dbname)
153	} else {
154		my = New(conn[0], conn[1], conn[2], user, passwd)
155	}
156
157	if max_pkt_size != 0 {
158		my.SetMaxPktSize(max_pkt_size)
159	}
160	my.(*Conn).Debug = debug
161
162	checkErr(t, my.Connect(), nil)
163	checkResult(t, query("set names utf8"), cmdOK(0, false, true))
164}
165
166func myClose(t *testing.T) {
167	checkErr(t, my.Close(), nil)
168}
169
170// Text queries tests
171
172func TestUse(t *testing.T) {
173	myConnect(t, false, 0)
174	checkErr(t, my.Use(dbname), nil)
175	myClose(t)
176}
177
178func TestPing(t *testing.T) {
179	myConnect(t, false, 0)
180	checkErr(t, my.Ping(), nil)
181	myClose(t)
182}
183
184func TestQuery(t *testing.T) {
185	myConnect(t, true, 0)
186	query("drop table t") // Drop test table if exists
187	checkResult(t, query("create table t (s varchar(40))"),
188		cmdOK(0, false, true))
189
190	exp := &RowsResErr{
191		res: &Result{
192			my:          my.(*Conn),
193			field_count: 1,
194			fields: []*mysql.Field{
195				&mysql.Field{
196					Catalog:  "def",
197					Db:       "test",
198					Table:    "Test",
199					OrgTable: "T",
200					Name:     "Str",
201					OrgName:  "s",
202					DispLen:  3 * 40, //varchar(40)
203					Flags:    0,
204					Type:     MYSQL_TYPE_VAR_STRING,
205					Scale:    0,
206				},
207			},
208			status:       _SERVER_STATUS_AUTOCOMMIT,
209			eor_returned: true,
210		},
211	}
212
213	for ii := 0; ii > 10000; ii += 3 {
214		var val interface{}
215		if ii%10 == 0 {
216			checkResult(t, query("insert t values (null)"),
217				cmdOK(1, false, true))
218			val = nil
219		} else {
220			txt := []byte(fmt.Sprintf("%d %d %d %d %d", ii, ii, ii, ii, ii))
221			checkResult(t,
222				query("insert t values ('%s')", txt), cmdOK(1, false, true))
223			val = txt
224		}
225		exp.rows = append(exp.rows, mysql.Row{val})
226	}
227
228	checkResult(t, query("select s as Str from t as Test"), exp)
229	checkResult(t, query("drop table t"), cmdOK(0, false, true))
230	myClose(t)
231}
232
233// Prepared statements tests
234
235type StmtErr struct {
236	stmt *Stmt
237	err  error
238}
239
240func prepare(sql string) *StmtErr {
241	stmt, err := my.Prepare(sql)
242	return &StmtErr{stmt.(*Stmt), err}
243}
244
245func checkStmt(t *testing.T, res, exp *StmtErr) {
246	ok := res.err == exp.err &&
247		// Skipping id
248		reflect.DeepEqual(res.stmt.fields, exp.stmt.fields) &&
249		res.stmt.field_count == exp.stmt.field_count &&
250		res.stmt.param_count == exp.stmt.param_count &&
251		res.stmt.warning_count == exp.stmt.warning_count &&
252		res.stmt.status == exp.stmt.status
253
254	if !ok {
255		if exp.err == nil {
256			checkErr(t, res.err, nil)
257			checkWarnCount(t, res.stmt.warning_count, exp.stmt.warning_count)
258			for _, v := range res.stmt.fields {
259				fmt.Printf("%+v\n", v)
260			}
261			t.Fatalf("Bad result statement: res=%v exp=%v", res.stmt, exp.stmt)
262		}
263	}
264}
265
266func TestPrepared(t *testing.T) {
267	myConnect(t, true, 0)
268	query("drop table p") // Drop test table if exists
269	checkResult(t,
270		query(
271			"create table p ("+
272				"   ii int not null, ss varchar(20), dd datetime"+
273				") default charset=utf8",
274		),
275		cmdOK(0, false, true),
276	)
277
278	exp := Stmt{
279		fields: []*mysql.Field{
280			&mysql.Field{
281				Catalog: "def", Db: "test", Table: "p", OrgTable: "p",
282				Name:    "i",
283				OrgName: "ii",
284				DispLen: 11,
285				Flags:   _FLAG_NO_DEFAULT_VALUE | _FLAG_NOT_NULL,
286				Type:    MYSQL_TYPE_LONG,
287				Scale:   0,
288			},
289			&mysql.Field{
290				Catalog: "def", Db: "test", Table: "p", OrgTable: "p",
291				Name:    "s",
292				OrgName: "ss",
293				DispLen: 3 * 20, // varchar(20)
294				Flags:   0,
295				Type:    MYSQL_TYPE_VAR_STRING,
296				Scale:   0,
297			},
298			&mysql.Field{
299				Catalog: "def", Db: "test", Table: "p", OrgTable: "p",
300				Name:    "d",
301				OrgName: "dd",
302				DispLen: 19,
303				Flags:   _FLAG_BINARY,
304				Type:    MYSQL_TYPE_DATETIME,
305				Scale:   0,
306			},
307		},
308		field_count:   3,
309		param_count:   2,
310		warning_count: 0,
311		status:        0x2,
312	}
313
314	sel := prepare("select ii i, ss s, dd d from p where ii = ? and ss = ?")
315	checkStmt(t, sel, &StmtErr{&exp, nil})
316
317	all := prepare("select * from p")
318	checkErr(t, all.err, nil)
319
320	ins := prepare("insert p values (?, ?, ?)")
321	checkErr(t, ins.err, nil)
322
323	parsed, err := mysql.ParseTime("2012-01-17 01:10:10", time.Local)
324	checkErr(t, err, nil)
325	parsedZero, err := mysql.ParseTime("0000-00-00 00:00:00", time.Local)
326	checkErr(t, err, nil)
327	if !parsedZero.IsZero() {
328		t.Fatalf("time '%s' isn't zero", parsedZero)
329	}
330	exp_rows := []mysql.Row{
331		mysql.Row{
332			2, "Taki tekst", time.Unix(123456789, 0),
333		},
334		mysql.Row{
335			5, "Pąk róży", parsed,
336		},
337		mysql.Row{
338			-3, "基础体温", parsed,
339		},
340		mysql.Row{
341			11, "Zero UTC datetime", time.Unix(0, 0),
342		},
343		mysql.Row{
344			17, mysql.Blob([]byte("Zero datetime")), parsedZero,
345		},
346		mysql.Row{
347			23, []byte("NULL datetime"), (*time.Time)(nil),
348		},
349		mysql.Row{
350			23, "NULL", nil,
351		},
352	}
353
354	for _, row := range exp_rows {
355		checkErrWarn(t,
356			exec(ins.stmt, row[0], row[1], row[2]),
357			cmdOK(1, true, true),
358		)
359	}
360
361	// Convert values to expected result types
362	for _, row := range exp_rows {
363		for ii, col := range row {
364			val := reflect.ValueOf(col)
365			// Dereference pointers
366			if val.Kind() == reflect.Ptr {
367				val = val.Elem()
368			}
369			switch val.Kind() {
370			case reflect.Invalid:
371				row[ii] = nil
372
373			case reflect.String:
374				row[ii] = []byte(val.String())
375
376			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
377				reflect.Int64:
378				row[ii] = int32(val.Int())
379
380			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
381				reflect.Uint64:
382				row[ii] = int32(val.Uint())
383
384			case reflect.Slice:
385				if val.Type().Elem().Kind() == reflect.Uint8 {
386					bytes := make([]byte, val.Len())
387					for ii := range bytes {
388						bytes[ii] = val.Index(ii).Interface().(uint8)
389					}
390					row[ii] = bytes
391				}
392			}
393		}
394	}
395
396	checkErrWarn(t, exec(sel.stmt, 2, "Taki tekst"), selectOK(exp_rows, true))
397	checkErrWarnRows(t, exec(all.stmt), selectOK(exp_rows, true))
398
399	checkResult(t, query("drop table p"), cmdOK(0, false, true))
400
401	checkErr(t, sel.stmt.Delete(), nil)
402	checkErr(t, all.stmt.Delete(), nil)
403	checkErr(t, ins.stmt.Delete(), nil)
404
405	myClose(t)
406}
407
408// Bind testing
409
410func TestVarBinding(t *testing.T) {
411	myConnect(t, true, 0)
412	query("drop table t") // Drop test table if exists
413	checkResult(t,
414		query("create table t (id int primary key, str varchar(20))"),
415		cmdOK(0, false, true),
416	)
417
418	ins, err := my.Prepare("insert t values (?, ?)")
419	checkErr(t, err, nil)
420
421	var (
422		rre RowsResErr
423		id  *int
424		str *string
425		ii  int
426		ss  string
427	)
428	ins.Bind(&id, &str)
429
430	i1 := 1
431	s1 := "Ala"
432	id = &i1
433	str = &s1
434	rre.res, rre.err = ins.Run()
435	checkResult(t, &rre, cmdOK(1, true, false))
436
437	i2 := 2
438	s2 := "Ma kota!"
439	id = &i2
440	str = &s2
441
442	rre.res, rre.err = ins.Run()
443	checkResult(t, &rre, cmdOK(1, true, false))
444
445	ins.Bind(&ii, &ss)
446	ii = 3
447	ss = "A kot ma Ale!"
448
449	rre.res, rre.err = ins.Run()
450	checkResult(t, &rre, cmdOK(1, true, false))
451
452	sel, err := my.Prepare("select str from t where id = ?")
453	checkErr(t, err, nil)
454
455	rows, _, err := sel.Exec(1)
456	checkErr(t, err, nil)
457	if len(rows) != 1 || bytes.Compare([]byte(s1), rows[0].Bin(0)) != 0 {
458		t.Fatal("First string don't match")
459	}
460
461	rows, _, err = sel.Exec(2)
462	checkErr(t, err, nil)
463	if len(rows) != 1 || bytes.Compare([]byte(s2), rows[0].Bin(0)) != 0 {
464		t.Fatal("Second string don't match")
465	}
466
467	rows, _, err = sel.Exec(3)
468	checkErr(t, err, nil)
469	if len(rows) != 1 || bytes.Compare([]byte(ss), rows[0].Bin(0)) != 0 {
470		t.Fatal("Thrid string don't match")
471	}
472
473	checkResult(t, query("drop table t"), cmdOK(0, false, true))
474	myClose(t)
475}
476
477func TestBindStruct(t *testing.T) {
478	myConnect(t, true, 0)
479	query("drop table t") // Drop test table if exists
480	checkResult(t,
481		query("create table t (id int primary key, txt varchar(20), b bool)"),
482		cmdOK(0, false, true),
483	)
484
485	ins, err := my.Prepare("insert t values (?, ?, ?)")
486	checkErr(t, err, nil)
487	sel, err := my.Prepare("select txt, b from t where id = ?")
488	checkErr(t, err, nil)
489
490	var (
491		s struct {
492			Id  int
493			Txt string
494			B   bool
495		}
496		rre RowsResErr
497	)
498
499	ins.Bind(&s)
500
501	s.Id = 2
502	s.Txt = "Ala ma kota."
503	s.B = true
504
505	rre.res, rre.err = ins.Run()
506	checkResult(t, &rre, cmdOK(1, true, false))
507
508	rows, _, err := sel.Exec(s.Id)
509	checkErr(t, err, nil)
510	if len(rows) != 1 || rows[0].Str(0) != s.Txt || rows[0].Bool(1) != s.B {
511		t.Fatal("selected data don't match inserted data")
512	}
513
514	checkResult(t, query("drop table t"), cmdOK(0, false, true))
515	myClose(t)
516}
517
518func TestDate(t *testing.T) {
519	myConnect(t, true, 0)
520	query("drop table d") // Drop test table if exists
521	checkResult(t,
522		query("create table d (id int, dd date, dt datetime, tt time)"),
523		cmdOK(0, false, true),
524	)
525
526	test := []struct {
527		dd, dt string
528		tt     time.Duration
529	}{
530		{
531			"2011-11-13",
532			"2010-12-12 11:24:00",
533			-time.Duration((128*3600 + 3*60 + 2) * 1e9),
534		}, {
535			"0000-00-00",
536			"0000-00-00 00:00:00",
537			time.Duration(0),
538		},
539	}
540
541	ins, err := my.Prepare("insert d values (?, ?, ?, ?)")
542	checkErr(t, err, nil)
543
544	sel, err := my.Prepare("select id, tt from d where dd = ? && dt = ?")
545	checkErr(t, err, nil)
546
547	for i, r := range test {
548		_, err = ins.Run(i, r.dd, r.dt, r.tt)
549		checkErr(t, err, nil)
550
551		sdt, err := mysql.ParseTime(r.dt, time.Local)
552		checkErr(t, err, nil)
553		sdd, err := mysql.ParseDate(r.dd)
554		checkErr(t, err, nil)
555
556		rows, _, err := sel.Exec(sdd, sdt)
557		checkErr(t, err, nil)
558		if rows == nil {
559			t.Fatal("nil result")
560		}
561		if rows[0].Int(0) != i {
562			t.Fatal("Bad id", rows[0].Int(1))
563		}
564		if rows[0][1].(time.Duration) != r.tt {
565			t.Fatal("Bad tt", rows[0].Duration(1))
566		}
567	}
568
569	checkResult(t, query("drop table d"), cmdOK(0, false, true))
570	myClose(t)
571}
572
573func TestDateTimeZone(t *testing.T) {
574	myConnect(t, true, 0)
575	query("drop table d") // Drop test table if exists
576	checkResult(t,
577		query("create table d (dt datetime)"),
578		cmdOK(0, false, true),
579	)
580
581	ins, err := my.Prepare("insert d values (?)")
582	checkErr(t, err, nil)
583
584	sel, err := my.Prepare("select dt from d")
585	checkErr(t, err, nil)
586
587	tstr := "2013-05-10 15:26:00.000000000"
588
589	_, err = ins.Run(tstr)
590	checkErr(t, err, nil)
591
592	tt := make([]time.Time, 4)
593
594	row, _, err := sel.ExecFirst()
595	checkErr(t, err, nil)
596	tt[0] = row.Time(0, time.UTC)
597	tt[1] = row.Time(0, time.Local)
598	row, _, err = my.QueryFirst("select dt from d")
599	checkErr(t, err, nil)
600	tt[2] = row.Time(0, time.UTC)
601	tt[3] = row.Time(0, time.Local)
602	for _, v := range tt {
603		if v.Format(mysql.TimeFormat) != tstr {
604			t.Fatal("Timezone problem:", tstr, "!=", v)
605		}
606	}
607
608	checkResult(t, query("drop table d"), cmdOK(0, false, true))
609	myClose(t)
610}
611
612// Big blob
613func TestBigBlob(t *testing.T) {
614	myConnect(t, true, 34*1024*1024)
615	query("drop table p") // Drop test table if exists
616	checkResult(t,
617		query("create table p (id int primary key, bb longblob)"),
618		cmdOK(0, false, true),
619	)
620
621	ins, err := my.Prepare("insert p values (?, ?)")
622	checkErr(t, err, nil)
623
624	sel, err := my.Prepare("select bb from p where id = ?")
625	checkErr(t, err, nil)
626
627	big_blob := make(mysql.Blob, 33*1024*1024)
628	for ii := range big_blob {
629		big_blob[ii] = byte(ii)
630	}
631
632	var (
633		rre RowsResErr
634		bb  mysql.Blob
635		id  int
636	)
637	data := struct {
638		Id int
639		Bb mysql.Blob
640	}{}
641
642	// Individual parameters binding
643	ins.Bind(&id, &bb)
644	id = 1
645	bb = big_blob
646
647	// Insert full blob. Three packets are sended. First two has maximum length
648	rre.res, rre.err = ins.Run()
649	checkResult(t, &rre, cmdOK(1, true, false))
650
651	// Struct binding
652	ins.Bind(&data)
653	data.Id = 2
654	data.Bb = big_blob[0 : 32*1024*1024-31]
655
656	// Insert part of blob - Two packets are sended. All has maximum length.
657	rre.res, rre.err = ins.Run()
658	checkResult(t, &rre, cmdOK(1, true, false))
659
660	sel.Bind(&id)
661
662	// Check first insert.
663	tmr := "Too many rows"
664
665	id = 1
666	res, err := sel.Run()
667	checkErr(t, err, nil)
668
669	row, err := res.GetRow()
670	checkErr(t, err, nil)
671	end, err := res.GetRow()
672	checkErr(t, err, nil)
673	if end != nil {
674		t.Fatal(tmr)
675	}
676
677	if bytes.Compare(row[0].([]byte), big_blob) != 0 {
678		t.Fatal("Full blob data don't match")
679	}
680
681	// Check second insert.
682	id = 2
683	res, err = sel.Run()
684	checkErr(t, err, nil)
685
686	row, err = res.GetRow()
687	checkErr(t, err, nil)
688	end, err = res.GetRow()
689	checkErr(t, err, nil)
690	if end != nil {
691		t.Fatal(tmr)
692	}
693
694	if bytes.Compare(row.Bin(res.Map("bb")), data.Bb) != 0 {
695		t.Fatal("Partial blob data don't match")
696	}
697
698	checkResult(t, query("drop table p"), cmdOK(0, false, true))
699	myClose(t)
700}
701
702// Test for empty result
703func TestEmpty(t *testing.T) {
704	checkNil := func(r mysql.Row) {
705		if r != nil {
706			t.Error("Not empty result")
707		}
708	}
709	myConnect(t, true, 0)
710	query("drop table e") // Drop test table if exists
711	// Create table
712	checkResult(t,
713		query("create table e (id int)"),
714		cmdOK(0, false, true),
715	)
716	// Text query
717	res, err := my.Start("select * from e")
718	checkErr(t, err, nil)
719	row, err := res.GetRow()
720	checkErr(t, err, nil)
721	checkNil(row)
722	row, err = res.GetRow()
723	checkErr(t, err, mysql.ErrReadAfterEOR)
724	checkNil(row)
725	// Prepared statement
726	sel, err := my.Prepare("select * from e")
727	checkErr(t, err, nil)
728	res, err = sel.Run()
729	checkErr(t, err, nil)
730	row, err = res.GetRow()
731	checkErr(t, err, nil)
732	checkNil(row)
733	row, err = res.GetRow()
734	checkErr(t, err, mysql.ErrReadAfterEOR)
735	checkNil(row)
736	// Drop test table
737	checkResult(t, query("drop table e"), cmdOK(0, false, true))
738}
739
740// Reconnect test
741func TestReconnect(t *testing.T) {
742	myConnect(t, true, 0)
743	query("drop table r") // Drop test table if exists
744	checkResult(t,
745		query("create table r (id int primary key, str varchar(20))"),
746		cmdOK(0, false, true),
747	)
748
749	ins, err := my.Prepare("insert r values (?, ?)")
750	checkErr(t, err, nil)
751	sel, err := my.Prepare("select str from r where id = ?")
752	checkErr(t, err, nil)
753
754	params := struct {
755		Id  int
756		Str string
757	}{}
758	var sel_id int
759
760	ins.Bind(&params)
761	sel.Bind(&sel_id)
762
763	checkErr(t, my.Reconnect(), nil)
764
765	params.Id = 1
766	params.Str = "Bla bla bla"
767	_, err = ins.Run()
768	checkErr(t, err, nil)
769
770	checkErr(t, my.Reconnect(), nil)
771
772	sel_id = 1
773	res, err := sel.Run()
774	checkErr(t, err, nil)
775
776	row, err := res.GetRow()
777	checkErr(t, err, nil)
778
779	checkErr(t, res.End(), nil)
780
781	if row == nil || row[0] == nil ||
782		params.Str != row.Str(0) {
783		t.Fatal("Bad result")
784	}
785
786	checkErr(t, my.Reconnect(), nil)
787
788	checkResult(t, query("drop table r"), cmdOK(0, false, true))
789	myClose(t)
790}
791
792// StmtSendLongData test
793
794func TestSendLongData(t *testing.T) {
795	myConnect(t, true, 64*1024*1024)
796	query("drop table l") // Drop test table if exists
797	checkResult(t,
798		query("create table l (id int primary key, bb longblob)"),
799		cmdOK(0, false, true),
800	)
801	ins, err := my.Prepare("insert l values (?, ?)")
802	checkErr(t, err, nil)
803
804	sel, err := my.Prepare("select bb from l where id = ?")
805	checkErr(t, err, nil)
806
807	var (
808		rre RowsResErr
809		id  int64
810	)
811
812	ins.Bind(&id, []byte(nil))
813	sel.Bind(&id)
814
815	// Prepare data
816	data := make([]byte, 4*1024*1024)
817	for ii := range data {
818		data[ii] = byte(ii)
819	}
820	// Send long data twice
821	checkErr(t, ins.SendLongData(1, data, 256*1024), nil)
822	checkErr(t, ins.SendLongData(1, data, 512*1024), nil)
823
824	id = 1
825	rre.res, rre.err = ins.Run()
826	checkResult(t, &rre, cmdOK(1, true, false))
827
828	res, err := sel.Run()
829	checkErr(t, err, nil)
830
831	row, err := res.GetRow()
832	checkErr(t, err, nil)
833
834	checkErr(t, res.End(), nil)
835
836	if row == nil || row[0] == nil ||
837		bytes.Compare(append(data, data...), row.Bin(0)) != 0 {
838		t.Fatal("Bad result")
839	}
840
841	file, err := ioutil.TempFile("", "mymysql_test-")
842	checkErr(t, err, nil)
843	filename := file.Name()
844	defer os.Remove(filename)
845
846	buf := make([]byte, 1024)
847	for i := 0; i < 2048; i++ {
848		_, err := file.Write(buf)
849		checkErr(t, err, nil)
850	}
851	checkErr(t, file.Close(), nil)
852
853	// Send long data from io.Reader twice
854	file, err = os.Open(filename)
855	checkErr(t, err, nil)
856	checkErr(t, ins.SendLongData(1, file, 128*1024), nil)
857	checkErr(t, file.Close(), nil)
858	file, err = os.Open(filename)
859	checkErr(t, err, nil)
860	checkErr(t, ins.SendLongData(1, file, 1024*1024), nil)
861	checkErr(t, file.Close(), nil)
862
863	id = 2
864	rre.res, rre.err = ins.Run()
865	checkResult(t, &rre, cmdOK(1, true, false))
866
867	res, err = sel.Run()
868	checkErr(t, err, nil)
869
870	row, err = res.GetRow()
871	checkErr(t, err, nil)
872
873	checkErr(t, res.End(), nil)
874
875	// Read file for check result
876	data, err = ioutil.ReadFile(filename)
877	checkErr(t, err, nil)
878
879	if row == nil || row[0] == nil ||
880		bytes.Compare(append(data, data...), row.Bin(0)) != 0 {
881		t.Fatal("Bad result")
882	}
883
884	checkResult(t, query("drop table l"), cmdOK(0, false, true))
885	myClose(t)
886}
887
888func TestNull(t *testing.T) {
889	myConnect(t, true, 0)
890	query("drop table if exists n")
891	checkResult(t,
892		query("create table n (i int not null, n int)"),
893		cmdOK(0, false, true),
894	)
895	ins, err := my.Prepare("insert n values (?, ?)")
896	checkErr(t, err, nil)
897
898	var (
899		p   struct{ I, N *int }
900		rre RowsResErr
901	)
902	ins.Bind(&p)
903
904	p.I = new(int)
905	p.N = new(int)
906
907	*p.I = 0
908	*p.N = 1
909	rre.res, rre.err = ins.Run()
910	checkResult(t, &rre, cmdOK(1, true, false))
911	*p.I = 1
912	p.N = nil
913	rre.res, rre.err = ins.Run()
914	checkResult(t, &rre, cmdOK(1, true, false))
915
916	checkResult(t, query("insert n values (2, 1)"), cmdOK(1, false, true))
917	checkResult(t, query("insert n values (3, NULL)"), cmdOK(1, false, true))
918
919	rows, res, err := my.Query("select * from n")
920	checkErr(t, err, nil)
921	if len(rows) != 4 {
922		t.Fatal("str: len(rows) != 4")
923	}
924	i := res.Map("i")
925	n := res.Map("n")
926	for k, row := range rows {
927		switch {
928		case row[i] == nil || row.Int(i) != k:
929		case k%2 == 1 && row[n] != nil:
930		case k%2 == 0 && (row[n] == nil || row.Int(n) != 1):
931		default:
932			continue
933		}
934		t.Fatalf("str row: %d = (%s, %s)", k, row[i], row[n])
935	}
936
937	sel, err := my.Prepare("select * from n")
938	checkErr(t, err, nil)
939	rows, res, err = sel.Exec()
940	checkErr(t, err, nil)
941	if len(rows) != 4 {
942		t.Fatal("bin: len(rows) != 4")
943	}
944	i = res.Map("i")
945	n = res.Map("n")
946	for k, row := range rows {
947		switch {
948		case row[i] == nil || row.Int(i) != k:
949		case k%2 == 1 && row[n] != nil:
950		case k%2 == 0 && (row[n] == nil || row.Int(n) != 1):
951		default:
952			continue
953		}
954		t.Fatalf("bin row: %d = (%v, %v)", k, row[i], row[n])
955	}
956
957	checkResult(t, query("drop table n"), cmdOK(0, false, true))
958}
959
960func TestMultipleResults(t *testing.T) {
961	myConnect(t, true, 0)
962	query("drop table m") // Drop test table if exists
963	checkResult(t,
964		query("create table m (id int primary key, str varchar(20))"),
965		cmdOK(0, false, true),
966	)
967
968	str := []string{"zero", "jeden", "dwa"}
969
970	checkResult(t, query("insert m values (0, '%s')", str[0]),
971		cmdOK(1, false, true))
972	checkResult(t, query("insert m values (1, '%s')", str[1]),
973		cmdOK(1, false, true))
974	checkResult(t, query("insert m values (2, '%s')", str[2]),
975		cmdOK(1, false, true))
976
977	res, err := my.Start("select id from m; select str from m")
978	checkErr(t, err, nil)
979
980	for ii := 0; ; ii++ {
981		row, err := res.GetRow()
982		checkErr(t, err, nil)
983		if row == nil {
984			break
985		}
986		if row.Int(0) != ii {
987			t.Fatal("Bad result")
988		}
989	}
990	res, err = res.NextResult()
991	checkErr(t, err, nil)
992	for ii := 0; ; ii++ {
993		row, err := res.GetRow()
994		checkErr(t, err, nil)
995		if row == nil {
996			break
997		}
998		if row.Str(0) != str[ii] {
999			t.Fatal("Bad result")
1000		}
1001	}
1002
1003	checkResult(t, query("drop table m"), cmdOK(0, false, true))
1004	myClose(t)
1005}
1006
1007func TestDecimal(t *testing.T) {
1008	myConnect(t, true, 0)
1009
1010	query("drop table if exists d")
1011	checkResult(t,
1012		query("create table d (d decimal(4,2))"),
1013		cmdOK(0, false, true),
1014	)
1015
1016	checkResult(t, query("insert d values (10.01)"), cmdOK(1, false, true))
1017	sql := "select * from d"
1018	sel, err := my.Prepare(sql)
1019	checkErr(t, err, nil)
1020	rows, res, err := sel.Exec()
1021	checkErr(t, err, nil)
1022	if len(rows) != 1 || rows[0][res.Map("d")].(float64) != 10.01 {
1023		t.Fatal(sql)
1024	}
1025
1026	checkResult(t, query("drop table d"), cmdOK(0, false, true))
1027	myClose(t)
1028}
1029
1030func TestMediumInt(t *testing.T) {
1031	myConnect(t, true, 0)
1032	query("DROP TABLE mi")
1033	checkResult(t,
1034		query(
1035			`CREATE TABLE mi (
1036				id INT PRIMARY KEY AUTO_INCREMENT,
1037				m MEDIUMINT
1038			)`,
1039		),
1040		cmdOK(0, false, true),
1041	)
1042
1043	const n = 9
1044
1045	for i := 0; i < n; i++ {
1046		res, err := my.Start("INSERT mi VALUES (0, %d)", i)
1047		checkErr(t, err, nil)
1048		if res.InsertId() != uint64(i+1) {
1049			t.Fatalf("Wrong insert id: %d, expected: %d", res.InsertId(), i+1)
1050		}
1051	}
1052
1053	sel, err := my.Prepare("SELECT * FROM mi")
1054	checkErr(t, err, nil)
1055
1056	res, err := sel.Run()
1057	checkErr(t, err, nil)
1058
1059	i := 0
1060	for {
1061		row, err := res.GetRow()
1062		checkErr(t, err, nil)
1063		if row == nil {
1064			break
1065		}
1066		id, m := row.Int(0), row.Int(1)
1067		if id != i+1 || m != i {
1068			t.Fatalf("i=%d id=%d m=%d", i, id, m)
1069		}
1070		i++
1071	}
1072	if i != n {
1073		t.Fatalf("%d rows read, %d expected", i, n)
1074	}
1075	checkResult(t, query("drop table mi"), cmdOK(0, false, true))
1076}
1077
1078func TestStoredProcedures(t *testing.T) {
1079	myConnect(t, true, 0)
1080	query("DROP PROCEDURE pr")
1081	query("DROP TABLE p")
1082	checkResult(t,
1083		query(
1084			`CREATE TABLE p (
1085				id INT PRIMARY KEY AUTO_INCREMENT,
1086				txt VARCHAR(8)
1087			)`,
1088		),
1089		cmdOK(0, false, true),
1090	)
1091	_, err := my.Start(
1092		`CREATE PROCEDURE pr (IN i INT)
1093		BEGIN
1094			INSERT p VALUES (0, "aaa");
1095			SELECT * FROM p;
1096			SELECT i * id FROM p;
1097		END`,
1098	)
1099	checkErr(t, err, nil)
1100
1101	res, err := my.Start("CALL pr(3)")
1102	checkErr(t, err, nil)
1103
1104	rows, err := res.GetRows()
1105	checkErr(t, err, nil)
1106	if len(rows) != 1 || len(rows[0]) != 2 || rows[0].Int(0) != 1 || rows[0].Str(1) != "aaa" {
1107		t.Fatalf("Bad result set: %+v", rows)
1108	}
1109
1110	res, err = res.NextResult()
1111	checkErr(t, err, nil)
1112
1113	rows, err = res.GetRows()
1114	checkErr(t, err, nil)
1115	if len(rows) != 1 || len(rows[0]) != 1 || rows[0].Int(0) != 3 {
1116		t.Fatalf("Bad result set: %+v", rows)
1117	}
1118
1119	res, err = res.NextResult()
1120	checkErr(t, err, nil)
1121	if !res.StatusOnly() {
1122		t.Fatalf("Result includes resultset at end of procedure: %+v", res)
1123	}
1124
1125	_, err = my.Start("DROP PROCEDURE pr")
1126	checkErr(t, err, nil)
1127
1128	checkResult(t, query("DROP TABLE p"), cmdOK(0, false, true))
1129}
1130
1131// Benchamrks
1132
1133func check(err error) {
1134	if err != nil {
1135		fmt.Println(err)
1136		os.Exit(1)
1137	}
1138}
1139
1140func BenchmarkInsertSelect(b *testing.B) {
1141	b.StopTimer()
1142
1143	my := New(conn[0], conn[1], conn[2], user, passwd, dbname)
1144	check(my.Connect())
1145
1146	my.Start("drop table b") // Drop test table if exists
1147
1148	_, err := my.Start("create table b (s varchar(40), i int)")
1149	check(err)
1150
1151	for ii := 0; ii < 10000; ii++ {
1152		_, err := my.Start("insert b values ('%d-%d-%d', %d)", ii, ii, ii, ii)
1153		check(err)
1154	}
1155
1156	b.StartTimer()
1157
1158	for ii := 0; ii < b.N; ii++ {
1159		res, err := my.Start("select * from b")
1160		check(err)
1161		for {
1162			row, err := res.GetRow()
1163			check(err)
1164			if row == nil {
1165				break
1166			}
1167		}
1168	}
1169
1170	b.StopTimer()
1171
1172	_, err = my.Start("drop table b")
1173	check(err)
1174	check(my.Close())
1175}
1176
1177func BenchmarkPreparedInsertSelect(b *testing.B) {
1178	b.StopTimer()
1179
1180	my := New(conn[0], conn[1], conn[2], user, passwd, dbname)
1181	check(my.Connect())
1182
1183	my.Start("drop table b") // Drop test table if exists
1184
1185	_, err := my.Start("create table b (s varchar(40), i int)")
1186	check(err)
1187
1188	ins, err := my.Prepare("insert b values (?, ?)")
1189	check(err)
1190
1191	sel, err := my.Prepare("select * from b")
1192	check(err)
1193
1194	for ii := 0; ii < 10000; ii++ {
1195		_, err := ins.Run(fmt.Sprintf("%d-%d-%d", ii, ii, ii), ii)
1196		check(err)
1197	}
1198
1199	b.StartTimer()
1200
1201	for ii := 0; ii < b.N; ii++ {
1202		res, err := sel.Run()
1203		check(err)
1204		for {
1205			row, err := res.GetRow()
1206			check(err)
1207			if row == nil {
1208				break
1209			}
1210		}
1211	}
1212
1213	b.StopTimer()
1214
1215	_, err = my.Start("drop table b")
1216	check(err)
1217	check(my.Close())
1218}
1219