1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"math"
21	"testing"
22	"time"
23
24	"cloud.google.com/go/civil"
25	"github.com/golang/protobuf/proto"
26	proto3 "github.com/golang/protobuf/ptypes/struct"
27	sppb "google.golang.org/genproto/googleapis/spanner/v1"
28)
29
30func TestConvertParams(t *testing.T) {
31	st := Statement{
32		SQL:    "SELECT id from t_foo WHERE col = @var",
33		Params: map[string]interface{}{"var": nil},
34	}
35	var (
36		t1, _ = time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z")
37		// Boundaries
38		t2, _ = time.Parse(time.RFC3339Nano, "0001-01-01T00:00:00.000000000Z")
39		t3, _ = time.Parse(time.RFC3339Nano, "9999-12-31T23:59:59.999999999Z")
40		d1, _ = civil.ParseDate("2016-11-15")
41		// Boundaries
42		d2, _ = civil.ParseDate("0001-01-01")
43		d3, _ = civil.ParseDate("9999-12-31")
44	)
45
46	type staticStruct struct {
47		Field int `spanner:"field"`
48	}
49
50	var (
51		s1 = staticStruct{10}
52		s2 = staticStruct{20}
53	)
54
55	for _, test := range []struct {
56		val       interface{}
57		wantField *proto3.Value
58		wantType  *sppb.Type
59	}{
60		// bool
61		{true, boolProto(true), boolType()},
62		{NullBool{true, true}, boolProto(true), boolType()},
63		{NullBool{true, false}, nullProto(), boolType()},
64		{[]bool(nil), nullProto(), listType(boolType())},
65		{[]bool{}, listProto(), listType(boolType())},
66		{[]bool{true, false}, listProto(boolProto(true), boolProto(false)), listType(boolType())},
67		{[]NullBool(nil), nullProto(), listType(boolType())},
68		{[]NullBool{}, listProto(), listType(boolType())},
69		{[]NullBool{{true, true}, {}}, listProto(boolProto(true), nullProto()), listType(boolType())},
70		// int
71		{int(1), intProto(1), intType()},
72		{[]int(nil), nullProto(), listType(intType())},
73		{[]int{}, listProto(), listType(intType())},
74		{[]int{1, 2}, listProto(intProto(1), intProto(2)), listType(intType())},
75		// int64
76		{int64(1), intProto(1), intType()},
77		{NullInt64{5, true}, intProto(5), intType()},
78		{NullInt64{5, false}, nullProto(), intType()},
79		{[]int64(nil), nullProto(), listType(intType())},
80		{[]int64{}, listProto(), listType(intType())},
81		{[]int64{1, 2}, listProto(intProto(1), intProto(2)), listType(intType())},
82		{[]NullInt64(nil), nullProto(), listType(intType())},
83		{[]NullInt64{}, listProto(), listType(intType())},
84		{[]NullInt64{{1, true}, {}}, listProto(intProto(1), nullProto()), listType(intType())},
85		// float64
86		{0.0, floatProto(0.0), floatType()},
87		{math.Inf(1), floatProto(math.Inf(1)), floatType()},
88		{math.Inf(-1), floatProto(math.Inf(-1)), floatType()},
89		{math.NaN(), floatProto(math.NaN()), floatType()},
90		{NullFloat64{2.71, true}, floatProto(2.71), floatType()},
91		{NullFloat64{1.41, false}, nullProto(), floatType()},
92		{[]float64(nil), nullProto(), listType(floatType())},
93		{[]float64{}, listProto(), listType(floatType())},
94		{[]float64{2.72, math.Inf(1)}, listProto(floatProto(2.72), floatProto(math.Inf(1))), listType(floatType())},
95		{[]NullFloat64(nil), nullProto(), listType(floatType())},
96		{[]NullFloat64{}, listProto(), listType(floatType())},
97		{[]NullFloat64{{2.72, true}, {}}, listProto(floatProto(2.72), nullProto()), listType(floatType())},
98		// string
99		{"", stringProto(""), stringType()},
100		{"foo", stringProto("foo"), stringType()},
101		{NullString{"bar", true}, stringProto("bar"), stringType()},
102		{NullString{"bar", false}, nullProto(), stringType()},
103		{[]string(nil), nullProto(), listType(stringType())},
104		{[]string{}, listProto(), listType(stringType())},
105		{[]string{"foo", "bar"}, listProto(stringProto("foo"), stringProto("bar")), listType(stringType())},
106		{[]NullString(nil), nullProto(), listType(stringType())},
107		{[]NullString{}, listProto(), listType(stringType())},
108		{[]NullString{{"foo", true}, {}}, listProto(stringProto("foo"), nullProto()), listType(stringType())},
109		// bytes
110		{[]byte{}, bytesProto([]byte{}), bytesType()},
111		{[]byte{1, 2, 3}, bytesProto([]byte{1, 2, 3}), bytesType()},
112		{[]byte(nil), nullProto(), bytesType()},
113		{[][]byte(nil), nullProto(), listType(bytesType())},
114		{[][]byte{}, listProto(), listType(bytesType())},
115		{[][]byte{{1}, []byte(nil)}, listProto(bytesProto([]byte{1}), nullProto()), listType(bytesType())},
116		// date
117		{d1, dateProto(d1), dateType()},
118		{NullDate{civil.Date{}, false}, nullProto(), dateType()},
119		{[]civil.Date(nil), nullProto(), listType(dateType())},
120		{[]civil.Date{}, listProto(), listType(dateType())},
121		{[]civil.Date{d1, d2, d3}, listProto(dateProto(d1), dateProto(d2), dateProto(d3)), listType(dateType())},
122		{[]NullDate{{d2, true}, {}}, listProto(dateProto(d2), nullProto()), listType(dateType())},
123		// timestamp
124		{t1, timeProto(t1), timeType()},
125		{NullTime{}, nullProto(), timeType()},
126		{[]time.Time(nil), nullProto(), listType(timeType())},
127		{[]time.Time{}, listProto(), listType(timeType())},
128		{[]time.Time{t1, t2, t3}, listProto(timeProto(t1), timeProto(t2), timeProto(t3)), listType(timeType())},
129		{[]NullTime{{t2, true}, {}}, listProto(timeProto(t2), nullProto()), listType(timeType())},
130		// Struct
131		{
132			s1,
133			listProto(intProto(10)),
134			structType(mkField("field", intType())),
135		},
136		{
137			(*struct {
138				F1 civil.Date `spanner:""`
139				F2 bool
140			})(nil),
141			nullProto(),
142			structType(
143				mkField("", dateType()),
144				mkField("F2", boolType())),
145		},
146		// Array-of-struct
147		{
148			[]staticStruct{s1, s2},
149			listProto(listProto(intProto(10)), listProto(intProto(20))),
150			listType(structType(mkField("field", intType()))),
151		},
152	} {
153		st.Params["var"] = test.val
154		gotParams, gotParamTypes, gotErr := st.convertParams()
155		if gotErr != nil {
156			t.Error(gotErr)
157			continue
158		}
159		gotParamField := gotParams.Fields["var"]
160		if !proto.Equal(gotParamField, test.wantField) {
161			// handle NaN
162			if test.wantType.Code == floatType().Code && proto.MarshalTextString(gotParamField) == proto.MarshalTextString(test.wantField) {
163				continue
164			}
165			t.Errorf("%#v: got %v, want %v\n", test.val, gotParamField, test.wantField)
166		}
167		gotParamType := gotParamTypes["var"]
168		if !proto.Equal(gotParamType, test.wantType) {
169			t.Errorf("%#v: got %v, want %v\n", test.val, gotParamType, test.wantField)
170		}
171	}
172
173	// Verify type error reporting.
174	for _, test := range []struct {
175		val     interface{}
176		wantErr error
177	}{
178		{
179			nil,
180			errBindParam("var", nil, errNilParam),
181		},
182	} {
183		st.Params["var"] = test.val
184		_, _, gotErr := st.convertParams()
185		if !testEqual(gotErr, test.wantErr) {
186			t.Errorf("value %#v:\ngot:  %v\nwant: %v", test.val, gotErr, test.wantErr)
187		}
188	}
189}
190
191func TestNewStatement(t *testing.T) {
192	s := NewStatement("query")
193	if got, want := s.SQL, "query"; got != want {
194		t.Errorf("got %q, want %q", got, want)
195	}
196}
197