1// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2//
3// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
4//
5// This Source Code Form is subject to the terms of the Mozilla Public
6// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7// You can obtain one at http://mozilla.org/MPL/2.0/.
8
9package mysql
10
11import (
12	"context"
13	"database/sql/driver"
14	"errors"
15	"net"
16	"testing"
17)
18
19func TestInterpolateParams(t *testing.T) {
20	mc := &mysqlConn{
21		buf:              newBuffer(nil),
22		maxAllowedPacket: maxPacketSize,
23		cfg: &Config{
24			InterpolateParams: true,
25		},
26	}
27
28	q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
29	if err != nil {
30		t.Errorf("Expected err=nil, got %#v", err)
31		return
32	}
33	expected := `SELECT 42+'gopher'`
34	if q != expected {
35		t.Errorf("Expected: %q\nGot: %q", expected, q)
36	}
37}
38
39func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
40	mc := &mysqlConn{
41		buf:              newBuffer(nil),
42		maxAllowedPacket: maxPacketSize,
43		cfg: &Config{
44			InterpolateParams: true,
45		},
46	}
47
48	q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
49	if err != driver.ErrSkip {
50		t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
51	}
52}
53
54// We don't support placeholder in string literal for now.
55// https://github.com/go-sql-driver/mysql/pull/490
56func TestInterpolateParamsPlaceholderInString(t *testing.T) {
57	mc := &mysqlConn{
58		buf:              newBuffer(nil),
59		maxAllowedPacket: maxPacketSize,
60		cfg: &Config{
61			InterpolateParams: true,
62		},
63	}
64
65	q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
66	// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
67	if err != driver.ErrSkip {
68		t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
69	}
70}
71
72func TestInterpolateParamsUint64(t *testing.T) {
73	mc := &mysqlConn{
74		buf:              newBuffer(nil),
75		maxAllowedPacket: maxPacketSize,
76		cfg: &Config{
77			InterpolateParams: true,
78		},
79	}
80
81	q, err := mc.interpolateParams("SELECT ?", []driver.Value{uint64(42)})
82	if err != nil {
83		t.Errorf("Expected err=nil, got err=%#v, q=%#v", err, q)
84	}
85	if q != "SELECT 42" {
86		t.Errorf("Expected uint64 interpolation to work, got q=%#v", q)
87	}
88}
89
90func TestCheckNamedValue(t *testing.T) {
91	value := driver.NamedValue{Value: ^uint64(0)}
92	x := &mysqlConn{}
93	err := x.CheckNamedValue(&value)
94
95	if err != nil {
96		t.Fatal("uint64 high-bit not convertible", err)
97	}
98
99	if value.Value != ^uint64(0) {
100		t.Fatalf("uint64 high-bit converted, got %#v %T", value.Value, value.Value)
101	}
102}
103
104// TestCleanCancel tests passed context is cancelled at start.
105// No packet should be sent.  Connection should keep current status.
106func TestCleanCancel(t *testing.T) {
107	mc := &mysqlConn{
108		closech: make(chan struct{}),
109	}
110	mc.startWatcher()
111	defer mc.cleanup()
112
113	ctx, cancel := context.WithCancel(context.Background())
114	cancel()
115
116	for i := 0; i < 3; i++ { // Repeat same behavior
117		err := mc.Ping(ctx)
118		if err != context.Canceled {
119			t.Errorf("expected context.Canceled, got %#v", err)
120		}
121
122		if mc.closed.IsSet() {
123			t.Error("expected mc is not closed, closed actually")
124		}
125
126		if mc.watching {
127			t.Error("expected watching is false, but true")
128		}
129	}
130}
131
132func TestPingMarkBadConnection(t *testing.T) {
133	nc := badConnection{err: errors.New("boom")}
134	ms := &mysqlConn{
135		netConn:          nc,
136		buf:              newBuffer(nc),
137		maxAllowedPacket: defaultMaxAllowedPacket,
138	}
139
140	err := ms.Ping(context.Background())
141
142	if err != driver.ErrBadConn {
143		t.Errorf("expected driver.ErrBadConn, got  %#v", err)
144	}
145}
146
147func TestPingErrInvalidConn(t *testing.T) {
148	nc := badConnection{err: errors.New("failed to write"), n: 10}
149	ms := &mysqlConn{
150		netConn:          nc,
151		buf:              newBuffer(nc),
152		maxAllowedPacket: defaultMaxAllowedPacket,
153		closech:          make(chan struct{}),
154	}
155
156	err := ms.Ping(context.Background())
157
158	if err != ErrInvalidConn {
159		t.Errorf("expected ErrInvalidConn, got  %#v", err)
160	}
161}
162
163type badConnection struct {
164	n   int
165	err error
166	net.Conn
167}
168
169func (bc badConnection) Write(b []byte) (n int, err error) {
170	return bc.n, bc.err
171}
172
173func (bc badConnection) Close() error {
174	return nil
175}
176