1package pgx
2
3import (
4	"bytes"
5	"fmt"
6
7	"github.com/jackc/pgx/pgio"
8	"github.com/jackc/pgx/pgproto3"
9	"github.com/pkg/errors"
10)
11
12// CopyFromRows returns a CopyFromSource interface over the provided rows slice
13// making it usable by *Conn.CopyFrom.
14func CopyFromRows(rows [][]interface{}) CopyFromSource {
15	return &copyFromRows{rows: rows, idx: -1}
16}
17
18type copyFromRows struct {
19	rows [][]interface{}
20	idx  int
21}
22
23func (ctr *copyFromRows) Next() bool {
24	ctr.idx++
25	return ctr.idx < len(ctr.rows)
26}
27
28func (ctr *copyFromRows) Values() ([]interface{}, error) {
29	return ctr.rows[ctr.idx], nil
30}
31
32func (ctr *copyFromRows) Err() error {
33	return nil
34}
35
36// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
37type CopyFromSource interface {
38	// Next returns true if there is another row and makes the next row data
39	// available to Values(). When there are no more rows available or an error
40	// has occurred it returns false.
41	Next() bool
42
43	// Values returns the values for the current row.
44	Values() ([]interface{}, error)
45
46	// Err returns any error that has been encountered by the CopyFromSource. If
47	// this is not nil *Conn.CopyFrom will abort the copy.
48	Err() error
49}
50
51type copyFrom struct {
52	conn          *Conn
53	tableName     Identifier
54	columnNames   []string
55	rowSrc        CopyFromSource
56	readerErrChan chan error
57}
58
59func (ct *copyFrom) readUntilReadyForQuery() {
60	for {
61		msg, err := ct.conn.rxMsg()
62		if err != nil {
63			ct.readerErrChan <- err
64			close(ct.readerErrChan)
65			return
66		}
67
68		switch msg := msg.(type) {
69		case *pgproto3.ReadyForQuery:
70			ct.conn.rxReadyForQuery(msg)
71			close(ct.readerErrChan)
72			return
73		case *pgproto3.CommandComplete:
74		case *pgproto3.ErrorResponse:
75			ct.readerErrChan <- ct.conn.rxErrorResponse(msg)
76		default:
77			err = ct.conn.processContextFreeMsg(msg)
78			if err != nil {
79				ct.readerErrChan <- ct.conn.processContextFreeMsg(msg)
80			}
81		}
82	}
83}
84
85func (ct *copyFrom) waitForReaderDone() error {
86	var err error
87	for err = range ct.readerErrChan {
88	}
89	return err
90}
91
92func (ct *copyFrom) run() (int, error) {
93	quotedTableName := ct.tableName.Sanitize()
94	cbuf := &bytes.Buffer{}
95	for i, cn := range ct.columnNames {
96		if i != 0 {
97			cbuf.WriteString(", ")
98		}
99		cbuf.WriteString(quoteIdentifier(cn))
100	}
101	quotedColumnNames := cbuf.String()
102
103	ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
104	if err != nil {
105		return 0, err
106	}
107
108	err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
109	if err != nil {
110		return 0, err
111	}
112
113	err = ct.conn.readUntilCopyInResponse()
114	if err != nil {
115		return 0, err
116	}
117
118	go ct.readUntilReadyForQuery()
119	defer ct.waitForReaderDone()
120
121	buf := ct.conn.wbuf
122	buf = append(buf, copyData)
123	sp := len(buf)
124	buf = pgio.AppendInt32(buf, -1)
125
126	buf = append(buf, "PGCOPY\n\377\r\n\000"...)
127	buf = pgio.AppendInt32(buf, 0)
128	buf = pgio.AppendInt32(buf, 0)
129
130	var sentCount int
131
132	for ct.rowSrc.Next() {
133		select {
134		case err = <-ct.readerErrChan:
135			return 0, err
136		default:
137		}
138
139		if len(buf) > 65536 {
140			pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
141			_, err = ct.conn.conn.Write(buf)
142			if err != nil {
143				ct.conn.die(err)
144				return 0, err
145			}
146
147			// Directly manipulate wbuf to reset to reuse the same buffer
148			buf = buf[0:5]
149		}
150
151		sentCount++
152
153		values, err := ct.rowSrc.Values()
154		if err != nil {
155			ct.cancelCopyIn()
156			return 0, err
157		}
158		if len(values) != len(ct.columnNames) {
159			ct.cancelCopyIn()
160			return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values))
161		}
162
163		buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
164		for i, val := range values {
165			buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val)
166			if err != nil {
167				ct.cancelCopyIn()
168				return 0, err
169			}
170
171		}
172	}
173
174	if ct.rowSrc.Err() != nil {
175		ct.cancelCopyIn()
176		return 0, ct.rowSrc.Err()
177	}
178
179	buf = pgio.AppendInt16(buf, -1) // terminate the copy stream
180	pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
181
182	buf = append(buf, copyDone)
183	buf = pgio.AppendInt32(buf, 4)
184
185	_, err = ct.conn.conn.Write(buf)
186	if err != nil {
187		ct.conn.die(err)
188		return 0, err
189	}
190
191	err = ct.waitForReaderDone()
192	if err != nil {
193		return 0, err
194	}
195	return sentCount, nil
196}
197
198func (c *Conn) readUntilCopyInResponse() error {
199	for {
200		msg, err := c.rxMsg()
201		if err != nil {
202			return err
203		}
204
205		switch msg := msg.(type) {
206		case *pgproto3.CopyInResponse:
207			return nil
208		default:
209			err = c.processContextFreeMsg(msg)
210			if err != nil {
211				return err
212			}
213		}
214	}
215}
216
217func (ct *copyFrom) cancelCopyIn() error {
218	buf := ct.conn.wbuf
219	buf = append(buf, copyFail)
220	sp := len(buf)
221	buf = pgio.AppendInt32(buf, -1)
222	buf = append(buf, "client error: abort"...)
223	buf = append(buf, 0)
224	pgio.SetInt32(buf[sp:], int32(len(buf[sp:])))
225
226	_, err := ct.conn.conn.Write(buf)
227	if err != nil {
228		ct.conn.die(err)
229		return err
230	}
231
232	return nil
233}
234
235// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion.
236// It returns the number of rows copied and an error.
237//
238// CopyFrom requires all values use the binary format. Almost all types
239// implemented by pgx use the binary format by default. Types implementing
240// Encoder can only be used if they encode to the binary format.
241func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
242	ct := &copyFrom{
243		conn:          c,
244		tableName:     tableName,
245		columnNames:   columnNames,
246		rowSrc:        rowSrc,
247		readerErrChan: make(chan error),
248	}
249
250	return ct.run()
251}
252