1package pgx 2 3import ( 4 "database/sql/driver" 5 "fmt" 6 "reflect" 7 8 "github.com/jackc/pgtype" 9) 10 11type extendedQueryBuilder struct { 12 paramValues [][]byte 13 paramValueBytes []byte 14 paramFormats []int16 15 resultFormats []int16 16 17 resetCount int 18} 19 20func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error { 21 f := chooseParameterFormatCode(ci, oid, arg) 22 eqb.paramFormats = append(eqb.paramFormats, f) 23 24 v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg) 25 if err != nil { 26 return err 27 } 28 eqb.paramValues = append(eqb.paramValues, v) 29 30 return nil 31} 32 33func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { 34 eqb.resultFormats = append(eqb.resultFormats, f) 35} 36 37func (eqb *extendedQueryBuilder) Reset() { 38 eqb.paramValues = eqb.paramValues[0:0] 39 eqb.paramValueBytes = eqb.paramValueBytes[0:0] 40 eqb.paramFormats = eqb.paramFormats[0:0] 41 eqb.resultFormats = eqb.resultFormats[0:0] 42 43 eqb.resetCount++ 44 45 // Every so often shrink our reserved memory if it is abnormally high 46 if eqb.resetCount%128 == 0 { 47 if cap(eqb.paramValues) > 64 { 48 eqb.paramValues = make([][]byte, 0, cap(eqb.paramValues)/2) 49 } 50 51 if cap(eqb.paramValueBytes) > 256 { 52 eqb.paramValueBytes = make([]byte, 0, cap(eqb.paramValueBytes)/2) 53 } 54 55 if cap(eqb.paramFormats) > 64 { 56 eqb.paramFormats = make([]int16, 0, cap(eqb.paramFormats)/2) 57 } 58 if cap(eqb.resultFormats) > 64 { 59 eqb.resultFormats = make([]int16, 0, cap(eqb.resultFormats)/2) 60 } 61 } 62 63} 64 65func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { 66 if arg == nil { 67 return nil, nil 68 } 69 70 refVal := reflect.ValueOf(arg) 71 argIsPtr := refVal.Kind() == reflect.Ptr 72 73 if argIsPtr && refVal.IsNil() { 74 return nil, nil 75 } 76 77 if eqb.paramValueBytes == nil { 78 eqb.paramValueBytes = make([]byte, 0, 128) 79 } 80 81 var err error 82 var buf []byte 83 pos := len(eqb.paramValueBytes) 84 85 if arg, ok := arg.(string); ok { 86 return []byte(arg), nil 87 } 88 89 if formatCode == TextFormatCode { 90 if arg, ok := arg.(pgtype.TextEncoder); ok { 91 buf, err = arg.EncodeText(ci, eqb.paramValueBytes) 92 if err != nil { 93 return nil, err 94 } 95 if buf == nil { 96 return nil, nil 97 } 98 eqb.paramValueBytes = buf 99 return eqb.paramValueBytes[pos:], nil 100 } 101 } else if formatCode == BinaryFormatCode { 102 if arg, ok := arg.(pgtype.BinaryEncoder); ok { 103 buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) 104 if err != nil { 105 return nil, err 106 } 107 if buf == nil { 108 return nil, nil 109 } 110 eqb.paramValueBytes = buf 111 return eqb.paramValueBytes[pos:], nil 112 } 113 } 114 115 if argIsPtr { 116 // We have already checked that arg is not pointing to nil, 117 // so it is safe to dereference here. 118 arg = refVal.Elem().Interface() 119 return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) 120 } 121 122 if dt, ok := ci.DataTypeForOID(oid); ok { 123 value := dt.Value 124 err := value.Set(arg) 125 if err != nil { 126 { 127 if arg, ok := arg.(driver.Valuer); ok { 128 v, err := callValuerValue(arg) 129 if err != nil { 130 return nil, err 131 } 132 return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) 133 } 134 } 135 136 return nil, err 137 } 138 139 return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) 140 } 141 142 // There is no data type registered for the destination OID, but maybe there is data type registered for the arg 143 // type. If so use it's text encoder (if available). 144 if dt, ok := ci.DataTypeForValue(arg); ok { 145 value := dt.Value 146 if textEncoder, ok := value.(pgtype.TextEncoder); ok { 147 err := value.Set(arg) 148 if err != nil { 149 return nil, err 150 } 151 152 buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes) 153 if err != nil { 154 return nil, err 155 } 156 if buf == nil { 157 return nil, nil 158 } 159 eqb.paramValueBytes = buf 160 return eqb.paramValueBytes[pos:], nil 161 } 162 } 163 164 if strippedArg, ok := stripNamedType(&refVal); ok { 165 return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) 166 } 167 return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) 168} 169