1package pq 2 3import ( 4 "database/sql/driver" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "sync" 9) 10 11var ( 12 errCopyInClosed = errors.New("pq: copyin statement has already been closed") 13 errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") 14 errCopyToNotSupported = errors.New("pq: COPY TO is not supported") 15 errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") 16 errCopyInProgress = errors.New("pq: COPY in progress") 17) 18 19// CopyIn creates a COPY FROM statement which can be prepared with 20// Tx.Prepare(). The target table should be visible in search_path. 21func CopyIn(table string, columns ...string) string { 22 stmt := "COPY " + QuoteIdentifier(table) + " (" 23 for i, col := range columns { 24 if i != 0 { 25 stmt += ", " 26 } 27 stmt += QuoteIdentifier(col) 28 } 29 stmt += ") FROM STDIN" 30 return stmt 31} 32 33// CopyInSchema creates a COPY FROM statement which can be prepared with 34// Tx.Prepare(). 35func CopyInSchema(schema, table string, columns ...string) string { 36 stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " (" 37 for i, col := range columns { 38 if i != 0 { 39 stmt += ", " 40 } 41 stmt += QuoteIdentifier(col) 42 } 43 stmt += ") FROM STDIN" 44 return stmt 45} 46 47type copyin struct { 48 cn *conn 49 buffer []byte 50 rowData chan []byte 51 done chan bool 52 53 closed bool 54 55 sync.Mutex // guards err 56 err error 57} 58 59const ciBufferSize = 64 * 1024 60 61// flush buffer before the buffer is filled up and needs reallocation 62const ciBufferFlushSize = 63 * 1024 63 64func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) { 65 if !cn.isInTransaction() { 66 return nil, errCopyNotSupportedOutsideTxn 67 } 68 69 ci := ©in{ 70 cn: cn, 71 buffer: make([]byte, 0, ciBufferSize), 72 rowData: make(chan []byte), 73 done: make(chan bool, 1), 74 } 75 // add CopyData identifier + 4 bytes for message length 76 ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0) 77 78 b := cn.writeBuf('Q') 79 b.string(q) 80 cn.send(b) 81 82awaitCopyInResponse: 83 for { 84 t, r := cn.recv1() 85 switch t { 86 case 'G': 87 if r.byte() != 0 { 88 err = errBinaryCopyNotSupported 89 break awaitCopyInResponse 90 } 91 go ci.resploop() 92 return ci, nil 93 case 'H': 94 err = errCopyToNotSupported 95 break awaitCopyInResponse 96 case 'E': 97 err = parseError(r) 98 case 'Z': 99 if err == nil { 100 ci.setBad() 101 errorf("unexpected ReadyForQuery in response to COPY") 102 } 103 cn.processReadyForQuery(r) 104 return nil, err 105 default: 106 ci.setBad() 107 errorf("unknown response for copy query: %q", t) 108 } 109 } 110 111 // something went wrong, abort COPY before we return 112 b = cn.writeBuf('f') 113 b.string(err.Error()) 114 cn.send(b) 115 116 for { 117 t, r := cn.recv1() 118 switch t { 119 case 'c', 'C', 'E': 120 case 'Z': 121 // correctly aborted, we're done 122 cn.processReadyForQuery(r) 123 return nil, err 124 default: 125 ci.setBad() 126 errorf("unknown response for CopyFail: %q", t) 127 } 128 } 129} 130 131func (ci *copyin) flush(buf []byte) { 132 // set message length (without message identifier) 133 binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) 134 135 _, err := ci.cn.c.Write(buf) 136 if err != nil { 137 panic(err) 138 } 139} 140 141func (ci *copyin) resploop() { 142 for { 143 var r readBuf 144 t, err := ci.cn.recvMessage(&r) 145 if err != nil { 146 ci.setBad() 147 ci.setError(err) 148 ci.done <- true 149 return 150 } 151 switch t { 152 case 'C': 153 // complete 154 case 'N': 155 // NoticeResponse 156 case 'Z': 157 ci.cn.processReadyForQuery(&r) 158 ci.done <- true 159 return 160 case 'E': 161 err := parseError(&r) 162 ci.setError(err) 163 default: 164 ci.setBad() 165 ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) 166 ci.done <- true 167 return 168 } 169 } 170} 171 172func (ci *copyin) setBad() { 173 ci.Lock() 174 ci.cn.bad = true 175 ci.Unlock() 176} 177 178func (ci *copyin) isBad() bool { 179 ci.Lock() 180 b := ci.cn.bad 181 ci.Unlock() 182 return b 183} 184 185func (ci *copyin) isErrorSet() bool { 186 ci.Lock() 187 isSet := (ci.err != nil) 188 ci.Unlock() 189 return isSet 190} 191 192// setError() sets ci.err if one has not been set already. Caller must not be 193// holding ci.Mutex. 194func (ci *copyin) setError(err error) { 195 ci.Lock() 196 if ci.err == nil { 197 ci.err = err 198 } 199 ci.Unlock() 200} 201 202func (ci *copyin) NumInput() int { 203 return -1 204} 205 206func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { 207 return nil, ErrNotSupported 208} 209 210// Exec inserts values into the COPY stream. The insert is asynchronous 211// and Exec can return errors from previous Exec calls to the same 212// COPY stmt. 213// 214// You need to call Exec(nil) to sync the COPY stream and to get any 215// errors from pending data, since Stmt.Close() doesn't return errors 216// to the user. 217func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { 218 if ci.closed { 219 return nil, errCopyInClosed 220 } 221 222 if ci.isBad() { 223 return nil, driver.ErrBadConn 224 } 225 defer ci.cn.errRecover(&err) 226 227 if ci.isErrorSet() { 228 return nil, ci.err 229 } 230 231 if len(v) == 0 { 232 return nil, ci.Close() 233 } 234 235 numValues := len(v) 236 for i, value := range v { 237 ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value) 238 if i < numValues-1 { 239 ci.buffer = append(ci.buffer, '\t') 240 } 241 } 242 243 ci.buffer = append(ci.buffer, '\n') 244 245 if len(ci.buffer) > ciBufferFlushSize { 246 ci.flush(ci.buffer) 247 // reset buffer, keep bytes for message identifier and length 248 ci.buffer = ci.buffer[:5] 249 } 250 251 return driver.RowsAffected(0), nil 252} 253 254func (ci *copyin) Close() (err error) { 255 if ci.closed { // Don't do anything, we're already closed 256 return nil 257 } 258 ci.closed = true 259 260 if ci.isBad() { 261 return driver.ErrBadConn 262 } 263 defer ci.cn.errRecover(&err) 264 265 if len(ci.buffer) > 0 { 266 ci.flush(ci.buffer) 267 } 268 // Avoid touching the scratch buffer as resploop could be using it. 269 err = ci.cn.sendSimpleMessage('c') 270 if err != nil { 271 return err 272 } 273 274 <-ci.done 275 ci.cn.inCopy = false 276 277 if ci.isErrorSet() { 278 err = ci.err 279 return err 280 } 281 return nil 282} 283