1package pgx
2
3import (
4	"database/sql/driver"
5	"fmt"
6	"reflect"
7
8	"github.com/jackc/pgtype"
9)
10
11type extendedQueryBuilder struct {
12	paramValues     [][]byte
13	paramValueBytes []byte
14	paramFormats    []int16
15	resultFormats   []int16
16
17	resetCount int
18}
19
20func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error {
21	f := chooseParameterFormatCode(ci, oid, arg)
22	eqb.paramFormats = append(eqb.paramFormats, f)
23
24	v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg)
25	if err != nil {
26		return err
27	}
28	eqb.paramValues = append(eqb.paramValues, v)
29
30	return nil
31}
32
33func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) {
34	eqb.resultFormats = append(eqb.resultFormats, f)
35}
36
37func (eqb *extendedQueryBuilder) Reset() {
38	eqb.paramValues = eqb.paramValues[0:0]
39	eqb.paramValueBytes = eqb.paramValueBytes[0:0]
40	eqb.paramFormats = eqb.paramFormats[0:0]
41	eqb.resultFormats = eqb.resultFormats[0:0]
42
43	eqb.resetCount++
44
45	// Every so often shrink our reserved memory if it is abnormally high
46	if eqb.resetCount%128 == 0 {
47		if cap(eqb.paramValues) > 64 {
48			eqb.paramValues = make([][]byte, 0, cap(eqb.paramValues)/2)
49		}
50
51		if cap(eqb.paramValueBytes) > 256 {
52			eqb.paramValueBytes = make([]byte, 0, cap(eqb.paramValueBytes)/2)
53		}
54
55		if cap(eqb.paramFormats) > 64 {
56			eqb.paramFormats = make([]int16, 0, cap(eqb.paramFormats)/2)
57		}
58		if cap(eqb.resultFormats) > 64 {
59			eqb.resultFormats = make([]int16, 0, cap(eqb.resultFormats)/2)
60		}
61	}
62
63}
64
65func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) {
66	if arg == nil {
67		return nil, nil
68	}
69
70	refVal := reflect.ValueOf(arg)
71	argIsPtr := refVal.Kind() == reflect.Ptr
72
73	if argIsPtr && refVal.IsNil() {
74		return nil, nil
75	}
76
77	if eqb.paramValueBytes == nil {
78		eqb.paramValueBytes = make([]byte, 0, 128)
79	}
80
81	var err error
82	var buf []byte
83	pos := len(eqb.paramValueBytes)
84
85	if arg, ok := arg.(string); ok {
86		return []byte(arg), nil
87	}
88
89	if formatCode == TextFormatCode {
90		if arg, ok := arg.(pgtype.TextEncoder); ok {
91			buf, err = arg.EncodeText(ci, eqb.paramValueBytes)
92			if err != nil {
93				return nil, err
94			}
95			if buf == nil {
96				return nil, nil
97			}
98			eqb.paramValueBytes = buf
99			return eqb.paramValueBytes[pos:], nil
100		}
101	} else if formatCode == BinaryFormatCode {
102		if arg, ok := arg.(pgtype.BinaryEncoder); ok {
103			buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes)
104			if err != nil {
105				return nil, err
106			}
107			if buf == nil {
108				return nil, nil
109			}
110			eqb.paramValueBytes = buf
111			return eqb.paramValueBytes[pos:], nil
112		}
113	}
114
115	if argIsPtr {
116		// We have already checked that arg is not pointing to nil,
117		// so it is safe to dereference here.
118		arg = refVal.Elem().Interface()
119		return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg)
120	}
121
122	if dt, ok := ci.DataTypeForOID(oid); ok {
123		value := dt.Value
124		err := value.Set(arg)
125		if err != nil {
126			{
127				if arg, ok := arg.(driver.Valuer); ok {
128					v, err := callValuerValue(arg)
129					if err != nil {
130						return nil, err
131					}
132					return eqb.encodeExtendedParamValue(ci, oid, formatCode, v)
133				}
134			}
135
136			return nil, err
137		}
138
139		return eqb.encodeExtendedParamValue(ci, oid, formatCode, value)
140	}
141
142	// There is no data type registered for the destination OID, but maybe there is data type registered for the arg
143	// type. If so use it's text encoder (if available).
144	if dt, ok := ci.DataTypeForValue(arg); ok {
145		value := dt.Value
146		if textEncoder, ok := value.(pgtype.TextEncoder); ok {
147			err := value.Set(arg)
148			if err != nil {
149				return nil, err
150			}
151
152			buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes)
153			if err != nil {
154				return nil, err
155			}
156			if buf == nil {
157				return nil, nil
158			}
159			eqb.paramValueBytes = buf
160			return eqb.paramValueBytes[pos:], nil
161		}
162	}
163
164	if strippedArg, ok := stripNamedType(&refVal); ok {
165		return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg)
166	}
167	return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
168}
169