1package pq
2
3import (
4	"database/sql/driver"
5	"encoding/binary"
6	"errors"
7	"fmt"
8	"sync"
9)
10
11var (
12	errCopyInClosed               = errors.New("pq: copyin statement has already been closed")
13	errBinaryCopyNotSupported     = errors.New("pq: only text format supported for COPY")
14	errCopyToNotSupported         = errors.New("pq: COPY TO is not supported")
15	errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
16	errCopyInProgress             = errors.New("pq: COPY in progress")
17)
18
19// CopyIn creates a COPY FROM statement which can be prepared with
20// Tx.Prepare().  The target table should be visible in search_path.
21func CopyIn(table string, columns ...string) string {
22	stmt := "COPY " + QuoteIdentifier(table) + " ("
23	for i, col := range columns {
24		if i != 0 {
25			stmt += ", "
26		}
27		stmt += QuoteIdentifier(col)
28	}
29	stmt += ") FROM STDIN"
30	return stmt
31}
32
33// CopyInSchema creates a COPY FROM statement which can be prepared with
34// Tx.Prepare().
35func CopyInSchema(schema, table string, columns ...string) string {
36	stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
37	for i, col := range columns {
38		if i != 0 {
39			stmt += ", "
40		}
41		stmt += QuoteIdentifier(col)
42	}
43	stmt += ") FROM STDIN"
44	return stmt
45}
46
47type copyin struct {
48	cn      *conn
49	buffer  []byte
50	rowData chan []byte
51	done    chan bool
52
53	closed bool
54
55	sync.Mutex // guards err
56	err        error
57}
58
59const ciBufferSize = 64 * 1024
60
61// flush buffer before the buffer is filled up and needs reallocation
62const ciBufferFlushSize = 63 * 1024
63
64func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
65	if !cn.isInTransaction() {
66		return nil, errCopyNotSupportedOutsideTxn
67	}
68
69	ci := &copyin{
70		cn:      cn,
71		buffer:  make([]byte, 0, ciBufferSize),
72		rowData: make(chan []byte),
73		done:    make(chan bool, 1),
74	}
75	// add CopyData identifier + 4 bytes for message length
76	ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
77
78	b := cn.writeBuf('Q')
79	b.string(q)
80	cn.send(b)
81
82awaitCopyInResponse:
83	for {
84		t, r := cn.recv1()
85		switch t {
86		case 'G':
87			if r.byte() != 0 {
88				err = errBinaryCopyNotSupported
89				break awaitCopyInResponse
90			}
91			go ci.resploop()
92			return ci, nil
93		case 'H':
94			err = errCopyToNotSupported
95			break awaitCopyInResponse
96		case 'E':
97			err = parseError(r)
98		case 'Z':
99			if err == nil {
100				ci.setBad()
101				errorf("unexpected ReadyForQuery in response to COPY")
102			}
103			cn.processReadyForQuery(r)
104			return nil, err
105		default:
106			ci.setBad()
107			errorf("unknown response for copy query: %q", t)
108		}
109	}
110
111	// something went wrong, abort COPY before we return
112	b = cn.writeBuf('f')
113	b.string(err.Error())
114	cn.send(b)
115
116	for {
117		t, r := cn.recv1()
118		switch t {
119		case 'c', 'C', 'E':
120		case 'Z':
121			// correctly aborted, we're done
122			cn.processReadyForQuery(r)
123			return nil, err
124		default:
125			ci.setBad()
126			errorf("unknown response for CopyFail: %q", t)
127		}
128	}
129}
130
131func (ci *copyin) flush(buf []byte) {
132	// set message length (without message identifier)
133	binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
134
135	_, err := ci.cn.c.Write(buf)
136	if err != nil {
137		panic(err)
138	}
139}
140
141func (ci *copyin) resploop() {
142	for {
143		var r readBuf
144		t, err := ci.cn.recvMessage(&r)
145		if err != nil {
146			ci.setBad()
147			ci.setError(err)
148			ci.done <- true
149			return
150		}
151		switch t {
152		case 'C':
153			// complete
154		case 'N':
155			// NoticeResponse
156		case 'Z':
157			ci.cn.processReadyForQuery(&r)
158			ci.done <- true
159			return
160		case 'E':
161			err := parseError(&r)
162			ci.setError(err)
163		default:
164			ci.setBad()
165			ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
166			ci.done <- true
167			return
168		}
169	}
170}
171
172func (ci *copyin) setBad() {
173	ci.Lock()
174	ci.cn.bad = true
175	ci.Unlock()
176}
177
178func (ci *copyin) isBad() bool {
179	ci.Lock()
180	b := ci.cn.bad
181	ci.Unlock()
182	return b
183}
184
185func (ci *copyin) isErrorSet() bool {
186	ci.Lock()
187	isSet := (ci.err != nil)
188	ci.Unlock()
189	return isSet
190}
191
192// setError() sets ci.err if one has not been set already.  Caller must not be
193// holding ci.Mutex.
194func (ci *copyin) setError(err error) {
195	ci.Lock()
196	if ci.err == nil {
197		ci.err = err
198	}
199	ci.Unlock()
200}
201
202func (ci *copyin) NumInput() int {
203	return -1
204}
205
206func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
207	return nil, ErrNotSupported
208}
209
210// Exec inserts values into the COPY stream. The insert is asynchronous
211// and Exec can return errors from previous Exec calls to the same
212// COPY stmt.
213//
214// You need to call Exec(nil) to sync the COPY stream and to get any
215// errors from pending data, since Stmt.Close() doesn't return errors
216// to the user.
217func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
218	if ci.closed {
219		return nil, errCopyInClosed
220	}
221
222	if ci.isBad() {
223		return nil, driver.ErrBadConn
224	}
225	defer ci.cn.errRecover(&err)
226
227	if ci.isErrorSet() {
228		return nil, ci.err
229	}
230
231	if len(v) == 0 {
232		return nil, ci.Close()
233	}
234
235	numValues := len(v)
236	for i, value := range v {
237		ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
238		if i < numValues-1 {
239			ci.buffer = append(ci.buffer, '\t')
240		}
241	}
242
243	ci.buffer = append(ci.buffer, '\n')
244
245	if len(ci.buffer) > ciBufferFlushSize {
246		ci.flush(ci.buffer)
247		// reset buffer, keep bytes for message identifier and length
248		ci.buffer = ci.buffer[:5]
249	}
250
251	return driver.RowsAffected(0), nil
252}
253
254func (ci *copyin) Close() (err error) {
255	if ci.closed { // Don't do anything, we're already closed
256		return nil
257	}
258	ci.closed = true
259
260	if ci.isBad() {
261		return driver.ErrBadConn
262	}
263	defer ci.cn.errRecover(&err)
264
265	if len(ci.buffer) > 0 {
266		ci.flush(ci.buffer)
267	}
268	// Avoid touching the scratch buffer as resploop could be using it.
269	err = ci.cn.sendSimpleMessage('c')
270	if err != nil {
271		return err
272	}
273
274	<-ci.done
275	ci.cn.inCopy = false
276
277	if ci.isErrorSet() {
278		err = ci.err
279		return err
280	}
281	return nil
282}
283