1package pgx 2 3import ( 4 "bytes" 5 "fmt" 6 7 "github.com/jackc/pgx/pgio" 8 "github.com/jackc/pgx/pgproto3" 9 "github.com/pkg/errors" 10) 11 12// CopyFromRows returns a CopyFromSource interface over the provided rows slice 13// making it usable by *Conn.CopyFrom. 14func CopyFromRows(rows [][]interface{}) CopyFromSource { 15 return ©FromRows{rows: rows, idx: -1} 16} 17 18type copyFromRows struct { 19 rows [][]interface{} 20 idx int 21} 22 23func (ctr *copyFromRows) Next() bool { 24 ctr.idx++ 25 return ctr.idx < len(ctr.rows) 26} 27 28func (ctr *copyFromRows) Values() ([]interface{}, error) { 29 return ctr.rows[ctr.idx], nil 30} 31 32func (ctr *copyFromRows) Err() error { 33 return nil 34} 35 36// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. 37type CopyFromSource interface { 38 // Next returns true if there is another row and makes the next row data 39 // available to Values(). When there are no more rows available or an error 40 // has occurred it returns false. 41 Next() bool 42 43 // Values returns the values for the current row. 44 Values() ([]interface{}, error) 45 46 // Err returns any error that has been encountered by the CopyFromSource. If 47 // this is not nil *Conn.CopyFrom will abort the copy. 48 Err() error 49} 50 51type copyFrom struct { 52 conn *Conn 53 tableName Identifier 54 columnNames []string 55 rowSrc CopyFromSource 56 readerErrChan chan error 57} 58 59func (ct *copyFrom) readUntilReadyForQuery() { 60 for { 61 msg, err := ct.conn.rxMsg() 62 if err != nil { 63 ct.readerErrChan <- err 64 close(ct.readerErrChan) 65 return 66 } 67 68 switch msg := msg.(type) { 69 case *pgproto3.ReadyForQuery: 70 ct.conn.rxReadyForQuery(msg) 71 close(ct.readerErrChan) 72 return 73 case *pgproto3.CommandComplete: 74 case *pgproto3.ErrorResponse: 75 ct.readerErrChan <- ct.conn.rxErrorResponse(msg) 76 default: 77 err = ct.conn.processContextFreeMsg(msg) 78 if err != nil { 79 ct.readerErrChan <- ct.conn.processContextFreeMsg(msg) 80 } 81 } 82 } 83} 84 85func (ct *copyFrom) waitForReaderDone() error { 86 var err error 87 for err = range ct.readerErrChan { 88 } 89 return err 90} 91 92func (ct *copyFrom) run() (int, error) { 93 quotedTableName := ct.tableName.Sanitize() 94 cbuf := &bytes.Buffer{} 95 for i, cn := range ct.columnNames { 96 if i != 0 { 97 cbuf.WriteString(", ") 98 } 99 cbuf.WriteString(quoteIdentifier(cn)) 100 } 101 quotedColumnNames := cbuf.String() 102 103 ps, err := ct.conn.Prepare("", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) 104 if err != nil { 105 return 0, err 106 } 107 108 err = ct.conn.sendSimpleQuery(fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) 109 if err != nil { 110 return 0, err 111 } 112 113 err = ct.conn.readUntilCopyInResponse() 114 if err != nil { 115 return 0, err 116 } 117 118 go ct.readUntilReadyForQuery() 119 defer ct.waitForReaderDone() 120 121 buf := ct.conn.wbuf 122 buf = append(buf, copyData) 123 sp := len(buf) 124 buf = pgio.AppendInt32(buf, -1) 125 126 buf = append(buf, "PGCOPY\n\377\r\n\000"...) 127 buf = pgio.AppendInt32(buf, 0) 128 buf = pgio.AppendInt32(buf, 0) 129 130 var sentCount int 131 132 for ct.rowSrc.Next() { 133 select { 134 case err = <-ct.readerErrChan: 135 return 0, err 136 default: 137 } 138 139 if len(buf) > 65536 { 140 pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) 141 _, err = ct.conn.conn.Write(buf) 142 if err != nil { 143 ct.conn.die(err) 144 return 0, err 145 } 146 147 // Directly manipulate wbuf to reset to reuse the same buffer 148 buf = buf[0:5] 149 } 150 151 sentCount++ 152 153 values, err := ct.rowSrc.Values() 154 if err != nil { 155 ct.cancelCopyIn() 156 return 0, err 157 } 158 if len(values) != len(ct.columnNames) { 159 ct.cancelCopyIn() 160 return 0, errors.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) 161 } 162 163 buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) 164 for i, val := range values { 165 buf, err = encodePreparedStatementArgument(ct.conn.ConnInfo, buf, ps.FieldDescriptions[i].DataType, val) 166 if err != nil { 167 ct.cancelCopyIn() 168 return 0, err 169 } 170 171 } 172 } 173 174 if ct.rowSrc.Err() != nil { 175 ct.cancelCopyIn() 176 return 0, ct.rowSrc.Err() 177 } 178 179 buf = pgio.AppendInt16(buf, -1) // terminate the copy stream 180 pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) 181 182 buf = append(buf, copyDone) 183 buf = pgio.AppendInt32(buf, 4) 184 185 _, err = ct.conn.conn.Write(buf) 186 if err != nil { 187 ct.conn.die(err) 188 return 0, err 189 } 190 191 err = ct.waitForReaderDone() 192 if err != nil { 193 return 0, err 194 } 195 return sentCount, nil 196} 197 198func (c *Conn) readUntilCopyInResponse() error { 199 for { 200 msg, err := c.rxMsg() 201 if err != nil { 202 return err 203 } 204 205 switch msg := msg.(type) { 206 case *pgproto3.CopyInResponse: 207 return nil 208 default: 209 err = c.processContextFreeMsg(msg) 210 if err != nil { 211 return err 212 } 213 } 214 } 215} 216 217func (ct *copyFrom) cancelCopyIn() error { 218 buf := ct.conn.wbuf 219 buf = append(buf, copyFail) 220 sp := len(buf) 221 buf = pgio.AppendInt32(buf, -1) 222 buf = append(buf, "client error: abort"...) 223 buf = append(buf, 0) 224 pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) 225 226 _, err := ct.conn.conn.Write(buf) 227 if err != nil { 228 ct.conn.die(err) 229 return err 230 } 231 232 return nil 233} 234 235// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. 236// It returns the number of rows copied and an error. 237// 238// CopyFrom requires all values use the binary format. Almost all types 239// implemented by pgx use the binary format by default. Types implementing 240// Encoder can only be used if they encode to the binary format. 241func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { 242 ct := ©From{ 243 conn: c, 244 tableName: tableName, 245 columnNames: columnNames, 246 rowSrc: rowSrc, 247 readerErrChan: make(chan error), 248 } 249 250 return ct.run() 251} 252