1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"bytes"
21	"encoding/base64"
22	"fmt"
23	"math"
24	"reflect"
25	"strconv"
26	"time"
27
28	"cloud.google.com/go/civil"
29	"cloud.google.com/go/internal/fields"
30	"github.com/golang/protobuf/proto"
31	proto3 "github.com/golang/protobuf/ptypes/struct"
32	sppb "google.golang.org/genproto/googleapis/spanner/v1"
33	"google.golang.org/grpc/codes"
34)
35
36// nullString is returned by the String methods of NullableValues when the
37// underlying database value is null.
38const nullString = "<null>"
39const commitTimestampPlaceholderString = "spanner.commit_timestamp()"
40
41var (
42	// CommitTimestamp is a special value used to tell Cloud Spanner to insert
43	// the commit timestamp of the transaction into a column. It can be used in
44	// a Mutation, or directly used in InsertStruct or InsertMap. See
45	// ExampleCommitTimestamp. This is just a placeholder and the actual value
46	// stored in this variable has no meaning.
47	CommitTimestamp = commitTimestamp
48	commitTimestamp = time.Unix(0, 0).In(time.FixedZone("CommitTimestamp placeholder", 0xDB))
49
50	jsonNullBytes = []byte("null")
51)
52
53// NullableValue is the interface implemented by all null value wrapper types.
54type NullableValue interface {
55	// IsNull returns true if the underlying database value is null.
56	IsNull() bool
57}
58
59// NullInt64 represents a Cloud Spanner INT64 that may be NULL.
60type NullInt64 struct {
61	Int64 int64
62	Valid bool // Valid is true if Int64 is not NULL.
63}
64
65// IsNull implements NullableValue.IsNull for NullInt64.
66func (n NullInt64) IsNull() bool {
67	return !n.Valid
68}
69
70// String implements Stringer.String for NullInt64
71func (n NullInt64) String() string {
72	if !n.Valid {
73		return nullString
74	}
75	return fmt.Sprintf("%v", n.Int64)
76}
77
78// MarshalJSON implements json.Marshaler.MarshalJSON for NullInt64.
79func (n NullInt64) MarshalJSON() ([]byte, error) {
80	if n.Valid {
81		return []byte(fmt.Sprintf("%v", n.Int64)), nil
82	}
83	return jsonNullBytes, nil
84}
85
86// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullInt64.
87func (n *NullInt64) UnmarshalJSON(payload []byte) error {
88	if payload == nil {
89		return fmt.Errorf("payload should not be nil")
90	}
91	if bytes.Equal(payload, jsonNullBytes) {
92		n.Int64 = int64(0)
93		n.Valid = false
94		return nil
95	}
96	num, err := strconv.ParseInt(string(payload), 10, 64)
97	if err != nil {
98		return fmt.Errorf("payload cannot be converted to int64: got %v", string(payload))
99	}
100	n.Int64 = num
101	n.Valid = true
102	return nil
103}
104
105// NullString represents a Cloud Spanner STRING that may be NULL.
106type NullString struct {
107	StringVal string
108	Valid     bool // Valid is true if StringVal is not NULL.
109}
110
111// IsNull implements NullableValue.IsNull for NullString.
112func (n NullString) IsNull() bool {
113	return !n.Valid
114}
115
116// String implements Stringer.String for NullString
117func (n NullString) String() string {
118	if !n.Valid {
119		return nullString
120	}
121	return n.StringVal
122}
123
124// MarshalJSON implements json.Marshaler.MarshalJSON for NullString.
125func (n NullString) MarshalJSON() ([]byte, error) {
126	if n.Valid {
127		return []byte(fmt.Sprintf("%q", n.StringVal)), nil
128	}
129	return jsonNullBytes, nil
130}
131
132// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullString.
133func (n *NullString) UnmarshalJSON(payload []byte) error {
134	if payload == nil {
135		return fmt.Errorf("payload should not be nil")
136	}
137	if bytes.Equal(payload, jsonNullBytes) {
138		n.StringVal = ""
139		n.Valid = false
140		return nil
141	}
142	payload, err := trimDoubleQuotes(payload)
143	if err != nil {
144		return err
145	}
146	n.StringVal = string(payload)
147	n.Valid = true
148	return nil
149}
150
151// NullFloat64 represents a Cloud Spanner FLOAT64 that may be NULL.
152type NullFloat64 struct {
153	Float64 float64
154	Valid   bool // Valid is true if Float64 is not NULL.
155}
156
157// IsNull implements NullableValue.IsNull for NullFloat64.
158func (n NullFloat64) IsNull() bool {
159	return !n.Valid
160}
161
162// String implements Stringer.String for NullFloat64
163func (n NullFloat64) String() string {
164	if !n.Valid {
165		return nullString
166	}
167	return fmt.Sprintf("%v", n.Float64)
168}
169
170// MarshalJSON implements json.Marshaler.MarshalJSON for NullFloat64.
171func (n NullFloat64) MarshalJSON() ([]byte, error) {
172	if n.Valid {
173		return []byte(fmt.Sprintf("%v", n.Float64)), nil
174	}
175	return jsonNullBytes, nil
176}
177
178// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullFloat64.
179func (n *NullFloat64) UnmarshalJSON(payload []byte) error {
180	if payload == nil {
181		return fmt.Errorf("payload should not be nil")
182	}
183	if bytes.Equal(payload, jsonNullBytes) {
184		n.Float64 = float64(0)
185		n.Valid = false
186		return nil
187	}
188	num, err := strconv.ParseFloat(string(payload), 64)
189	if err != nil {
190		return fmt.Errorf("payload cannot be converted to float64: got %v", string(payload))
191	}
192	n.Float64 = num
193	n.Valid = true
194	return nil
195}
196
197// NullBool represents a Cloud Spanner BOOL that may be NULL.
198type NullBool struct {
199	Bool  bool
200	Valid bool // Valid is true if Bool is not NULL.
201}
202
203// IsNull implements NullableValue.IsNull for NullBool.
204func (n NullBool) IsNull() bool {
205	return !n.Valid
206}
207
208// String implements Stringer.String for NullBool
209func (n NullBool) String() string {
210	if !n.Valid {
211		return nullString
212	}
213	return fmt.Sprintf("%v", n.Bool)
214}
215
216// MarshalJSON implements json.Marshaler.MarshalJSON for NullBool.
217func (n NullBool) MarshalJSON() ([]byte, error) {
218	if n.Valid {
219		return []byte(fmt.Sprintf("%v", n.Bool)), nil
220	}
221	return jsonNullBytes, nil
222}
223
224// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullBool.
225func (n *NullBool) UnmarshalJSON(payload []byte) error {
226	if payload == nil {
227		return fmt.Errorf("payload should not be nil")
228	}
229	if bytes.Equal(payload, jsonNullBytes) {
230		n.Bool = false
231		n.Valid = false
232		return nil
233	}
234	b, err := strconv.ParseBool(string(payload))
235	if err != nil {
236		return fmt.Errorf("payload cannot be converted to bool: got %v", string(payload))
237	}
238	n.Bool = b
239	n.Valid = true
240	return nil
241}
242
243// NullTime represents a Cloud Spanner TIMESTAMP that may be null.
244type NullTime struct {
245	Time  time.Time
246	Valid bool // Valid is true if Time is not NULL.
247}
248
249// IsNull implements NullableValue.IsNull for NullTime.
250func (n NullTime) IsNull() bool {
251	return !n.Valid
252}
253
254// String implements Stringer.String for NullTime
255func (n NullTime) String() string {
256	if !n.Valid {
257		return nullString
258	}
259	return n.Time.Format(time.RFC3339Nano)
260}
261
262// MarshalJSON implements json.Marshaler.MarshalJSON for NullTime.
263func (n NullTime) MarshalJSON() ([]byte, error) {
264	if n.Valid {
265		return []byte(fmt.Sprintf("%q", n.String())), nil
266	}
267	return jsonNullBytes, nil
268}
269
270// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullTime.
271func (n *NullTime) UnmarshalJSON(payload []byte) error {
272	if payload == nil {
273		return fmt.Errorf("payload should not be nil")
274	}
275	if bytes.Equal(payload, jsonNullBytes) {
276		n.Time = time.Time{}
277		n.Valid = false
278		return nil
279	}
280	payload, err := trimDoubleQuotes(payload)
281	if err != nil {
282		return err
283	}
284	s := string(payload)
285	t, err := time.Parse(time.RFC3339Nano, s)
286	if err != nil {
287		return fmt.Errorf("payload cannot be converted to time.Time: got %v", string(payload))
288	}
289	n.Time = t
290	n.Valid = true
291	return nil
292}
293
294// NullDate represents a Cloud Spanner DATE that may be null.
295type NullDate struct {
296	Date  civil.Date
297	Valid bool // Valid is true if Date is not NULL.
298}
299
300// IsNull implements NullableValue.IsNull for NullDate.
301func (n NullDate) IsNull() bool {
302	return !n.Valid
303}
304
305// String implements Stringer.String for NullDate
306func (n NullDate) String() string {
307	if !n.Valid {
308		return nullString
309	}
310	return n.Date.String()
311}
312
313// MarshalJSON implements json.Marshaler.MarshalJSON for NullDate.
314func (n NullDate) MarshalJSON() ([]byte, error) {
315	if n.Valid {
316		return []byte(fmt.Sprintf("%q", n.String())), nil
317	}
318	return jsonNullBytes, nil
319}
320
321// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullDate.
322func (n *NullDate) UnmarshalJSON(payload []byte) error {
323	if payload == nil {
324		return fmt.Errorf("payload should not be nil")
325	}
326	if bytes.Equal(payload, jsonNullBytes) {
327		n.Date = civil.Date{}
328		n.Valid = false
329		return nil
330	}
331	payload, err := trimDoubleQuotes(payload)
332	if err != nil {
333		return err
334	}
335	s := string(payload)
336	t, err := civil.ParseDate(s)
337	if err != nil {
338		return fmt.Errorf("payload cannot be converted to civil.Date: got %v", string(payload))
339	}
340	n.Date = t
341	n.Valid = true
342	return nil
343}
344
345// NullRow represents a Cloud Spanner STRUCT that may be NULL.
346// See also the document for Row.
347// Note that NullRow is not a valid Cloud Spanner column Type.
348type NullRow struct {
349	Row   Row
350	Valid bool // Valid is true if Row is not NULL.
351}
352
353// GenericColumnValue represents the generic encoded value and type of the
354// column.  See google.spanner.v1.ResultSet proto for details.  This can be
355// useful for proxying query results when the result types are not known in
356// advance.
357//
358// If you populate a GenericColumnValue from a row using Row.Column or related
359// methods, do not modify the contents of Type and Value.
360type GenericColumnValue struct {
361	Type  *sppb.Type
362	Value *proto3.Value
363}
364
365// Decode decodes a GenericColumnValue. The ptr argument should be a pointer
366// to a Go value that can accept v.
367func (v GenericColumnValue) Decode(ptr interface{}) error {
368	return decodeValue(v.Value, v.Type, ptr)
369}
370
371// NewGenericColumnValue creates a GenericColumnValue from Go value that is
372// valid for Cloud Spanner.
373func newGenericColumnValue(v interface{}) (*GenericColumnValue, error) {
374	value, typ, err := encodeValue(v)
375	if err != nil {
376		return nil, err
377	}
378	return &GenericColumnValue{Value: value, Type: typ}, nil
379}
380
381// errTypeMismatch returns error for destination not having a compatible type
382// with source Cloud Spanner type.
383func errTypeMismatch(srcCode, elCode sppb.TypeCode, dst interface{}) error {
384	s := srcCode.String()
385	if srcCode == sppb.TypeCode_ARRAY {
386		s = fmt.Sprintf("%v[%v]", srcCode, elCode)
387	}
388	return spannerErrorf(codes.InvalidArgument, "type %T cannot be used for decoding %s", dst, s)
389}
390
391// errNilSpannerType returns error for nil Cloud Spanner type in decoding.
392func errNilSpannerType() error {
393	return spannerErrorf(codes.FailedPrecondition, "unexpected nil Cloud Spanner data type in decoding")
394}
395
396// errNilSrc returns error for decoding from nil proto value.
397func errNilSrc() error {
398	return spannerErrorf(codes.FailedPrecondition, "unexpected nil Cloud Spanner value in decoding")
399}
400
401// errNilDst returns error for decoding into nil interface{}.
402func errNilDst(dst interface{}) error {
403	return spannerErrorf(codes.InvalidArgument, "cannot decode into nil type %T", dst)
404}
405
406// errNilArrElemType returns error for input Cloud Spanner data type being a array but without a
407// non-nil array element type.
408func errNilArrElemType(t *sppb.Type) error {
409	return spannerErrorf(codes.FailedPrecondition, "array type %v is with nil array element type", t)
410}
411
412func errUnsupportedEmbeddedStructFields(fname string) error {
413	return spannerErrorf(codes.InvalidArgument, "Embedded field: %s. Embedded and anonymous fields are not allowed "+
414		"when converting Go structs to Cloud Spanner STRUCT values. To create a STRUCT value with an "+
415		"unnamed field, use a `spanner:\"\"` field tag.", fname)
416}
417
418// errDstNotForNull returns error for decoding a SQL NULL value into a destination which doesn't
419// support NULL values.
420func errDstNotForNull(dst interface{}) error {
421	return spannerErrorf(codes.InvalidArgument, "destination %T cannot support NULL SQL values", dst)
422}
423
424// errBadEncoding returns error for decoding wrongly encoded types.
425func errBadEncoding(v *proto3.Value, err error) error {
426	return spannerErrorf(codes.FailedPrecondition, "%v wasn't correctly encoded: <%v>", v, err)
427}
428
429func parseNullTime(v *proto3.Value, p *NullTime, code sppb.TypeCode, isNull bool) error {
430	if p == nil {
431		return errNilDst(p)
432	}
433	if code != sppb.TypeCode_TIMESTAMP {
434		return errTypeMismatch(code, sppb.TypeCode_TYPE_CODE_UNSPECIFIED, p)
435	}
436	if isNull {
437		*p = NullTime{}
438		return nil
439	}
440	x, err := getStringValue(v)
441	if err != nil {
442		return err
443	}
444	y, err := time.Parse(time.RFC3339Nano, x)
445	if err != nil {
446		return errBadEncoding(v, err)
447	}
448	p.Valid = true
449	p.Time = y
450	return nil
451}
452
453// decodeValue decodes a protobuf Value into a pointer to a Go value, as
454// specified by sppb.Type.
455func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error {
456	if v == nil {
457		return errNilSrc()
458	}
459	if t == nil {
460		return errNilSpannerType()
461	}
462	code := t.Code
463	acode := sppb.TypeCode_TYPE_CODE_UNSPECIFIED
464	if code == sppb.TypeCode_ARRAY {
465		if t.ArrayElementType == nil {
466			return errNilArrElemType(t)
467		}
468		acode = t.ArrayElementType.Code
469	}
470	_, isNull := v.Kind.(*proto3.Value_NullValue)
471
472	// Do the decoding based on the type of ptr.
473	switch p := ptr.(type) {
474	case nil:
475		return errNilDst(nil)
476	case *string:
477		if p == nil {
478			return errNilDst(p)
479		}
480		if code != sppb.TypeCode_STRING {
481			return errTypeMismatch(code, acode, ptr)
482		}
483		if isNull {
484			return errDstNotForNull(ptr)
485		}
486		x, err := getStringValue(v)
487		if err != nil {
488			return err
489		}
490		*p = x
491	case *NullString, **string:
492		if p == nil {
493			return errNilDst(p)
494		}
495		if code != sppb.TypeCode_STRING {
496			return errTypeMismatch(code, acode, ptr)
497		}
498		if isNull {
499			switch sp := ptr.(type) {
500			case *NullString:
501				*sp = NullString{}
502			case **string:
503				*sp = nil
504			}
505			break
506		}
507		x, err := getStringValue(v)
508		if err != nil {
509			return err
510		}
511		switch sp := ptr.(type) {
512		case *NullString:
513			sp.Valid = true
514			sp.StringVal = x
515		case **string:
516			*sp = &x
517		}
518	case *[]NullString, *[]*string:
519		if p == nil {
520			return errNilDst(p)
521		}
522		if acode != sppb.TypeCode_STRING {
523			return errTypeMismatch(code, acode, ptr)
524		}
525		if isNull {
526			switch sp := ptr.(type) {
527			case *[]NullString:
528				*sp = nil
529			case *[]*string:
530				*sp = nil
531			}
532			break
533		}
534		x, err := getListValue(v)
535		if err != nil {
536			return err
537		}
538		switch sp := ptr.(type) {
539		case *[]NullString:
540			y, err := decodeNullStringArray(x)
541			if err != nil {
542				return err
543			}
544			*sp = y
545		case *[]*string:
546			y, err := decodeStringPointerArray(x)
547			if err != nil {
548				return err
549			}
550			*sp = y
551		}
552	case *[]string:
553		if p == nil {
554			return errNilDst(p)
555		}
556		if acode != sppb.TypeCode_STRING {
557			return errTypeMismatch(code, acode, ptr)
558		}
559		if isNull {
560			*p = nil
561			break
562		}
563		x, err := getListValue(v)
564		if err != nil {
565			return err
566		}
567		y, err := decodeStringArray(x)
568		if err != nil {
569			return err
570		}
571		*p = y
572	case *[]byte:
573		if p == nil {
574			return errNilDst(p)
575		}
576		if code != sppb.TypeCode_BYTES {
577			return errTypeMismatch(code, acode, ptr)
578		}
579		if isNull {
580			*p = nil
581			break
582		}
583		x, err := getStringValue(v)
584		if err != nil {
585			return err
586		}
587		y, err := base64.StdEncoding.DecodeString(x)
588		if err != nil {
589			return errBadEncoding(v, err)
590		}
591		*p = y
592	case *[][]byte:
593		if p == nil {
594			return errNilDst(p)
595		}
596		if acode != sppb.TypeCode_BYTES {
597			return errTypeMismatch(code, acode, ptr)
598		}
599		if isNull {
600			*p = nil
601			break
602		}
603		x, err := getListValue(v)
604		if err != nil {
605			return err
606		}
607		y, err := decodeByteArray(x)
608		if err != nil {
609			return err
610		}
611		*p = y
612	case *int64:
613		if p == nil {
614			return errNilDst(p)
615		}
616		if code != sppb.TypeCode_INT64 {
617			return errTypeMismatch(code, acode, ptr)
618		}
619		if isNull {
620			return errDstNotForNull(ptr)
621		}
622		x, err := getStringValue(v)
623		if err != nil {
624			return err
625		}
626		y, err := strconv.ParseInt(x, 10, 64)
627		if err != nil {
628			return errBadEncoding(v, err)
629		}
630		*p = y
631	case *NullInt64, **int64:
632		if p == nil {
633			return errNilDst(p)
634		}
635		if code != sppb.TypeCode_INT64 {
636			return errTypeMismatch(code, acode, ptr)
637		}
638		if isNull {
639			switch sp := ptr.(type) {
640			case *NullInt64:
641				*sp = NullInt64{}
642			case **int64:
643				*sp = nil
644			}
645			break
646		}
647		x, err := getStringValue(v)
648		if err != nil {
649			return err
650		}
651		y, err := strconv.ParseInt(x, 10, 64)
652		if err != nil {
653			return errBadEncoding(v, err)
654		}
655		switch sp := ptr.(type) {
656		case *NullInt64:
657			sp.Valid = true
658			sp.Int64 = y
659		case **int64:
660			*sp = &y
661		}
662	case *[]NullInt64, *[]*int64:
663		if p == nil {
664			return errNilDst(p)
665		}
666		if acode != sppb.TypeCode_INT64 {
667			return errTypeMismatch(code, acode, ptr)
668		}
669		if isNull {
670			switch sp := ptr.(type) {
671			case *[]NullInt64:
672				*sp = nil
673			case *[]*int64:
674				*sp = nil
675			}
676			break
677		}
678		x, err := getListValue(v)
679		if err != nil {
680			return err
681		}
682		switch sp := ptr.(type) {
683		case *[]NullInt64:
684			y, err := decodeNullInt64Array(x)
685			if err != nil {
686				return err
687			}
688			*sp = y
689		case *[]*int64:
690			y, err := decodeInt64PointerArray(x)
691			if err != nil {
692				return err
693			}
694			*sp = y
695		}
696	case *[]int64:
697		if p == nil {
698			return errNilDst(p)
699		}
700		if acode != sppb.TypeCode_INT64 {
701			return errTypeMismatch(code, acode, ptr)
702		}
703		if isNull {
704			*p = nil
705			break
706		}
707		x, err := getListValue(v)
708		if err != nil {
709			return err
710		}
711		y, err := decodeInt64Array(x)
712		if err != nil {
713			return err
714		}
715		*p = y
716	case *bool:
717		if p == nil {
718			return errNilDst(p)
719		}
720		if code != sppb.TypeCode_BOOL {
721			return errTypeMismatch(code, acode, ptr)
722		}
723		if isNull {
724			return errDstNotForNull(ptr)
725		}
726		x, err := getBoolValue(v)
727		if err != nil {
728			return err
729		}
730		*p = x
731	case *NullBool, **bool:
732		if p == nil {
733			return errNilDst(p)
734		}
735		if code != sppb.TypeCode_BOOL {
736			return errTypeMismatch(code, acode, ptr)
737		}
738		if isNull {
739			switch sp := ptr.(type) {
740			case *NullBool:
741				*sp = NullBool{}
742			case **bool:
743				*sp = nil
744			}
745			break
746		}
747		x, err := getBoolValue(v)
748		if err != nil {
749			return err
750		}
751		switch sp := ptr.(type) {
752		case *NullBool:
753			sp.Valid = true
754			sp.Bool = x
755		case **bool:
756			*sp = &x
757		}
758	case *[]NullBool, *[]*bool:
759		if p == nil {
760			return errNilDst(p)
761		}
762		if acode != sppb.TypeCode_BOOL {
763			return errTypeMismatch(code, acode, ptr)
764		}
765		if isNull {
766			switch sp := ptr.(type) {
767			case *[]NullBool:
768				*sp = nil
769			case *[]*bool:
770				*sp = nil
771			}
772			break
773		}
774		x, err := getListValue(v)
775		if err != nil {
776			return err
777		}
778		switch sp := ptr.(type) {
779		case *[]NullBool:
780			y, err := decodeNullBoolArray(x)
781			if err != nil {
782				return err
783			}
784			*sp = y
785		case *[]*bool:
786			y, err := decodeBoolPointerArray(x)
787			if err != nil {
788				return err
789			}
790			*sp = y
791		}
792	case *[]bool:
793		if p == nil {
794			return errNilDst(p)
795		}
796		if acode != sppb.TypeCode_BOOL {
797			return errTypeMismatch(code, acode, ptr)
798		}
799		if isNull {
800			*p = nil
801			break
802		}
803		x, err := getListValue(v)
804		if err != nil {
805			return err
806		}
807		y, err := decodeBoolArray(x)
808		if err != nil {
809			return err
810		}
811		*p = y
812	case *float64:
813		if p == nil {
814			return errNilDst(p)
815		}
816		if code != sppb.TypeCode_FLOAT64 {
817			return errTypeMismatch(code, acode, ptr)
818		}
819		if isNull {
820			return errDstNotForNull(ptr)
821		}
822		x, err := getFloat64Value(v)
823		if err != nil {
824			return err
825		}
826		*p = x
827	case *NullFloat64, **float64:
828		if p == nil {
829			return errNilDst(p)
830		}
831		if code != sppb.TypeCode_FLOAT64 {
832			return errTypeMismatch(code, acode, ptr)
833		}
834		if isNull {
835			switch sp := ptr.(type) {
836			case *NullFloat64:
837				*sp = NullFloat64{}
838			case **float64:
839				*sp = nil
840			}
841			break
842		}
843		x, err := getFloat64Value(v)
844		if err != nil {
845			return err
846		}
847		switch sp := ptr.(type) {
848		case *NullFloat64:
849			sp.Valid = true
850			sp.Float64 = x
851		case **float64:
852			*sp = &x
853		}
854	case *[]NullFloat64, *[]*float64:
855		if p == nil {
856			return errNilDst(p)
857		}
858		if acode != sppb.TypeCode_FLOAT64 {
859			return errTypeMismatch(code, acode, ptr)
860		}
861		if isNull {
862			switch sp := ptr.(type) {
863			case *[]NullFloat64:
864				*sp = nil
865			case *[]*float64:
866				*sp = nil
867			}
868			break
869		}
870		x, err := getListValue(v)
871		if err != nil {
872			return err
873		}
874		switch sp := ptr.(type) {
875		case *[]NullFloat64:
876			y, err := decodeNullFloat64Array(x)
877			if err != nil {
878				return err
879			}
880			*sp = y
881		case *[]*float64:
882			y, err := decodeFloat64PointerArray(x)
883			if err != nil {
884				return err
885			}
886			*sp = y
887		}
888	case *[]float64:
889		if p == nil {
890			return errNilDst(p)
891		}
892		if acode != sppb.TypeCode_FLOAT64 {
893			return errTypeMismatch(code, acode, ptr)
894		}
895		if isNull {
896			*p = nil
897			break
898		}
899		x, err := getListValue(v)
900		if err != nil {
901			return err
902		}
903		y, err := decodeFloat64Array(x)
904		if err != nil {
905			return err
906		}
907		*p = y
908	case *time.Time:
909		var nt NullTime
910		if isNull {
911			return errDstNotForNull(ptr)
912		}
913		err := parseNullTime(v, &nt, code, isNull)
914		if err != nil {
915			return err
916		}
917		*p = nt.Time
918	case *NullTime:
919		err := parseNullTime(v, p, code, isNull)
920		if err != nil {
921			return err
922		}
923	case **time.Time:
924		var nt NullTime
925		if isNull {
926			*p = nil
927			break
928		}
929		err := parseNullTime(v, &nt, code, isNull)
930		if err != nil {
931			return err
932		}
933		*p = &nt.Time
934	case *[]NullTime, *[]*time.Time:
935		if p == nil {
936			return errNilDst(p)
937		}
938		if acode != sppb.TypeCode_TIMESTAMP {
939			return errTypeMismatch(code, acode, ptr)
940		}
941		if isNull {
942			switch sp := ptr.(type) {
943			case *[]NullTime:
944				*sp = nil
945			case *[]*time.Time:
946				*sp = nil
947			}
948			break
949		}
950		x, err := getListValue(v)
951		if err != nil {
952			return err
953		}
954		switch sp := ptr.(type) {
955		case *[]NullTime:
956			y, err := decodeNullTimeArray(x)
957			if err != nil {
958				return err
959			}
960			*sp = y
961		case *[]*time.Time:
962			y, err := decodeTimePointerArray(x)
963			if err != nil {
964				return err
965			}
966			*sp = y
967		}
968	case *[]time.Time:
969		if p == nil {
970			return errNilDst(p)
971		}
972		if acode != sppb.TypeCode_TIMESTAMP {
973			return errTypeMismatch(code, acode, ptr)
974		}
975		if isNull {
976			*p = nil
977			break
978		}
979		x, err := getListValue(v)
980		if err != nil {
981			return err
982		}
983		y, err := decodeTimeArray(x)
984		if err != nil {
985			return err
986		}
987		*p = y
988	case *civil.Date:
989		if p == nil {
990			return errNilDst(p)
991		}
992		if code != sppb.TypeCode_DATE {
993			return errTypeMismatch(code, acode, ptr)
994		}
995		if isNull {
996			return errDstNotForNull(ptr)
997		}
998		x, err := getStringValue(v)
999		if err != nil {
1000			return err
1001		}
1002		y, err := civil.ParseDate(x)
1003		if err != nil {
1004			return errBadEncoding(v, err)
1005		}
1006		*p = y
1007	case *NullDate, **civil.Date:
1008		if p == nil {
1009			return errNilDst(p)
1010		}
1011		if code != sppb.TypeCode_DATE {
1012			return errTypeMismatch(code, acode, ptr)
1013		}
1014		if isNull {
1015			switch sp := ptr.(type) {
1016			case *NullDate:
1017				*sp = NullDate{}
1018			case **civil.Date:
1019				*sp = nil
1020			}
1021			break
1022		}
1023		x, err := getStringValue(v)
1024		if err != nil {
1025			return err
1026		}
1027		y, err := civil.ParseDate(x)
1028		if err != nil {
1029			return errBadEncoding(v, err)
1030		}
1031		switch sp := ptr.(type) {
1032		case *NullDate:
1033			sp.Valid = true
1034			sp.Date = y
1035		case **civil.Date:
1036			*sp = &y
1037		}
1038	case *[]NullDate, *[]*civil.Date:
1039		if p == nil {
1040			return errNilDst(p)
1041		}
1042		if acode != sppb.TypeCode_DATE {
1043			return errTypeMismatch(code, acode, ptr)
1044		}
1045		if isNull {
1046			switch sp := ptr.(type) {
1047			case *[]NullDate:
1048				*sp = nil
1049			case *[]*civil.Date:
1050				*sp = nil
1051			}
1052			break
1053		}
1054		x, err := getListValue(v)
1055		if err != nil {
1056			return err
1057		}
1058		switch sp := ptr.(type) {
1059		case *[]NullDate:
1060			y, err := decodeNullDateArray(x)
1061			if err != nil {
1062				return err
1063			}
1064			*sp = y
1065		case *[]*civil.Date:
1066			y, err := decodeDatePointerArray(x)
1067			if err != nil {
1068				return err
1069			}
1070			*sp = y
1071		}
1072	case *[]civil.Date:
1073		if p == nil {
1074			return errNilDst(p)
1075		}
1076		if acode != sppb.TypeCode_DATE {
1077			return errTypeMismatch(code, acode, ptr)
1078		}
1079		if isNull {
1080			*p = nil
1081			break
1082		}
1083		x, err := getListValue(v)
1084		if err != nil {
1085			return err
1086		}
1087		y, err := decodeDateArray(x)
1088		if err != nil {
1089			return err
1090		}
1091		*p = y
1092	case *[]NullRow:
1093		if p == nil {
1094			return errNilDst(p)
1095		}
1096		if acode != sppb.TypeCode_STRUCT {
1097			return errTypeMismatch(code, acode, ptr)
1098		}
1099		if isNull {
1100			*p = nil
1101			break
1102		}
1103		x, err := getListValue(v)
1104		if err != nil {
1105			return err
1106		}
1107		y, err := decodeRowArray(t.ArrayElementType.StructType, x)
1108		if err != nil {
1109			return err
1110		}
1111		*p = y
1112	case *GenericColumnValue:
1113		*p = GenericColumnValue{Type: t, Value: v}
1114	default:
1115		// Check if the pointer is a variant of a base type.
1116		decodableType := getDecodableSpannerType(ptr)
1117		if decodableType != spannerTypeUnknown {
1118			if isNull && !decodableType.supportsNull() {
1119				return errDstNotForNull(ptr)
1120			}
1121			return decodableType.decodeValueToCustomType(v, t, acode, ptr)
1122		}
1123
1124		// Check if the proto encoding is for an array of structs.
1125		if !(code == sppb.TypeCode_ARRAY && acode == sppb.TypeCode_STRUCT) {
1126			return errTypeMismatch(code, acode, ptr)
1127		}
1128		vp := reflect.ValueOf(p)
1129		if !vp.IsValid() {
1130			return errNilDst(p)
1131		}
1132		if !isPtrStructPtrSlice(vp.Type()) {
1133			// The container is not a pointer to a struct pointer slice.
1134			return errTypeMismatch(code, acode, ptr)
1135		}
1136		// Only use reflection for nil detection on slow path.
1137		// Also, IsNil panics on many types, so check it after the type check.
1138		if vp.IsNil() {
1139			return errNilDst(p)
1140		}
1141		if isNull {
1142			// The proto Value is encoding NULL, set the pointer to struct
1143			// slice to nil as well.
1144			vp.Elem().Set(reflect.Zero(vp.Elem().Type()))
1145			break
1146		}
1147		x, err := getListValue(v)
1148		if err != nil {
1149			return err
1150		}
1151		if err = decodeStructArray(t.ArrayElementType.StructType, x, p); err != nil {
1152			return err
1153		}
1154	}
1155	return nil
1156}
1157
1158// decodableSpannerType represents the Go types that a value from a Spanner
1159// database can be converted to.
1160type decodableSpannerType uint
1161
1162const (
1163	spannerTypeUnknown decodableSpannerType = iota
1164	spannerTypeInvalid
1165	spannerTypeNonNullString
1166	spannerTypeByteArray
1167	spannerTypeNonNullInt64
1168	spannerTypeNonNullBool
1169	spannerTypeNonNullFloat64
1170	spannerTypeNonNullTime
1171	spannerTypeNonNullDate
1172	spannerTypeNullString
1173	spannerTypeNullInt64
1174	spannerTypeNullBool
1175	spannerTypeNullFloat64
1176	spannerTypeNullTime
1177	spannerTypeNullDate
1178	spannerTypeArrayOfNonNullString
1179	spannerTypeArrayOfByteArray
1180	spannerTypeArrayOfNonNullInt64
1181	spannerTypeArrayOfNonNullBool
1182	spannerTypeArrayOfNonNullFloat64
1183	spannerTypeArrayOfNonNullTime
1184	spannerTypeArrayOfNonNullDate
1185	spannerTypeArrayOfNullString
1186	spannerTypeArrayOfNullInt64
1187	spannerTypeArrayOfNullBool
1188	spannerTypeArrayOfNullFloat64
1189	spannerTypeArrayOfNullTime
1190	spannerTypeArrayOfNullDate
1191)
1192
1193// supportsNull returns true for the Go types that can hold a null value from
1194// Spanner.
1195func (d decodableSpannerType) supportsNull() bool {
1196	switch d {
1197	case spannerTypeNonNullString, spannerTypeNonNullInt64, spannerTypeNonNullBool, spannerTypeNonNullFloat64, spannerTypeNonNullTime, spannerTypeNonNullDate:
1198		return false
1199	default:
1200		return true
1201	}
1202}
1203
1204// The following list of types represent the struct types that represent a
1205// specific Spanner data type in Go. If a pointer to one of these types is
1206// passed to decodeValue, the client library will decode one column value into
1207// the struct. For pointers to all other struct types, the client library will
1208// treat it as a generic struct that should contain a field for each column in
1209// the result set that is being decoded.
1210
1211var typeOfNonNullTime = reflect.TypeOf(time.Time{})
1212var typeOfNonNullDate = reflect.TypeOf(civil.Date{})
1213var typeOfNullString = reflect.TypeOf(NullString{})
1214var typeOfNullInt64 = reflect.TypeOf(NullInt64{})
1215var typeOfNullBool = reflect.TypeOf(NullBool{})
1216var typeOfNullFloat64 = reflect.TypeOf(NullFloat64{})
1217var typeOfNullTime = reflect.TypeOf(NullTime{})
1218var typeOfNullDate = reflect.TypeOf(NullDate{})
1219
1220// getDecodableSpannerType returns the corresponding decodableSpannerType of
1221// the given pointer.
1222func getDecodableSpannerType(ptr interface{}) decodableSpannerType {
1223	kind := reflect.Indirect(reflect.ValueOf(ptr)).Kind()
1224	if kind == reflect.Invalid {
1225		return spannerTypeInvalid
1226	}
1227	switch kind {
1228	case reflect.Invalid:
1229		return spannerTypeInvalid
1230	case reflect.String:
1231		return spannerTypeNonNullString
1232	case reflect.Int64:
1233		return spannerTypeNonNullInt64
1234	case reflect.Bool:
1235		return spannerTypeNonNullBool
1236	case reflect.Float64:
1237		return spannerTypeNonNullFloat64
1238	case reflect.Struct:
1239		t := reflect.Indirect(reflect.ValueOf(ptr)).Type()
1240		if t.ConvertibleTo(typeOfNonNullTime) {
1241			return spannerTypeNonNullTime
1242		}
1243		if t.ConvertibleTo(typeOfNonNullDate) {
1244			return spannerTypeNonNullDate
1245		}
1246		if t.ConvertibleTo(typeOfNullString) {
1247			return spannerTypeNullString
1248		}
1249		if t.ConvertibleTo(typeOfNullInt64) {
1250			return spannerTypeNullInt64
1251		}
1252		if t.ConvertibleTo(typeOfNullBool) {
1253			return spannerTypeNullBool
1254		}
1255		if t.ConvertibleTo(typeOfNullFloat64) {
1256			return spannerTypeNullFloat64
1257		}
1258		if t.ConvertibleTo(typeOfNullTime) {
1259			return spannerTypeNullTime
1260		}
1261		if t.ConvertibleTo(typeOfNullDate) {
1262			return spannerTypeNullDate
1263		}
1264	case reflect.Slice:
1265		kind := reflect.Indirect(reflect.ValueOf(ptr)).Type().Elem().Kind()
1266		switch kind {
1267		case reflect.Invalid:
1268			return spannerTypeUnknown
1269		case reflect.String:
1270			return spannerTypeArrayOfNonNullString
1271		case reflect.Uint8:
1272			return spannerTypeByteArray
1273		case reflect.Int64:
1274			return spannerTypeArrayOfNonNullInt64
1275		case reflect.Bool:
1276			return spannerTypeArrayOfNonNullBool
1277		case reflect.Float64:
1278			return spannerTypeArrayOfNonNullFloat64
1279		case reflect.Struct:
1280			t := reflect.Indirect(reflect.ValueOf(ptr)).Type().Elem()
1281			if t.ConvertibleTo(typeOfNonNullTime) {
1282				return spannerTypeArrayOfNonNullTime
1283			}
1284			if t.ConvertibleTo(typeOfNonNullDate) {
1285				return spannerTypeArrayOfNonNullDate
1286			}
1287			if t.ConvertibleTo(typeOfNullString) {
1288				return spannerTypeArrayOfNullString
1289			}
1290			if t.ConvertibleTo(typeOfNullInt64) {
1291				return spannerTypeArrayOfNullInt64
1292			}
1293			if t.ConvertibleTo(typeOfNullBool) {
1294				return spannerTypeArrayOfNullBool
1295			}
1296			if t.ConvertibleTo(typeOfNullFloat64) {
1297				return spannerTypeArrayOfNullFloat64
1298			}
1299			if t.ConvertibleTo(typeOfNullTime) {
1300				return spannerTypeArrayOfNullTime
1301			}
1302			if t.ConvertibleTo(typeOfNullDate) {
1303				return spannerTypeArrayOfNullDate
1304			}
1305		case reflect.Slice:
1306			// The only array-of-array type that is supported is [][]byte.
1307			kind := reflect.Indirect(reflect.ValueOf(ptr)).Type().Elem().Elem().Kind()
1308			switch kind {
1309			case reflect.Uint8:
1310				return spannerTypeArrayOfByteArray
1311			}
1312		}
1313	}
1314	// Not convertible to a known base type.
1315	return spannerTypeUnknown
1316}
1317
1318// decodeValueToCustomType decodes a protobuf Value into a pointer to a Go
1319// value. It must be possible to convert the value to the type pointed to by
1320// the pointer.
1321func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb.Type, acode sppb.TypeCode, ptr interface{}) error {
1322	code := t.Code
1323	_, isNull := v.Kind.(*proto3.Value_NullValue)
1324	if dsc == spannerTypeInvalid {
1325		return errNilDst(ptr)
1326	}
1327	if isNull && !dsc.supportsNull() {
1328		return errDstNotForNull(ptr)
1329	}
1330
1331	var result interface{}
1332	switch dsc {
1333	case spannerTypeNonNullString, spannerTypeNullString:
1334		if code != sppb.TypeCode_STRING {
1335			return errTypeMismatch(code, acode, ptr)
1336		}
1337		if isNull {
1338			result = &NullString{}
1339			break
1340		}
1341		x, err := getStringValue(v)
1342		if err != nil {
1343			return err
1344		}
1345		if dsc == spannerTypeNonNullString {
1346			result = &x
1347		} else {
1348			result = &NullString{x, !isNull}
1349		}
1350	case spannerTypeByteArray:
1351		if code != sppb.TypeCode_BYTES {
1352			return errTypeMismatch(code, acode, ptr)
1353		}
1354		if isNull {
1355			result = []byte(nil)
1356			break
1357		}
1358		x, err := getStringValue(v)
1359		if err != nil {
1360			return err
1361		}
1362		y, err := base64.StdEncoding.DecodeString(x)
1363		if err != nil {
1364			return errBadEncoding(v, err)
1365		}
1366		result = y
1367	case spannerTypeNonNullInt64, spannerTypeNullInt64:
1368		if code != sppb.TypeCode_INT64 {
1369			return errTypeMismatch(code, acode, ptr)
1370		}
1371		if isNull {
1372			result = &NullInt64{}
1373			break
1374		}
1375		x, err := getStringValue(v)
1376		if err != nil {
1377			return err
1378		}
1379		y, err := strconv.ParseInt(x, 10, 64)
1380		if err != nil {
1381			return errBadEncoding(v, err)
1382		}
1383		if dsc == spannerTypeNonNullInt64 {
1384			result = &y
1385		} else {
1386			result = &NullInt64{y, !isNull}
1387		}
1388	case spannerTypeNonNullBool, spannerTypeNullBool:
1389		if code != sppb.TypeCode_BOOL {
1390			return errTypeMismatch(code, acode, ptr)
1391		}
1392		if isNull {
1393			result = &NullBool{}
1394			break
1395		}
1396		x, err := getBoolValue(v)
1397		if err != nil {
1398			return err
1399		}
1400		if dsc == spannerTypeNonNullBool {
1401			result = &x
1402		} else {
1403			result = &NullBool{x, !isNull}
1404		}
1405	case spannerTypeNonNullFloat64, spannerTypeNullFloat64:
1406		if code != sppb.TypeCode_FLOAT64 {
1407			return errTypeMismatch(code, acode, ptr)
1408		}
1409		if isNull {
1410			result = &NullFloat64{}
1411			break
1412		}
1413		x, err := getFloat64Value(v)
1414		if err != nil {
1415			return err
1416		}
1417		if dsc == spannerTypeNonNullFloat64 {
1418			result = &x
1419		} else {
1420			result = &NullFloat64{x, !isNull}
1421		}
1422	case spannerTypeNonNullTime, spannerTypeNullTime:
1423		var nt NullTime
1424		err := parseNullTime(v, &nt, code, isNull)
1425		if err != nil {
1426			return err
1427		}
1428		if dsc == spannerTypeNonNullTime {
1429			result = &nt.Time
1430		} else {
1431			result = &nt
1432		}
1433	case spannerTypeNonNullDate, spannerTypeNullDate:
1434		if code != sppb.TypeCode_DATE {
1435			return errTypeMismatch(code, acode, ptr)
1436		}
1437		if isNull {
1438			result = &NullDate{}
1439			break
1440		}
1441		x, err := getStringValue(v)
1442		if err != nil {
1443			return err
1444		}
1445		y, err := civil.ParseDate(x)
1446		if err != nil {
1447			return errBadEncoding(v, err)
1448		}
1449		if dsc == spannerTypeNonNullDate {
1450			result = &y
1451		} else {
1452			result = &NullDate{y, !isNull}
1453		}
1454	case spannerTypeArrayOfNonNullString, spannerTypeArrayOfNullString:
1455		if acode != sppb.TypeCode_STRING {
1456			return errTypeMismatch(code, acode, ptr)
1457		}
1458		if isNull {
1459			ptr = nil
1460			return nil
1461		}
1462		x, err := getListValue(v)
1463		if err != nil {
1464			return err
1465		}
1466		y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, stringType(), "STRING")
1467		if err != nil {
1468			return err
1469		}
1470		result = y
1471	case spannerTypeArrayOfByteArray:
1472		if acode != sppb.TypeCode_BYTES {
1473			return errTypeMismatch(code, acode, ptr)
1474		}
1475		if isNull {
1476			ptr = nil
1477			return nil
1478		}
1479		x, err := getListValue(v)
1480		if err != nil {
1481			return err
1482		}
1483		y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, bytesType(), "BYTES")
1484		if err != nil {
1485			return err
1486		}
1487		result = y
1488	case spannerTypeArrayOfNonNullInt64, spannerTypeArrayOfNullInt64:
1489		if acode != sppb.TypeCode_INT64 {
1490			return errTypeMismatch(code, acode, ptr)
1491		}
1492		if isNull {
1493			ptr = nil
1494			return nil
1495		}
1496		x, err := getListValue(v)
1497		if err != nil {
1498			return err
1499		}
1500		y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, intType(), "INT64")
1501		if err != nil {
1502			return err
1503		}
1504		result = y
1505	case spannerTypeArrayOfNonNullBool, spannerTypeArrayOfNullBool:
1506		if acode != sppb.TypeCode_BOOL {
1507			return errTypeMismatch(code, acode, ptr)
1508		}
1509		if isNull {
1510			ptr = nil
1511			return nil
1512		}
1513		x, err := getListValue(v)
1514		if err != nil {
1515			return err
1516		}
1517		y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, boolType(), "BOOL")
1518		if err != nil {
1519			return err
1520		}
1521		result = y
1522	case spannerTypeArrayOfNonNullFloat64, spannerTypeArrayOfNullFloat64:
1523		if acode != sppb.TypeCode_FLOAT64 {
1524			return errTypeMismatch(code, acode, ptr)
1525		}
1526		if isNull {
1527			ptr = nil
1528			return nil
1529		}
1530		x, err := getListValue(v)
1531		if err != nil {
1532			return err
1533		}
1534		y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, floatType(), "FLOAT64")
1535		if err != nil {
1536			return err
1537		}
1538		result = y
1539	case spannerTypeArrayOfNonNullTime, spannerTypeArrayOfNullTime:
1540		if acode != sppb.TypeCode_TIMESTAMP {
1541			return errTypeMismatch(code, acode, ptr)
1542		}
1543		if isNull {
1544			ptr = nil
1545			return nil
1546		}
1547		x, err := getListValue(v)
1548		if err != nil {
1549			return err
1550		}
1551		y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, timeType(), "TIMESTAMP")
1552		if err != nil {
1553			return err
1554		}
1555		result = y
1556	case spannerTypeArrayOfNonNullDate, spannerTypeArrayOfNullDate:
1557		if acode != sppb.TypeCode_DATE {
1558			return errTypeMismatch(code, acode, ptr)
1559		}
1560		if isNull {
1561			ptr = nil
1562			return nil
1563		}
1564		x, err := getListValue(v)
1565		if err != nil {
1566			return err
1567		}
1568		y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, dateType(), "DATE")
1569		if err != nil {
1570			return err
1571		}
1572		result = y
1573	default:
1574		// This should not be possible.
1575		return fmt.Errorf("unknown decodable type found: %v", dsc)
1576	}
1577	source := reflect.Indirect(reflect.ValueOf(result))
1578	destination := reflect.Indirect(reflect.ValueOf(ptr))
1579	destination.Set(source.Convert(destination.Type()))
1580	return nil
1581}
1582
1583// errSrvVal returns an error for getting a wrong source protobuf value in decoding.
1584func errSrcVal(v *proto3.Value, want string) error {
1585	return spannerErrorf(codes.FailedPrecondition, "cannot use %v(Kind: %T) as %s Value",
1586		v, v.GetKind(), want)
1587}
1588
1589// getStringValue returns the string value encoded in proto3.Value v whose
1590// kind is proto3.Value_StringValue.
1591func getStringValue(v *proto3.Value) (string, error) {
1592	if x, ok := v.GetKind().(*proto3.Value_StringValue); ok && x != nil {
1593		return x.StringValue, nil
1594	}
1595	return "", errSrcVal(v, "String")
1596}
1597
1598// getBoolValue returns the bool value encoded in proto3.Value v whose
1599// kind is proto3.Value_BoolValue.
1600func getBoolValue(v *proto3.Value) (bool, error) {
1601	if x, ok := v.GetKind().(*proto3.Value_BoolValue); ok && x != nil {
1602		return x.BoolValue, nil
1603	}
1604	return false, errSrcVal(v, "Bool")
1605}
1606
1607// getListValue returns the proto3.ListValue contained in proto3.Value v whose
1608// kind is proto3.Value_ListValue.
1609func getListValue(v *proto3.Value) (*proto3.ListValue, error) {
1610	if x, ok := v.GetKind().(*proto3.Value_ListValue); ok && x != nil {
1611		return x.ListValue, nil
1612	}
1613	return nil, errSrcVal(v, "List")
1614}
1615
1616// errUnexpectedNumStr returns error for decoder getting a unexpected string for
1617// representing special float values.
1618func errUnexpectedNumStr(s string) error {
1619	return spannerErrorf(codes.FailedPrecondition, "unexpected string value %q for number", s)
1620}
1621
1622// getFloat64Value returns the float64 value encoded in proto3.Value v whose
1623// kind is proto3.Value_NumberValue / proto3.Value_StringValue.
1624// Cloud Spanner uses string to encode NaN, Infinity and -Infinity.
1625func getFloat64Value(v *proto3.Value) (float64, error) {
1626	switch x := v.GetKind().(type) {
1627	case *proto3.Value_NumberValue:
1628		if x == nil {
1629			break
1630		}
1631		return x.NumberValue, nil
1632	case *proto3.Value_StringValue:
1633		if x == nil {
1634			break
1635		}
1636		switch x.StringValue {
1637		case "NaN":
1638			return math.NaN(), nil
1639		case "Infinity":
1640			return math.Inf(1), nil
1641		case "-Infinity":
1642			return math.Inf(-1), nil
1643		default:
1644			return 0, errUnexpectedNumStr(x.StringValue)
1645		}
1646	}
1647	return 0, errSrcVal(v, "Number")
1648}
1649
1650// errNilListValue returns error for unexpected nil ListValue in decoding Cloud Spanner ARRAYs.
1651func errNilListValue(sqlType string) error {
1652	return spannerErrorf(codes.FailedPrecondition, "unexpected nil ListValue in decoding %v array", sqlType)
1653}
1654
1655// errDecodeArrayElement returns error for failure in decoding single array element.
1656func errDecodeArrayElement(i int, v proto.Message, sqlType string, err error) error {
1657	var se *Error
1658	if !errorAs(err, &se) {
1659		return spannerErrorf(codes.Unknown,
1660			"cannot decode %v(array element %v) as %v, error = <%v>", v, i, sqlType, err)
1661	}
1662	se.decorate(fmt.Sprintf("cannot decode %v(array element %v) as %v", v, i, sqlType))
1663	return se
1664}
1665
1666// decodeGenericArray decodes proto3.ListValue pb into a slice which type is
1667// determined through reflection.
1668func decodeGenericArray(tp reflect.Type, pb *proto3.ListValue, t *sppb.Type, sqlType string) (interface{}, error) {
1669	if pb == nil {
1670		return nil, errNilListValue(sqlType)
1671	}
1672	a := reflect.MakeSlice(tp, len(pb.Values), len(pb.Values))
1673	for i, v := range pb.Values {
1674		if err := decodeValue(v, t, a.Index(i).Addr().Interface()); err != nil {
1675			return nil, errDecodeArrayElement(i, v, "STRING", err)
1676		}
1677	}
1678	return a.Interface(), nil
1679}
1680
1681// decodeNullStringArray decodes proto3.ListValue pb into a NullString slice.
1682func decodeNullStringArray(pb *proto3.ListValue) ([]NullString, error) {
1683	if pb == nil {
1684		return nil, errNilListValue("STRING")
1685	}
1686	a := make([]NullString, len(pb.Values))
1687	for i, v := range pb.Values {
1688		if err := decodeValue(v, stringType(), &a[i]); err != nil {
1689			return nil, errDecodeArrayElement(i, v, "STRING", err)
1690		}
1691	}
1692	return a, nil
1693}
1694
1695// decodeStringPointerArray decodes proto3.ListValue pb into a *string slice.
1696func decodeStringPointerArray(pb *proto3.ListValue) ([]*string, error) {
1697	if pb == nil {
1698		return nil, errNilListValue("STRING")
1699	}
1700	a := make([]*string, len(pb.Values))
1701	for i, v := range pb.Values {
1702		if err := decodeValue(v, stringType(), &a[i]); err != nil {
1703			return nil, errDecodeArrayElement(i, v, "STRING", err)
1704		}
1705	}
1706	return a, nil
1707}
1708
1709// decodeStringArray decodes proto3.ListValue pb into a string slice.
1710func decodeStringArray(pb *proto3.ListValue) ([]string, error) {
1711	if pb == nil {
1712		return nil, errNilListValue("STRING")
1713	}
1714	a := make([]string, len(pb.Values))
1715	st := stringType()
1716	for i, v := range pb.Values {
1717		if err := decodeValue(v, st, &a[i]); err != nil {
1718			return nil, errDecodeArrayElement(i, v, "STRING", err)
1719		}
1720	}
1721	return a, nil
1722}
1723
1724// decodeNullInt64Array decodes proto3.ListValue pb into a NullInt64 slice.
1725func decodeNullInt64Array(pb *proto3.ListValue) ([]NullInt64, error) {
1726	if pb == nil {
1727		return nil, errNilListValue("INT64")
1728	}
1729	a := make([]NullInt64, len(pb.Values))
1730	for i, v := range pb.Values {
1731		if err := decodeValue(v, intType(), &a[i]); err != nil {
1732			return nil, errDecodeArrayElement(i, v, "INT64", err)
1733		}
1734	}
1735	return a, nil
1736}
1737
1738// decodeInt64PointerArray decodes proto3.ListValue pb into a *int64 slice.
1739func decodeInt64PointerArray(pb *proto3.ListValue) ([]*int64, error) {
1740	if pb == nil {
1741		return nil, errNilListValue("INT64")
1742	}
1743	a := make([]*int64, len(pb.Values))
1744	for i, v := range pb.Values {
1745		if err := decodeValue(v, intType(), &a[i]); err != nil {
1746			return nil, errDecodeArrayElement(i, v, "INT64", err)
1747		}
1748	}
1749	return a, nil
1750}
1751
1752// decodeInt64Array decodes proto3.ListValue pb into a int64 slice.
1753func decodeInt64Array(pb *proto3.ListValue) ([]int64, error) {
1754	if pb == nil {
1755		return nil, errNilListValue("INT64")
1756	}
1757	a := make([]int64, len(pb.Values))
1758	for i, v := range pb.Values {
1759		if err := decodeValue(v, intType(), &a[i]); err != nil {
1760			return nil, errDecodeArrayElement(i, v, "INT64", err)
1761		}
1762	}
1763	return a, nil
1764}
1765
1766// decodeNullBoolArray decodes proto3.ListValue pb into a NullBool slice.
1767func decodeNullBoolArray(pb *proto3.ListValue) ([]NullBool, error) {
1768	if pb == nil {
1769		return nil, errNilListValue("BOOL")
1770	}
1771	a := make([]NullBool, len(pb.Values))
1772	for i, v := range pb.Values {
1773		if err := decodeValue(v, boolType(), &a[i]); err != nil {
1774			return nil, errDecodeArrayElement(i, v, "BOOL", err)
1775		}
1776	}
1777	return a, nil
1778}
1779
1780// decodeBoolPointerArray decodes proto3.ListValue pb into a *bool slice.
1781func decodeBoolPointerArray(pb *proto3.ListValue) ([]*bool, error) {
1782	if pb == nil {
1783		return nil, errNilListValue("BOOL")
1784	}
1785	a := make([]*bool, len(pb.Values))
1786	for i, v := range pb.Values {
1787		if err := decodeValue(v, boolType(), &a[i]); err != nil {
1788			return nil, errDecodeArrayElement(i, v, "BOOL", err)
1789		}
1790	}
1791	return a, nil
1792}
1793
1794// decodeBoolArray decodes proto3.ListValue pb into a bool slice.
1795func decodeBoolArray(pb *proto3.ListValue) ([]bool, error) {
1796	if pb == nil {
1797		return nil, errNilListValue("BOOL")
1798	}
1799	a := make([]bool, len(pb.Values))
1800	for i, v := range pb.Values {
1801		if err := decodeValue(v, boolType(), &a[i]); err != nil {
1802			return nil, errDecodeArrayElement(i, v, "BOOL", err)
1803		}
1804	}
1805	return a, nil
1806}
1807
1808// decodeNullFloat64Array decodes proto3.ListValue pb into a NullFloat64 slice.
1809func decodeNullFloat64Array(pb *proto3.ListValue) ([]NullFloat64, error) {
1810	if pb == nil {
1811		return nil, errNilListValue("FLOAT64")
1812	}
1813	a := make([]NullFloat64, len(pb.Values))
1814	for i, v := range pb.Values {
1815		if err := decodeValue(v, floatType(), &a[i]); err != nil {
1816			return nil, errDecodeArrayElement(i, v, "FLOAT64", err)
1817		}
1818	}
1819	return a, nil
1820}
1821
1822// decodeFloat64PointerArray decodes proto3.ListValue pb into a NullFloat64 slice.
1823func decodeFloat64PointerArray(pb *proto3.ListValue) ([]*float64, error) {
1824	if pb == nil {
1825		return nil, errNilListValue("FLOAT64")
1826	}
1827	a := make([]*float64, len(pb.Values))
1828	for i, v := range pb.Values {
1829		if err := decodeValue(v, floatType(), &a[i]); err != nil {
1830			return nil, errDecodeArrayElement(i, v, "FLOAT64", err)
1831		}
1832	}
1833	return a, nil
1834}
1835
1836// decodeFloat64Array decodes proto3.ListValue pb into a float64 slice.
1837func decodeFloat64Array(pb *proto3.ListValue) ([]float64, error) {
1838	if pb == nil {
1839		return nil, errNilListValue("FLOAT64")
1840	}
1841	a := make([]float64, len(pb.Values))
1842	for i, v := range pb.Values {
1843		if err := decodeValue(v, floatType(), &a[i]); err != nil {
1844			return nil, errDecodeArrayElement(i, v, "FLOAT64", err)
1845		}
1846	}
1847	return a, nil
1848}
1849
1850// decodeByteArray decodes proto3.ListValue pb into a slice of byte slice.
1851func decodeByteArray(pb *proto3.ListValue) ([][]byte, error) {
1852	if pb == nil {
1853		return nil, errNilListValue("BYTES")
1854	}
1855	a := make([][]byte, len(pb.Values))
1856	for i, v := range pb.Values {
1857		if err := decodeValue(v, bytesType(), &a[i]); err != nil {
1858			return nil, errDecodeArrayElement(i, v, "BYTES", err)
1859		}
1860	}
1861	return a, nil
1862}
1863
1864// decodeNullTimeArray decodes proto3.ListValue pb into a NullTime slice.
1865func decodeNullTimeArray(pb *proto3.ListValue) ([]NullTime, error) {
1866	if pb == nil {
1867		return nil, errNilListValue("TIMESTAMP")
1868	}
1869	a := make([]NullTime, len(pb.Values))
1870	for i, v := range pb.Values {
1871		if err := decodeValue(v, timeType(), &a[i]); err != nil {
1872			return nil, errDecodeArrayElement(i, v, "TIMESTAMP", err)
1873		}
1874	}
1875	return a, nil
1876}
1877
1878// decodeTimePointerArray decodes proto3.ListValue pb into a NullTime slice.
1879func decodeTimePointerArray(pb *proto3.ListValue) ([]*time.Time, error) {
1880	if pb == nil {
1881		return nil, errNilListValue("TIMESTAMP")
1882	}
1883	a := make([]*time.Time, len(pb.Values))
1884	for i, v := range pb.Values {
1885		if err := decodeValue(v, timeType(), &a[i]); err != nil {
1886			return nil, errDecodeArrayElement(i, v, "TIMESTAMP", err)
1887		}
1888	}
1889	return a, nil
1890}
1891
1892// decodeTimeArray decodes proto3.ListValue pb into a time.Time slice.
1893func decodeTimeArray(pb *proto3.ListValue) ([]time.Time, error) {
1894	if pb == nil {
1895		return nil, errNilListValue("TIMESTAMP")
1896	}
1897	a := make([]time.Time, len(pb.Values))
1898	for i, v := range pb.Values {
1899		if err := decodeValue(v, timeType(), &a[i]); err != nil {
1900			return nil, errDecodeArrayElement(i, v, "TIMESTAMP", err)
1901		}
1902	}
1903	return a, nil
1904}
1905
1906// decodeNullDateArray decodes proto3.ListValue pb into a NullDate slice.
1907func decodeNullDateArray(pb *proto3.ListValue) ([]NullDate, error) {
1908	if pb == nil {
1909		return nil, errNilListValue("DATE")
1910	}
1911	a := make([]NullDate, len(pb.Values))
1912	for i, v := range pb.Values {
1913		if err := decodeValue(v, dateType(), &a[i]); err != nil {
1914			return nil, errDecodeArrayElement(i, v, "DATE", err)
1915		}
1916	}
1917	return a, nil
1918}
1919
1920// decodeDatePointerArray decodes proto3.ListValue pb into a *civil.Date slice.
1921func decodeDatePointerArray(pb *proto3.ListValue) ([]*civil.Date, error) {
1922	if pb == nil {
1923		return nil, errNilListValue("DATE")
1924	}
1925	a := make([]*civil.Date, len(pb.Values))
1926	for i, v := range pb.Values {
1927		if err := decodeValue(v, dateType(), &a[i]); err != nil {
1928			return nil, errDecodeArrayElement(i, v, "DATE", err)
1929		}
1930	}
1931	return a, nil
1932}
1933
1934// decodeDateArray decodes proto3.ListValue pb into a civil.Date slice.
1935func decodeDateArray(pb *proto3.ListValue) ([]civil.Date, error) {
1936	if pb == nil {
1937		return nil, errNilListValue("DATE")
1938	}
1939	a := make([]civil.Date, len(pb.Values))
1940	for i, v := range pb.Values {
1941		if err := decodeValue(v, dateType(), &a[i]); err != nil {
1942			return nil, errDecodeArrayElement(i, v, "DATE", err)
1943		}
1944	}
1945	return a, nil
1946}
1947
1948func errNotStructElement(i int, v *proto3.Value) error {
1949	return errDecodeArrayElement(i, v, "STRUCT",
1950		spannerErrorf(codes.FailedPrecondition, "%v(type: %T) doesn't encode Cloud Spanner STRUCT", v, v))
1951}
1952
1953// decodeRowArray decodes proto3.ListValue pb into a NullRow slice according to
1954// the structural information given in sppb.StructType ty.
1955func decodeRowArray(ty *sppb.StructType, pb *proto3.ListValue) ([]NullRow, error) {
1956	if pb == nil {
1957		return nil, errNilListValue("STRUCT")
1958	}
1959	a := make([]NullRow, len(pb.Values))
1960	for i := range pb.Values {
1961		switch v := pb.Values[i].GetKind().(type) {
1962		case *proto3.Value_ListValue:
1963			a[i] = NullRow{
1964				Row: Row{
1965					fields: ty.Fields,
1966					vals:   v.ListValue.Values,
1967				},
1968				Valid: true,
1969			}
1970		// Null elements not currently supported by the server, see
1971		// https://cloud.google.com/spanner/docs/query-syntax#using-structs-with-select
1972		case *proto3.Value_NullValue:
1973			// no-op, a[i] is NullRow{} already
1974		default:
1975			return nil, errNotStructElement(i, pb.Values[i])
1976		}
1977	}
1978	return a, nil
1979}
1980
1981// errNilSpannerStructType returns error for unexpected nil Cloud Spanner STRUCT
1982// schema type in decoding.
1983func errNilSpannerStructType() error {
1984	return spannerErrorf(codes.FailedPrecondition, "unexpected nil StructType in decoding Cloud Spanner STRUCT")
1985}
1986
1987// errUnnamedField returns error for decoding a Cloud Spanner STRUCT with
1988// unnamed field into a Go struct.
1989func errUnnamedField(ty *sppb.StructType, i int) error {
1990	return spannerErrorf(codes.InvalidArgument, "unnamed field %v in Cloud Spanner STRUCT %+v", i, ty)
1991}
1992
1993// errNoOrDupGoField returns error for decoding a Cloud Spanner
1994// STRUCT into a Go struct which is either missing a field, or has duplicate
1995// fields.
1996func errNoOrDupGoField(s interface{}, f string) error {
1997	return spannerErrorf(codes.InvalidArgument, "Go struct %+v(type %T) has no or duplicate fields for Cloud Spanner STRUCT field %v", s, s, f)
1998}
1999
2000// errDupColNames returns error for duplicated Cloud Spanner STRUCT field names
2001// found in decoding a Cloud Spanner STRUCT into a Go struct.
2002func errDupSpannerField(f string, ty *sppb.StructType) error {
2003	return spannerErrorf(codes.InvalidArgument, "duplicated field name %q in Cloud Spanner STRUCT %+v", f, ty)
2004}
2005
2006// errDecodeStructField returns error for failure in decoding a single field of
2007// a Cloud Spanner STRUCT.
2008func errDecodeStructField(ty *sppb.StructType, f string, err error) error {
2009	var se *Error
2010	if !errorAs(err, &se) {
2011		return spannerErrorf(codes.Unknown,
2012			"cannot decode field %v of Cloud Spanner STRUCT %+v, error = <%v>", f, ty, err)
2013	}
2014	se.decorate(fmt.Sprintf("cannot decode field %v of Cloud Spanner STRUCT %+v", f, ty))
2015	return se
2016}
2017
2018// decodeStruct decodes proto3.ListValue pb into struct referenced by pointer
2019// ptr, according to
2020// the structural information given in sppb.StructType ty.
2021func decodeStruct(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}) error {
2022	if reflect.ValueOf(ptr).IsNil() {
2023		return errNilDst(ptr)
2024	}
2025	if ty == nil {
2026		return errNilSpannerStructType()
2027	}
2028	// t holds the structural information of ptr.
2029	t := reflect.TypeOf(ptr).Elem()
2030	// v is the actual value that ptr points to.
2031	v := reflect.ValueOf(ptr).Elem()
2032
2033	fields, err := fieldCache.Fields(t)
2034	if err != nil {
2035		return toSpannerError(err)
2036	}
2037	seen := map[string]bool{}
2038	for i, f := range ty.Fields {
2039		if f.Name == "" {
2040			return errUnnamedField(ty, i)
2041		}
2042		sf := fields.Match(f.Name)
2043		if sf == nil {
2044			return errNoOrDupGoField(ptr, f.Name)
2045		}
2046		if seen[f.Name] {
2047			// We don't allow duplicated field name.
2048			return errDupSpannerField(f.Name, ty)
2049		}
2050		// Try to decode a single field.
2051		if err := decodeValue(pb.Values[i], f.Type, v.FieldByIndex(sf.Index).Addr().Interface()); err != nil {
2052			return errDecodeStructField(ty, f.Name, err)
2053		}
2054		// Mark field f.Name as processed.
2055		seen[f.Name] = true
2056	}
2057	return nil
2058}
2059
2060// isPtrStructPtrSlice returns true if ptr is a pointer to a slice of struct pointers.
2061func isPtrStructPtrSlice(t reflect.Type) bool {
2062	if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Slice {
2063		// t is not a pointer to a slice.
2064		return false
2065	}
2066	if t = t.Elem(); t.Elem().Kind() != reflect.Ptr || t.Elem().Elem().Kind() != reflect.Struct {
2067		// the slice that t points to is not a slice of struct pointers.
2068		return false
2069	}
2070	return true
2071}
2072
2073// decodeStructArray decodes proto3.ListValue pb into struct slice referenced by
2074// pointer ptr, according to the
2075// structural information given in a sppb.StructType.
2076func decodeStructArray(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}) error {
2077	if pb == nil {
2078		return errNilListValue("STRUCT")
2079	}
2080	// Type of the struct pointers stored in the slice that ptr points to.
2081	ts := reflect.TypeOf(ptr).Elem().Elem()
2082	// The slice that ptr points to, might be nil at this point.
2083	v := reflect.ValueOf(ptr).Elem()
2084	// Allocate empty slice.
2085	v.Set(reflect.MakeSlice(v.Type(), 0, len(pb.Values)))
2086	// Decode every struct in pb.Values.
2087	for i, pv := range pb.Values {
2088		// Check if pv is a NULL value.
2089		if _, isNull := pv.Kind.(*proto3.Value_NullValue); isNull {
2090			// Append a nil pointer to the slice.
2091			v.Set(reflect.Append(v, reflect.New(ts).Elem()))
2092			continue
2093		}
2094		// Allocate empty struct.
2095		s := reflect.New(ts.Elem())
2096		// Get proto3.ListValue l from proto3.Value pv.
2097		l, err := getListValue(pv)
2098		if err != nil {
2099			return errDecodeArrayElement(i, pv, "STRUCT", err)
2100		}
2101		// Decode proto3.ListValue l into struct referenced by s.Interface().
2102		if err = decodeStruct(ty, l, s.Interface()); err != nil {
2103			return errDecodeArrayElement(i, pv, "STRUCT", err)
2104		}
2105		// Append the decoded struct back into the slice.
2106		v.Set(reflect.Append(v, s))
2107	}
2108	return nil
2109}
2110
2111// errEncoderUnsupportedType returns error for not being able to encode a value
2112// of certain type.
2113func errEncoderUnsupportedType(v interface{}) error {
2114	return spannerErrorf(codes.InvalidArgument, "client doesn't support type %T", v)
2115}
2116
2117// encodeValue encodes a Go native type into a proto3.Value.
2118func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) {
2119	pb := &proto3.Value{
2120		Kind: &proto3.Value_NullValue{NullValue: proto3.NullValue_NULL_VALUE},
2121	}
2122	var pt *sppb.Type
2123	var err error
2124	switch v := v.(type) {
2125	case nil:
2126	case string:
2127		pb.Kind = stringKind(v)
2128		pt = stringType()
2129	case NullString:
2130		if v.Valid {
2131			return encodeValue(v.StringVal)
2132		}
2133		pt = stringType()
2134	case []string:
2135		if v != nil {
2136			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2137			if err != nil {
2138				return nil, nil, err
2139			}
2140		}
2141		pt = listType(stringType())
2142	case []NullString:
2143		if v != nil {
2144			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2145			if err != nil {
2146				return nil, nil, err
2147			}
2148		}
2149		pt = listType(stringType())
2150	case *string:
2151		if v != nil {
2152			return encodeValue(*v)
2153		}
2154		pt = stringType()
2155	case []*string:
2156		if v != nil {
2157			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2158			if err != nil {
2159				return nil, nil, err
2160			}
2161		}
2162		pt = listType(stringType())
2163	case []byte:
2164		if v != nil {
2165			pb.Kind = stringKind(base64.StdEncoding.EncodeToString(v))
2166		}
2167		pt = bytesType()
2168	case [][]byte:
2169		if v != nil {
2170			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2171			if err != nil {
2172				return nil, nil, err
2173			}
2174		}
2175		pt = listType(bytesType())
2176	case int:
2177		pb.Kind = stringKind(strconv.FormatInt(int64(v), 10))
2178		pt = intType()
2179	case []int:
2180		if v != nil {
2181			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2182			if err != nil {
2183				return nil, nil, err
2184			}
2185		}
2186		pt = listType(intType())
2187	case int64:
2188		pb.Kind = stringKind(strconv.FormatInt(v, 10))
2189		pt = intType()
2190	case []int64:
2191		if v != nil {
2192			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2193			if err != nil {
2194				return nil, nil, err
2195			}
2196		}
2197		pt = listType(intType())
2198	case NullInt64:
2199		if v.Valid {
2200			return encodeValue(v.Int64)
2201		}
2202		pt = intType()
2203	case []NullInt64:
2204		if v != nil {
2205			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2206			if err != nil {
2207				return nil, nil, err
2208			}
2209		}
2210		pt = listType(intType())
2211	case *int64:
2212		if v != nil {
2213			return encodeValue(*v)
2214		}
2215		pt = intType()
2216	case []*int64:
2217		if v != nil {
2218			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2219			if err != nil {
2220				return nil, nil, err
2221			}
2222		}
2223		pt = listType(intType())
2224	case bool:
2225		pb.Kind = &proto3.Value_BoolValue{BoolValue: v}
2226		pt = boolType()
2227	case []bool:
2228		if v != nil {
2229			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2230			if err != nil {
2231				return nil, nil, err
2232			}
2233		}
2234		pt = listType(boolType())
2235	case NullBool:
2236		if v.Valid {
2237			return encodeValue(v.Bool)
2238		}
2239		pt = boolType()
2240	case []NullBool:
2241		if v != nil {
2242			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2243			if err != nil {
2244				return nil, nil, err
2245			}
2246		}
2247		pt = listType(boolType())
2248	case *bool:
2249		if v != nil {
2250			return encodeValue(*v)
2251		}
2252		pt = boolType()
2253	case []*bool:
2254		if v != nil {
2255			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2256			if err != nil {
2257				return nil, nil, err
2258			}
2259		}
2260		pt = listType(boolType())
2261	case float64:
2262		pb.Kind = &proto3.Value_NumberValue{NumberValue: v}
2263		pt = floatType()
2264	case []float64:
2265		if v != nil {
2266			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2267			if err != nil {
2268				return nil, nil, err
2269			}
2270		}
2271		pt = listType(floatType())
2272	case NullFloat64:
2273		if v.Valid {
2274			return encodeValue(v.Float64)
2275		}
2276		pt = floatType()
2277	case []NullFloat64:
2278		if v != nil {
2279			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2280			if err != nil {
2281				return nil, nil, err
2282			}
2283		}
2284		pt = listType(floatType())
2285	case *float64:
2286		if v != nil {
2287			return encodeValue(*v)
2288		}
2289		pt = floatType()
2290	case []*float64:
2291		if v != nil {
2292			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2293			if err != nil {
2294				return nil, nil, err
2295			}
2296		}
2297		pt = listType(floatType())
2298	case time.Time:
2299		if v == commitTimestamp {
2300			pb.Kind = stringKind(commitTimestampPlaceholderString)
2301		} else {
2302			pb.Kind = stringKind(v.UTC().Format(time.RFC3339Nano))
2303		}
2304		pt = timeType()
2305	case []time.Time:
2306		if v != nil {
2307			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2308			if err != nil {
2309				return nil, nil, err
2310			}
2311		}
2312		pt = listType(timeType())
2313	case NullTime:
2314		if v.Valid {
2315			return encodeValue(v.Time)
2316		}
2317		pt = timeType()
2318	case []NullTime:
2319		if v != nil {
2320			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2321			if err != nil {
2322				return nil, nil, err
2323			}
2324		}
2325		pt = listType(timeType())
2326	case *time.Time:
2327		if v != nil {
2328			return encodeValue(*v)
2329		}
2330		pt = timeType()
2331	case []*time.Time:
2332		if v != nil {
2333			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2334			if err != nil {
2335				return nil, nil, err
2336			}
2337		}
2338		pt = listType(timeType())
2339	case civil.Date:
2340		pb.Kind = stringKind(v.String())
2341		pt = dateType()
2342	case []civil.Date:
2343		if v != nil {
2344			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2345			if err != nil {
2346				return nil, nil, err
2347			}
2348		}
2349		pt = listType(dateType())
2350	case NullDate:
2351		if v.Valid {
2352			return encodeValue(v.Date)
2353		}
2354		pt = dateType()
2355	case []NullDate:
2356		if v != nil {
2357			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2358			if err != nil {
2359				return nil, nil, err
2360			}
2361		}
2362		pt = listType(dateType())
2363	case *civil.Date:
2364		if v != nil {
2365			return encodeValue(*v)
2366		}
2367		pt = dateType()
2368	case []*civil.Date:
2369		if v != nil {
2370			pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })
2371			if err != nil {
2372				return nil, nil, err
2373			}
2374		}
2375		pt = listType(dateType())
2376	case GenericColumnValue:
2377		// Deep clone to ensure subsequent changes to v before
2378		// transmission don't affect our encoded value.
2379		pb = proto.Clone(v.Value).(*proto3.Value)
2380		pt = proto.Clone(v.Type).(*sppb.Type)
2381	case []GenericColumnValue:
2382		return nil, nil, errEncoderUnsupportedType(v)
2383	default:
2384		// Check if the value is a variant of a base type.
2385		decodableType := getDecodableSpannerType(v)
2386		if decodableType != spannerTypeUnknown && decodableType != spannerTypeInvalid {
2387			converted, err := convertCustomTypeValue(decodableType, v)
2388			if err != nil {
2389				return nil, nil, err
2390			}
2391			return encodeValue(converted)
2392		}
2393
2394		if !isStructOrArrayOfStructValue(v) {
2395			return nil, nil, errEncoderUnsupportedType(v)
2396		}
2397		typ := reflect.TypeOf(v)
2398
2399		// Value is a Go struct value/ptr.
2400		if (typ.Kind() == reflect.Struct) ||
2401			(typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct) {
2402			return encodeStruct(v)
2403		}
2404
2405		// Value is a slice of Go struct values/ptrs.
2406		if typ.Kind() == reflect.Slice {
2407			return encodeStructArray(v)
2408		}
2409	}
2410	return pb, pt, nil
2411}
2412
2413func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (interface{}, error) {
2414	// destination will be initialized to a base type. The input value will be
2415	// converted to this type and copied to destination.
2416	var destination reflect.Value
2417	switch sourceType {
2418	case spannerTypeInvalid:
2419		return nil, fmt.Errorf("cannot encode a value to type spannerTypeInvalid")
2420	case spannerTypeNonNullString:
2421		destination = reflect.Indirect(reflect.New(reflect.TypeOf("")))
2422	case spannerTypeNullString:
2423		destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullString{})))
2424	case spannerTypeByteArray:
2425		// Return a nil array directly if the input value is nil instead of
2426		// creating an empty slice and returning that.
2427		if reflect.ValueOf(v).IsNil() {
2428			return []byte(nil), nil
2429		}
2430		destination = reflect.MakeSlice(reflect.TypeOf([]byte{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2431	case spannerTypeNonNullInt64:
2432		destination = reflect.Indirect(reflect.New(reflect.TypeOf(int64(0))))
2433	case spannerTypeNullInt64:
2434		destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullInt64{})))
2435	case spannerTypeNonNullBool:
2436		destination = reflect.Indirect(reflect.New(reflect.TypeOf(false)))
2437	case spannerTypeNullBool:
2438		destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullBool{})))
2439	case spannerTypeNonNullFloat64:
2440		destination = reflect.Indirect(reflect.New(reflect.TypeOf(float64(0.0))))
2441	case spannerTypeNullFloat64:
2442		destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullFloat64{})))
2443	case spannerTypeNonNullTime:
2444		destination = reflect.Indirect(reflect.New(reflect.TypeOf(time.Time{})))
2445	case spannerTypeNullTime:
2446		destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullTime{})))
2447	case spannerTypeNonNullDate:
2448		destination = reflect.Indirect(reflect.New(reflect.TypeOf(civil.Date{})))
2449	case spannerTypeNullDate:
2450		destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullDate{})))
2451	case spannerTypeArrayOfNonNullString:
2452		if reflect.ValueOf(v).IsNil() {
2453			return []string(nil), nil
2454		}
2455		destination = reflect.MakeSlice(reflect.TypeOf([]string{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2456	case spannerTypeArrayOfNullString:
2457		if reflect.ValueOf(v).IsNil() {
2458			return []NullString(nil), nil
2459		}
2460		destination = reflect.MakeSlice(reflect.TypeOf([]NullString{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2461	case spannerTypeArrayOfByteArray:
2462		if reflect.ValueOf(v).IsNil() {
2463			return [][]byte(nil), nil
2464		}
2465		destination = reflect.MakeSlice(reflect.TypeOf([][]byte{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2466	case spannerTypeArrayOfNonNullInt64:
2467		if reflect.ValueOf(v).IsNil() {
2468			return []int64(nil), nil
2469		}
2470		destination = reflect.MakeSlice(reflect.TypeOf([]int64{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2471	case spannerTypeArrayOfNullInt64:
2472		if reflect.ValueOf(v).IsNil() {
2473			return []NullInt64(nil), nil
2474		}
2475		destination = reflect.MakeSlice(reflect.TypeOf([]NullInt64{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2476	case spannerTypeArrayOfNonNullBool:
2477		if reflect.ValueOf(v).IsNil() {
2478			return []bool(nil), nil
2479		}
2480		destination = reflect.MakeSlice(reflect.TypeOf([]bool{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2481	case spannerTypeArrayOfNullBool:
2482		if reflect.ValueOf(v).IsNil() {
2483			return []NullBool(nil), nil
2484		}
2485		destination = reflect.MakeSlice(reflect.TypeOf([]NullBool{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2486	case spannerTypeArrayOfNonNullFloat64:
2487		if reflect.ValueOf(v).IsNil() {
2488			return []float64(nil), nil
2489		}
2490		destination = reflect.MakeSlice(reflect.TypeOf([]float64{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2491	case spannerTypeArrayOfNullFloat64:
2492		if reflect.ValueOf(v).IsNil() {
2493			return []NullFloat64(nil), nil
2494		}
2495		destination = reflect.MakeSlice(reflect.TypeOf([]NullFloat64{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2496	case spannerTypeArrayOfNonNullTime:
2497		if reflect.ValueOf(v).IsNil() {
2498			return []time.Time(nil), nil
2499		}
2500		destination = reflect.MakeSlice(reflect.TypeOf([]time.Time{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2501	case spannerTypeArrayOfNullTime:
2502		if reflect.ValueOf(v).IsNil() {
2503			return []NullTime(nil), nil
2504		}
2505		destination = reflect.MakeSlice(reflect.TypeOf([]NullTime{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2506	case spannerTypeArrayOfNonNullDate:
2507		if reflect.ValueOf(v).IsNil() {
2508			return []civil.Date(nil), nil
2509		}
2510		destination = reflect.MakeSlice(reflect.TypeOf([]civil.Date{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2511	case spannerTypeArrayOfNullDate:
2512		if reflect.ValueOf(v).IsNil() {
2513			return []NullDate(nil), nil
2514		}
2515		destination = reflect.MakeSlice(reflect.TypeOf([]NullDate{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap())
2516	default:
2517		// This should not be possible.
2518		return nil, fmt.Errorf("unknown decodable type found: %v", sourceType)
2519	}
2520	// destination has been initialized. Convert and copy the input value to
2521	// destination. That must be done per element if the input type is a slice
2522	// or an array.
2523	if destination.Kind() == reflect.Slice || destination.Kind() == reflect.Array {
2524		sourceSlice := reflect.ValueOf(v)
2525		for i := 0; i < destination.Len(); i++ {
2526			source := reflect.Indirect(sourceSlice.Index(i))
2527			destination.Index(i).Set(source.Convert(destination.Type().Elem()))
2528		}
2529	} else {
2530		source := reflect.Indirect(reflect.ValueOf(v))
2531		destination.Set(source.Convert(destination.Type()))
2532	}
2533	// Return the converted value.
2534	return destination.Interface(), nil
2535}
2536
2537// Encodes a Go struct value/ptr in v to the spanner Value and Type protos. v
2538// itself must be non-nil.
2539func encodeStruct(v interface{}) (*proto3.Value, *sppb.Type, error) {
2540	typ := reflect.TypeOf(v)
2541	val := reflect.ValueOf(v)
2542
2543	// Pointer to struct.
2544	if typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct {
2545		typ = typ.Elem()
2546		if val.IsNil() {
2547			// nil pointer to struct, representing a NULL STRUCT value. Use a
2548			// dummy value to get the type.
2549			_, st, err := encodeStruct(reflect.Zero(typ).Interface())
2550			if err != nil {
2551				return nil, nil, err
2552			}
2553			return nullProto(), st, nil
2554		}
2555		val = val.Elem()
2556	}
2557
2558	if typ.Kind() != reflect.Struct {
2559		return nil, nil, errEncoderUnsupportedType(v)
2560	}
2561
2562	stf := make([]*sppb.StructType_Field, 0, typ.NumField())
2563	stv := make([]*proto3.Value, 0, typ.NumField())
2564
2565	for i := 0; i < typ.NumField(); i++ {
2566		// If the field has a 'spanner' tag, use the value of that tag as the field name.
2567		// This is used to build STRUCT types with unnamed/duplicate fields.
2568		sf := typ.Field(i)
2569		fval := val.Field(i)
2570
2571		// Embedded fields are not allowed.
2572		if sf.Anonymous {
2573			return nil, nil, errUnsupportedEmbeddedStructFields(sf.Name)
2574		}
2575
2576		// Unexported fields are ignored.
2577		if !fval.CanInterface() {
2578			continue
2579		}
2580
2581		fname, ok := sf.Tag.Lookup("spanner")
2582		if !ok {
2583			fname = sf.Name
2584		}
2585
2586		eval, etype, err := encodeValue(fval.Interface())
2587		if err != nil {
2588			return nil, nil, err
2589		}
2590
2591		stf = append(stf, mkField(fname, etype))
2592		stv = append(stv, eval)
2593	}
2594
2595	return listProto(stv...), structType(stf...), nil
2596}
2597
2598// Encodes a slice of Go struct values/ptrs in v to the spanner Value and Type
2599// protos. v itself must be non-nil.
2600func encodeStructArray(v interface{}) (*proto3.Value, *sppb.Type, error) {
2601	etyp := reflect.TypeOf(v).Elem()
2602	sliceval := reflect.ValueOf(v)
2603
2604	// Slice of pointers to structs.
2605	if etyp.Kind() == reflect.Ptr {
2606		etyp = etyp.Elem()
2607	}
2608
2609	// Use a dummy struct value to get the element type.
2610	_, elemTyp, err := encodeStruct(reflect.Zero(etyp).Interface())
2611	if err != nil {
2612		return nil, nil, err
2613	}
2614
2615	// nil slice represents a NULL array-of-struct.
2616	if sliceval.IsNil() {
2617		return nullProto(), listType(elemTyp), nil
2618	}
2619
2620	values := make([]*proto3.Value, 0, sliceval.Len())
2621
2622	for i := 0; i < sliceval.Len(); i++ {
2623		ev, _, err := encodeStruct(sliceval.Index(i).Interface())
2624		if err != nil {
2625			return nil, nil, err
2626		}
2627		values = append(values, ev)
2628	}
2629	return listProto(values...), listType(elemTyp), nil
2630}
2631
2632func isStructOrArrayOfStructValue(v interface{}) bool {
2633	typ := reflect.TypeOf(v)
2634	if typ.Kind() == reflect.Slice {
2635		typ = typ.Elem()
2636	}
2637	if typ.Kind() == reflect.Ptr {
2638		typ = typ.Elem()
2639	}
2640	return typ.Kind() == reflect.Struct
2641}
2642
2643func isSupportedMutationType(v interface{}) bool {
2644	switch v.(type) {
2645	case nil, string, *string, NullString, []string, []*string, []NullString,
2646		[]byte, [][]byte,
2647		int, []int, int64, *int64, []int64, []*int64, NullInt64, []NullInt64,
2648		bool, *bool, []bool, []*bool, NullBool, []NullBool,
2649		float64, *float64, []float64, []*float64, NullFloat64, []NullFloat64,
2650		time.Time, *time.Time, []time.Time, []*time.Time, NullTime, []NullTime,
2651		civil.Date, *civil.Date, []civil.Date, []*civil.Date, NullDate, []NullDate,
2652		GenericColumnValue:
2653		return true
2654	default:
2655		return false
2656	}
2657}
2658
2659// encodeValueArray encodes a Value array into a proto3.ListValue.
2660func encodeValueArray(vs []interface{}) (*proto3.ListValue, error) {
2661	lv := &proto3.ListValue{}
2662	lv.Values = make([]*proto3.Value, 0, len(vs))
2663	for _, v := range vs {
2664		if !isSupportedMutationType(v) {
2665			return nil, errEncoderUnsupportedType(v)
2666		}
2667		pb, _, err := encodeValue(v)
2668		if err != nil {
2669			return nil, err
2670		}
2671		lv.Values = append(lv.Values, pb)
2672	}
2673	return lv, nil
2674}
2675
2676// encodeArray assumes that all values of the array element type encode without
2677// error.
2678func encodeArray(len int, at func(int) interface{}) (*proto3.Value, error) {
2679	vs := make([]*proto3.Value, len)
2680	var err error
2681	for i := 0; i < len; i++ {
2682		vs[i], _, err = encodeValue(at(i))
2683		if err != nil {
2684			return nil, err
2685		}
2686	}
2687	return listProto(vs...), nil
2688}
2689
2690func spannerTagParser(t reflect.StructTag) (name string, keep bool, other interface{}, err error) {
2691	if s := t.Get("spanner"); s != "" {
2692		if s == "-" {
2693			return "", false, nil, nil
2694		}
2695		return s, true, nil, nil
2696	}
2697	return "", true, nil, nil
2698}
2699
2700var fieldCache = fields.NewCache(spannerTagParser, nil, nil)
2701
2702func trimDoubleQuotes(payload []byte) ([]byte, error) {
2703	if len(payload) <= 1 || payload[0] != '"' || payload[len(payload)-1] != '"' {
2704		return nil, fmt.Errorf("payload is too short or not wrapped with double quotes: got %q", string(payload))
2705	}
2706	// Remove the double quotes at the beginning and the end.
2707	return payload[1 : len(payload)-1], nil
2708}
2709