1package mssql
2
3import (
4	"bytes"
5	"context"
6	"encoding/binary"
7	"fmt"
8	"math"
9	"reflect"
10	"strconv"
11	"strings"
12	"time"
13)
14
15type Bulk struct {
16	// ctx is used only for AddRow and Done methods.
17	// This could be removed if AddRow and Done accepted
18	// a ctx field as well, which is available with the
19	// database/sql call.
20	ctx context.Context
21
22	cn          *Conn
23	metadata    []columnStruct
24	bulkColumns []columnStruct
25	columnsName []string
26	tablename   string
27	numRows     int
28
29	headerSent bool
30	Options    BulkOptions
31	Debug      bool
32}
33type BulkOptions struct {
34	CheckConstraints  bool
35	FireTriggers      bool
36	KeepNulls         bool
37	KilobytesPerBatch int
38	RowsPerBatch      int
39	Order             []string
40	Tablock           bool
41}
42
43type DataValue interface{}
44
45func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
46	b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns}
47	b.Debug = false
48	return &b
49}
50
51func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) {
52	b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns}
53	b.Debug = false
54	return &b
55}
56
57func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
58	//get table columns info
59	err = b.getMetadata(ctx)
60	if err != nil {
61		return err
62	}
63
64	//match the columns
65	for _, colname := range b.columnsName {
66		var bulkCol *columnStruct
67
68		for _, m := range b.metadata {
69			if m.ColName == colname {
70				bulkCol = &m
71				break
72			}
73		}
74		if bulkCol != nil {
75
76			if bulkCol.ti.TypeId == typeUdt {
77				//send udt as binary
78				bulkCol.ti.TypeId = typeBigVarBin
79			}
80			b.bulkColumns = append(b.bulkColumns, *bulkCol)
81			b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
82		} else {
83			return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
84		}
85	}
86
87	//create the bulk command
88
89	//columns definitions
90	var col_defs bytes.Buffer
91	for i, col := range b.bulkColumns {
92		if i != 0 {
93			col_defs.WriteString(", ")
94		}
95		col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti))
96	}
97
98	//options
99	var with_opts []string
100
101	if b.Options.CheckConstraints {
102		with_opts = append(with_opts, "CHECK_CONSTRAINTS")
103	}
104	if b.Options.FireTriggers {
105		with_opts = append(with_opts, "FIRE_TRIGGERS")
106	}
107	if b.Options.KeepNulls {
108		with_opts = append(with_opts, "KEEP_NULLS")
109	}
110	if b.Options.KilobytesPerBatch > 0 {
111		with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch))
112	}
113	if b.Options.RowsPerBatch > 0 {
114		with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch))
115	}
116	if len(b.Options.Order) > 0 {
117		with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ",")))
118	}
119	if b.Options.Tablock {
120		with_opts = append(with_opts, "TABLOCK")
121	}
122	var with_part string
123	if len(with_opts) > 0 {
124		with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ","))
125	}
126
127	query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
128
129	stmt, err := b.cn.PrepareContext(ctx, query)
130	if err != nil {
131		return fmt.Errorf("Prepare failed: %s", err.Error())
132	}
133	b.dlogf(query)
134
135	_, err = stmt.(*Stmt).ExecContext(ctx, nil)
136	if err != nil {
137		return err
138	}
139
140	b.headerSent = true
141
142	var buf = b.cn.sess.buf
143	buf.BeginPacket(packBulkLoadBCP, false)
144
145	// Send the columns metadata.
146	columnMetadata := b.createColMetadata()
147	_, err = buf.Write(columnMetadata)
148
149	return
150}
151
152// AddRow immediately writes the row to the destination table.
153// The arguments are the row values in the order they were specified.
154func (b *Bulk) AddRow(row []interface{}) (err error) {
155	if !b.headerSent {
156		err = b.sendBulkCommand(b.ctx)
157		if err != nil {
158			return
159		}
160	}
161
162	if len(row) != len(b.bulkColumns) {
163		return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
164			len(row), len(b.bulkColumns))
165	}
166
167	bytes, err := b.makeRowData(row)
168	if err != nil {
169		return
170	}
171
172	_, err = b.cn.sess.buf.Write(bytes)
173	if err != nil {
174		return
175	}
176
177	b.numRows = b.numRows + 1
178	return
179}
180
181func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
182	buf := new(bytes.Buffer)
183	buf.WriteByte(byte(tokenRow))
184
185	var logcol bytes.Buffer
186	for i, col := range b.bulkColumns {
187
188		if b.Debug {
189			logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i]))
190		}
191		param, err := b.makeParam(row[i], col)
192		if err != nil {
193			return nil, fmt.Errorf("bulkcopy: %s", err.Error())
194		}
195
196		if col.ti.Writer == nil {
197			return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x",
198				col.ColName, col.ti.TypeId)
199		}
200		err = col.ti.Writer(buf, param.ti, param.buffer)
201		if err != nil {
202			return nil, fmt.Errorf("bulkcopy: %s", err.Error())
203		}
204	}
205
206	b.dlogf("row[%d] %s\n", b.numRows, logcol.String())
207
208	return buf.Bytes(), nil
209}
210
211func (b *Bulk) Done() (rowcount int64, err error) {
212	if b.headerSent == false {
213		//no rows had been sent
214		return 0, nil
215	}
216	var buf = b.cn.sess.buf
217	buf.WriteByte(byte(tokenDone))
218
219	binary.Write(buf, binary.LittleEndian, uint16(doneFinal))
220	binary.Write(buf, binary.LittleEndian, uint16(0)) //     curcmd
221
222	if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
223		binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0
224	} else {
225		binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0
226	}
227
228	buf.FinishPacket()
229
230	tokchan := make(chan tokenStruct, 5)
231	go processResponse(b.ctx, b.cn.sess, tokchan, nil)
232
233	var rowCount int64
234	for token := range tokchan {
235		switch token := token.(type) {
236		case doneStruct:
237			if token.Status&doneCount != 0 {
238				rowCount = int64(token.RowCount)
239			}
240			if token.isError() {
241				return 0, token.getError()
242			}
243		case error:
244			return 0, b.cn.checkBadConn(token)
245		}
246	}
247	return rowCount, nil
248}
249
250func (b *Bulk) createColMetadata() []byte {
251	buf := new(bytes.Buffer)
252	buf.WriteByte(byte(tokenColMetadata))                              // token
253	binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
254
255	for i, col := range b.bulkColumns {
256
257		if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
258			binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) //  usertype, always 0?
259		} else {
260			binary.Write(buf, binary.LittleEndian, uint16(col.UserType))
261		}
262		binary.Write(buf, binary.LittleEndian, uint16(col.Flags))
263
264		writeTypeInfo(buf, &b.bulkColumns[i].ti)
265
266		if col.ti.TypeId == typeNText ||
267			col.ti.TypeId == typeText ||
268			col.ti.TypeId == typeImage {
269
270			tablename_ucs2 := str2ucs2(b.tablename)
271			binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2))
272			buf.Write(tablename_ucs2)
273		}
274		colname_ucs2 := str2ucs2(col.ColName)
275		buf.WriteByte(uint8(len(colname_ucs2) / 2))
276		buf.Write(colname_ucs2)
277	}
278
279	return buf.Bytes()
280}
281
282func (b *Bulk) getMetadata(ctx context.Context) (err error) {
283	stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON")
284	if err != nil {
285		return
286	}
287
288	_, err = stmt.ExecContext(ctx, nil)
289	if err != nil {
290		return
291	}
292
293	// Get columns info.
294	stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
295	if err != nil {
296		return
297	}
298	rows, err := stmt.QueryContext(ctx, nil)
299	if err != nil {
300		return fmt.Errorf("get columns info failed: %v", err)
301	}
302	b.metadata = rows.(*Rows).cols
303
304	if b.Debug {
305		for _, col := range b.metadata {
306			b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
307				col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
308				col.Flags, col.ti.Collation.LcidAndFlags)
309		}
310	}
311
312	return rows.Close()
313}
314
315func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) {
316	res.ti.Size = col.ti.Size
317	res.ti.TypeId = col.ti.TypeId
318
319	if val == nil {
320		res.ti.Size = 0
321		return
322	}
323
324	switch col.ti.TypeId {
325
326	case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN:
327		var intvalue int64
328
329		switch val := val.(type) {
330		case int:
331			intvalue = int64(val)
332		case int32:
333			intvalue = int64(val)
334		case int64:
335			intvalue = val
336		default:
337			err = fmt.Errorf("mssql: invalid type for int column")
338			return
339		}
340
341		res.buffer = make([]byte, res.ti.Size)
342		if col.ti.Size == 1 {
343			res.buffer[0] = byte(intvalue)
344		} else if col.ti.Size == 2 {
345			binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue))
346		} else if col.ti.Size == 4 {
347			binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue))
348		} else if col.ti.Size == 8 {
349			binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue))
350		}
351	case typeFlt4, typeFlt8, typeFltN:
352		var floatvalue float64
353
354		switch val := val.(type) {
355		case float32:
356			floatvalue = float64(val)
357		case float64:
358			floatvalue = val
359		case int:
360			floatvalue = float64(val)
361		case int64:
362			floatvalue = float64(val)
363		default:
364			err = fmt.Errorf("mssql: invalid type for float column: %s", val)
365			return
366		}
367
368		if col.ti.Size == 4 {
369			res.buffer = make([]byte, 4)
370			binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue)))
371		} else if col.ti.Size == 8 {
372			res.buffer = make([]byte, 8)
373			binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue))
374		}
375	case typeNVarChar, typeNText, typeNChar:
376
377		switch val := val.(type) {
378		case string:
379			res.buffer = str2ucs2(val)
380		case []byte:
381			res.buffer = val
382		default:
383			err = fmt.Errorf("mssql: invalid type for nvarchar column: %s", val)
384			return
385		}
386		res.ti.Size = len(res.buffer)
387
388	case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar:
389		switch val := val.(type) {
390		case string:
391			res.buffer = []byte(val)
392		case []byte:
393			res.buffer = val
394		default:
395			err = fmt.Errorf("mssql: invalid type for varchar column: %s", val)
396			return
397		}
398		res.ti.Size = len(res.buffer)
399
400	case typeBit, typeBitN:
401		if reflect.TypeOf(val).Kind() != reflect.Bool {
402			err = fmt.Errorf("mssql: invalid type for bit column: %s", val)
403			return
404		}
405		res.ti.TypeId = typeBitN
406		res.ti.Size = 1
407		res.buffer = make([]byte, 1)
408		if val.(bool) {
409			res.buffer[0] = 1
410		}
411	case typeDateTime2N:
412		switch val := val.(type) {
413		case time.Time:
414			res.buffer = encodeDateTime2(val, int(col.ti.Scale))
415			res.ti.Size = len(res.buffer)
416		default:
417			err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val)
418			return
419		}
420	case typeDateTimeOffsetN:
421		switch val := val.(type) {
422		case time.Time:
423			res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale))
424			res.ti.Size = len(res.buffer)
425
426		default:
427			err = fmt.Errorf("mssql: invalid type for datetimeoffset column: %s", val)
428			return
429		}
430	case typeDateN:
431		switch val := val.(type) {
432		case time.Time:
433			res.buffer = encodeDate(val)
434			res.ti.Size = len(res.buffer)
435		default:
436			err = fmt.Errorf("mssql: invalid type for date column: %s", val)
437			return
438		}
439	case typeDateTime, typeDateTimeN, typeDateTim4:
440		switch val := val.(type) {
441		case time.Time:
442			if col.ti.Size == 4 {
443				res.buffer = encodeDateTim4(val)
444				res.ti.Size = len(res.buffer)
445			} else if col.ti.Size == 8 {
446				res.buffer = encodeDateTime(val)
447				res.ti.Size = len(res.buffer)
448			} else {
449				err = fmt.Errorf("mssql: invalid size of column")
450			}
451
452		default:
453			err = fmt.Errorf("mssql: invalid type for datetime column: %s", val)
454		}
455
456	// case typeMoney, typeMoney4, typeMoneyN:
457	case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
458		var value float64
459		switch v := val.(type) {
460		case int:
461			value = float64(v)
462		case int8:
463			value = float64(v)
464		case int16:
465			value = float64(v)
466		case int32:
467			value = float64(v)
468		case int64:
469			value = float64(v)
470		case float32:
471			value = float64(v)
472		case float64:
473			value = v
474		case string:
475			if value, err = strconv.ParseFloat(v, 64); err != nil {
476				return res, fmt.Errorf("bulk: unable to convert string to float: %v", err)
477			}
478		default:
479			return res, fmt.Errorf("unknown value for decimal: %#v", v)
480		}
481
482		perc := col.ti.Prec
483		scale := col.ti.Scale
484		var dec Decimal
485		dec, err = Float64ToDecimalScale(value, scale)
486		if err != nil {
487			return res, err
488		}
489		dec.prec = perc
490
491		var length byte
492		switch {
493		case perc <= 9:
494			length = 4
495		case perc <= 19:
496			length = 8
497		case perc <= 28:
498			length = 12
499		default:
500			length = 16
501		}
502
503		buf := make([]byte, length+1)
504		// first byte length written by typeInfo.writer
505		res.ti.Size = int(length) + 1
506		// second byte sign
507		if value < 0 {
508			buf[0] = 0
509		} else {
510			buf[0] = 1
511		}
512
513		ub := dec.UnscaledBytes()
514		l := len(ub)
515		if l > int(length) {
516			err = fmt.Errorf("decimal out of range: %s", dec)
517			return res, err
518		}
519		// reverse the bytes
520		for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 {
521			buf[i] = ub[j]
522		}
523		res.buffer = buf
524	case typeBigVarBin, typeBigBinary:
525		switch val := val.(type) {
526		case []byte:
527			res.ti.Size = len(val)
528			res.buffer = val
529		default:
530			err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
531			return
532		}
533	case typeGuid:
534		switch val := val.(type) {
535		case []byte:
536			res.ti.Size = len(val)
537			res.buffer = val
538		default:
539			err = fmt.Errorf("mssql: invalid type for Guid column: %s", val)
540			return
541		}
542
543	default:
544		err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
545	}
546	return
547
548}
549
550func (b *Bulk) dlogf(format string, v ...interface{}) {
551	if b.Debug {
552		b.cn.sess.log.Printf(format, v...)
553	}
554}
555