1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Type conversions for Scan.
6
7package sql
8
9import (
10	"database/sql/driver"
11	"errors"
12	"fmt"
13	"reflect"
14	"strconv"
15	"time"
16	"unicode"
17	"unicode/utf8"
18)
19
20var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
21
22func describeNamedValue(nv *driver.NamedValue) string {
23	if len(nv.Name) == 0 {
24		return fmt.Sprintf("$%d", nv.Ordinal)
25	}
26	return fmt.Sprintf("with name %q", nv.Name)
27}
28
29func validateNamedValueName(name string) error {
30	if len(name) == 0 {
31		return nil
32	}
33	r, _ := utf8.DecodeRuneInString(name)
34	if unicode.IsLetter(r) {
35		return nil
36	}
37	return fmt.Errorf("name %q does not begin with a letter", name)
38}
39
40// ccChecker wraps the driver.ColumnConverter and allows it to be used
41// as if it were a NamedValueChecker. If the driver ColumnConverter
42// is not present then the NamedValueChecker will return driver.ErrSkip.
43type ccChecker struct {
44	cci  driver.ColumnConverter
45	want int
46}
47
48func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
49	if c.cci == nil {
50		return driver.ErrSkip
51	}
52	// The column converter shouldn't be called on any index
53	// it isn't expecting. The final error will be thrown
54	// in the argument converter loop.
55	index := nv.Ordinal - 1
56	if c.want <= index {
57		return nil
58	}
59
60	// First, see if the value itself knows how to convert
61	// itself to a driver type. For example, a NullString
62	// struct changing into a string or nil.
63	if vr, ok := nv.Value.(driver.Valuer); ok {
64		sv, err := callValuerValue(vr)
65		if err != nil {
66			return err
67		}
68		if !driver.IsValue(sv) {
69			return fmt.Errorf("non-subset type %T returned from Value", sv)
70		}
71		nv.Value = sv
72	}
73
74	// Second, ask the column to sanity check itself. For
75	// example, drivers might use this to make sure that
76	// an int64 values being inserted into a 16-bit
77	// integer field is in range (before getting
78	// truncated), or that a nil can't go into a NOT NULL
79	// column before going across the network to get the
80	// same error.
81	var err error
82	arg := nv.Value
83	nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
84	if err != nil {
85		return err
86	}
87	if !driver.IsValue(nv.Value) {
88		return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
89	}
90	return nil
91}
92
93// defaultCheckNamedValue wraps the default ColumnConverter to have the same
94// function signature as the CheckNamedValue in the driver.NamedValueChecker
95// interface.
96func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
97	nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
98	return err
99}
100
101// driverArgsConnLocked converts arguments from callers of Stmt.Exec and
102// Stmt.Query into driver Values.
103//
104// The statement ds may be nil, if no statement is available.
105//
106// ci must be locked.
107func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
108	nvargs := make([]driver.NamedValue, len(args))
109
110	// -1 means the driver doesn't know how to count the number of
111	// placeholders, so we won't sanity check input here and instead let the
112	// driver deal with errors.
113	want := -1
114
115	var si driver.Stmt
116	var cc ccChecker
117	if ds != nil {
118		si = ds.si
119		want = ds.si.NumInput()
120		cc.want = want
121	}
122
123	// Check all types of interfaces from the start.
124	// Drivers may opt to use the NamedValueChecker for special
125	// argument types, then return driver.ErrSkip to pass it along
126	// to the column converter.
127	nvc, ok := si.(driver.NamedValueChecker)
128	if !ok {
129		nvc, ok = ci.(driver.NamedValueChecker)
130	}
131	cci, ok := si.(driver.ColumnConverter)
132	if ok {
133		cc.cci = cci
134	}
135
136	// Loop through all the arguments, checking each one.
137	// If no error is returned simply increment the index
138	// and continue. However if driver.ErrRemoveArgument
139	// is returned the argument is not included in the query
140	// argument list.
141	var err error
142	var n int
143	for _, arg := range args {
144		nv := &nvargs[n]
145		if np, ok := arg.(NamedArg); ok {
146			if err = validateNamedValueName(np.Name); err != nil {
147				return nil, err
148			}
149			arg = np.Value
150			nv.Name = np.Name
151		}
152		nv.Ordinal = n + 1
153		nv.Value = arg
154
155		// Checking sequence has four routes:
156		// A: 1. Default
157		// B: 1. NamedValueChecker 2. Column Converter 3. Default
158		// C: 1. NamedValueChecker 3. Default
159		// D: 1. Column Converter 2. Default
160		//
161		// The only time a Column Converter is called is first
162		// or after NamedValueConverter. If first it is handled before
163		// the nextCheck label. Thus for repeats tries only when the
164		// NamedValueConverter is selected should the Column Converter
165		// be used in the retry.
166		checker := defaultCheckNamedValue
167		nextCC := false
168		switch {
169		case nvc != nil:
170			nextCC = cci != nil
171			checker = nvc.CheckNamedValue
172		case cci != nil:
173			checker = cc.CheckNamedValue
174		}
175
176	nextCheck:
177		err = checker(nv)
178		switch err {
179		case nil:
180			n++
181			continue
182		case driver.ErrRemoveArgument:
183			nvargs = nvargs[:len(nvargs)-1]
184			continue
185		case driver.ErrSkip:
186			if nextCC {
187				nextCC = false
188				checker = cc.CheckNamedValue
189			} else {
190				checker = defaultCheckNamedValue
191			}
192			goto nextCheck
193		default:
194			return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
195		}
196	}
197
198	// Check the length of arguments after conversion to allow for omitted
199	// arguments.
200	if want != -1 && len(nvargs) != want {
201		return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
202	}
203
204	return nvargs, nil
205
206}
207
208// convertAssign is the same as convertAssignRows, but without the optional
209// rows argument.
210func convertAssign(dest, src any) error {
211	return convertAssignRows(dest, src, nil)
212}
213
214// convertAssignRows copies to dest the value in src, converting it if possible.
215// An error is returned if the copy would result in loss of information.
216// dest should be a pointer type. If rows is passed in, the rows will
217// be used as the parent for any cursor values converted from a
218// driver.Rows to a *Rows.
219func convertAssignRows(dest, src any, rows *Rows) error {
220	// Common cases, without reflect.
221	switch s := src.(type) {
222	case string:
223		switch d := dest.(type) {
224		case *string:
225			if d == nil {
226				return errNilPtr
227			}
228			*d = s
229			return nil
230		case *[]byte:
231			if d == nil {
232				return errNilPtr
233			}
234			*d = []byte(s)
235			return nil
236		case *RawBytes:
237			if d == nil {
238				return errNilPtr
239			}
240			*d = append((*d)[:0], s...)
241			return nil
242		}
243	case []byte:
244		switch d := dest.(type) {
245		case *string:
246			if d == nil {
247				return errNilPtr
248			}
249			*d = string(s)
250			return nil
251		case *any:
252			if d == nil {
253				return errNilPtr
254			}
255			*d = cloneBytes(s)
256			return nil
257		case *[]byte:
258			if d == nil {
259				return errNilPtr
260			}
261			*d = cloneBytes(s)
262			return nil
263		case *RawBytes:
264			if d == nil {
265				return errNilPtr
266			}
267			*d = s
268			return nil
269		}
270	case time.Time:
271		switch d := dest.(type) {
272		case *time.Time:
273			*d = s
274			return nil
275		case *string:
276			*d = s.Format(time.RFC3339Nano)
277			return nil
278		case *[]byte:
279			if d == nil {
280				return errNilPtr
281			}
282			*d = []byte(s.Format(time.RFC3339Nano))
283			return nil
284		case *RawBytes:
285			if d == nil {
286				return errNilPtr
287			}
288			*d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
289			return nil
290		}
291	case decimalDecompose:
292		switch d := dest.(type) {
293		case decimalCompose:
294			return d.Compose(s.Decompose(nil))
295		}
296	case nil:
297		switch d := dest.(type) {
298		case *any:
299			if d == nil {
300				return errNilPtr
301			}
302			*d = nil
303			return nil
304		case *[]byte:
305			if d == nil {
306				return errNilPtr
307			}
308			*d = nil
309			return nil
310		case *RawBytes:
311			if d == nil {
312				return errNilPtr
313			}
314			*d = nil
315			return nil
316		}
317	// The driver is returning a cursor the client may iterate over.
318	case driver.Rows:
319		switch d := dest.(type) {
320		case *Rows:
321			if d == nil {
322				return errNilPtr
323			}
324			if rows == nil {
325				return errors.New("invalid context to convert cursor rows, missing parent *Rows")
326			}
327			rows.closemu.Lock()
328			*d = Rows{
329				dc:          rows.dc,
330				releaseConn: func(error) {},
331				rowsi:       s,
332			}
333			// Chain the cancel function.
334			parentCancel := rows.cancel
335			rows.cancel = func() {
336				// When Rows.cancel is called, the closemu will be locked as well.
337				// So we can access rs.lasterr.
338				d.close(rows.lasterr)
339				if parentCancel != nil {
340					parentCancel()
341				}
342			}
343			rows.closemu.Unlock()
344			return nil
345		}
346	}
347
348	var sv reflect.Value
349
350	switch d := dest.(type) {
351	case *string:
352		sv = reflect.ValueOf(src)
353		switch sv.Kind() {
354		case reflect.Bool,
355			reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
356			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
357			reflect.Float32, reflect.Float64:
358			*d = asString(src)
359			return nil
360		}
361	case *[]byte:
362		sv = reflect.ValueOf(src)
363		if b, ok := asBytes(nil, sv); ok {
364			*d = b
365			return nil
366		}
367	case *RawBytes:
368		sv = reflect.ValueOf(src)
369		if b, ok := asBytes([]byte(*d)[:0], sv); ok {
370			*d = RawBytes(b)
371			return nil
372		}
373	case *bool:
374		bv, err := driver.Bool.ConvertValue(src)
375		if err == nil {
376			*d = bv.(bool)
377		}
378		return err
379	case *any:
380		*d = src
381		return nil
382	}
383
384	if scanner, ok := dest.(Scanner); ok {
385		return scanner.Scan(src)
386	}
387
388	dpv := reflect.ValueOf(dest)
389	if dpv.Kind() != reflect.Pointer {
390		return errors.New("destination not a pointer")
391	}
392	if dpv.IsNil() {
393		return errNilPtr
394	}
395
396	if !sv.IsValid() {
397		sv = reflect.ValueOf(src)
398	}
399
400	dv := reflect.Indirect(dpv)
401	if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
402		switch b := src.(type) {
403		case []byte:
404			dv.Set(reflect.ValueOf(cloneBytes(b)))
405		default:
406			dv.Set(sv)
407		}
408		return nil
409	}
410
411	if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
412		dv.Set(sv.Convert(dv.Type()))
413		return nil
414	}
415
416	// The following conversions use a string value as an intermediate representation
417	// to convert between various numeric types.
418	//
419	// This also allows scanning into user defined types such as "type Int int64".
420	// For symmetry, also check for string destination types.
421	switch dv.Kind() {
422	case reflect.Pointer:
423		if src == nil {
424			dv.Set(reflect.Zero(dv.Type()))
425			return nil
426		}
427		dv.Set(reflect.New(dv.Type().Elem()))
428		return convertAssignRows(dv.Interface(), src, rows)
429	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
430		if src == nil {
431			return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
432		}
433		s := asString(src)
434		i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
435		if err != nil {
436			err = strconvErr(err)
437			return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
438		}
439		dv.SetInt(i64)
440		return nil
441	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
442		if src == nil {
443			return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
444		}
445		s := asString(src)
446		u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
447		if err != nil {
448			err = strconvErr(err)
449			return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
450		}
451		dv.SetUint(u64)
452		return nil
453	case reflect.Float32, reflect.Float64:
454		if src == nil {
455			return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
456		}
457		s := asString(src)
458		f64, err := strconv.ParseFloat(s, dv.Type().Bits())
459		if err != nil {
460			err = strconvErr(err)
461			return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
462		}
463		dv.SetFloat(f64)
464		return nil
465	case reflect.String:
466		if src == nil {
467			return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
468		}
469		switch v := src.(type) {
470		case string:
471			dv.SetString(v)
472			return nil
473		case []byte:
474			dv.SetString(string(v))
475			return nil
476		}
477	}
478
479	return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
480}
481
482func strconvErr(err error) error {
483	if ne, ok := err.(*strconv.NumError); ok {
484		return ne.Err
485	}
486	return err
487}
488
489func cloneBytes(b []byte) []byte {
490	if b == nil {
491		return nil
492	}
493	c := make([]byte, len(b))
494	copy(c, b)
495	return c
496}
497
498func asString(src any) string {
499	switch v := src.(type) {
500	case string:
501		return v
502	case []byte:
503		return string(v)
504	}
505	rv := reflect.ValueOf(src)
506	switch rv.Kind() {
507	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
508		return strconv.FormatInt(rv.Int(), 10)
509	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
510		return strconv.FormatUint(rv.Uint(), 10)
511	case reflect.Float64:
512		return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
513	case reflect.Float32:
514		return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
515	case reflect.Bool:
516		return strconv.FormatBool(rv.Bool())
517	}
518	return fmt.Sprintf("%v", src)
519}
520
521func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
522	switch rv.Kind() {
523	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
524		return strconv.AppendInt(buf, rv.Int(), 10), true
525	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
526		return strconv.AppendUint(buf, rv.Uint(), 10), true
527	case reflect.Float32:
528		return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
529	case reflect.Float64:
530		return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
531	case reflect.Bool:
532		return strconv.AppendBool(buf, rv.Bool()), true
533	case reflect.String:
534		s := rv.String()
535		return append(buf, s...), true
536	}
537	return
538}
539
540var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
541
542// callValuerValue returns vr.Value(), with one exception:
543// If vr.Value is an auto-generated method on a pointer type and the
544// pointer is nil, it would panic at runtime in the panicwrap
545// method. Treat it like nil instead.
546// Issue 8415.
547//
548// This is so people can implement driver.Value on value types and
549// still use nil pointers to those types to mean nil/NULL, just like
550// string/*string.
551//
552// This function is mirrored in the database/sql/driver package.
553func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
554	if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
555		rv.IsNil() &&
556		rv.Type().Elem().Implements(valuerReflectType) {
557		return nil, nil
558	}
559	return vr.Value()
560}
561
562// decimal composes or decomposes a decimal value to and from individual parts.
563// There are four parts: a boolean negative flag, a form byte with three possible states
564// (finite=0, infinite=1, NaN=2), a base-2 big-endian integer
565// coefficient (also known as a significand) as a []byte, and an int32 exponent.
566// These are composed into a final value as "decimal = (neg) (form=finite) coefficient * 10 ^ exponent".
567// A zero length coefficient is a zero value.
568// The big-endian integer coefficient stores the most significant byte first (at coefficient[0]).
569// If the form is not finite the coefficient and exponent should be ignored.
570// The negative parameter may be set to true for any form, although implementations are not required
571// to respect the negative parameter in the non-finite form.
572//
573// Implementations may choose to set the negative parameter to true on a zero or NaN value,
574// but implementations that do not differentiate between negative and positive
575// zero or NaN values should ignore the negative parameter without error.
576// If an implementation does not support Infinity it may be converted into a NaN without error.
577// If a value is set that is larger than what is supported by an implementation,
578// an error must be returned.
579// Implementations must return an error if a NaN or Infinity is attempted to be set while neither
580// are supported.
581//
582// NOTE(kardianos): This is an experimental interface. See https://golang.org/issue/30870
583type decimal interface {
584	decimalDecompose
585	decimalCompose
586}
587
588type decimalDecompose interface {
589	// Decompose returns the internal decimal state in parts.
590	// If the provided buf has sufficient capacity, buf may be returned as the coefficient with
591	// the value set and length set as appropriate.
592	Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
593}
594
595type decimalCompose interface {
596	// Compose sets the internal decimal value from parts. If the value cannot be
597	// represented then an error should be returned.
598	Compose(form byte, negative bool, coefficient []byte, exponent int32) error
599}
600