1package mssql 2 3import ( 4 "bytes" 5 "context" 6 "encoding/binary" 7 "fmt" 8 "math" 9 "reflect" 10 "strconv" 11 "strings" 12 "time" 13) 14 15type Bulk struct { 16 // ctx is used only for AddRow and Done methods. 17 // This could be removed if AddRow and Done accepted 18 // a ctx field as well, which is available with the 19 // database/sql call. 20 ctx context.Context 21 22 cn *Conn 23 metadata []columnStruct 24 bulkColumns []columnStruct 25 columnsName []string 26 tablename string 27 numRows int 28 29 headerSent bool 30 Options BulkOptions 31 Debug bool 32} 33type BulkOptions struct { 34 CheckConstraints bool 35 FireTriggers bool 36 KeepNulls bool 37 KilobytesPerBatch int 38 RowsPerBatch int 39 Order []string 40 Tablock bool 41} 42 43type DataValue interface{} 44 45func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) { 46 b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns} 47 b.Debug = false 48 return &b 49} 50 51func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) { 52 b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns} 53 b.Debug = false 54 return &b 55} 56 57func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) { 58 //get table columns info 59 err = b.getMetadata(ctx) 60 if err != nil { 61 return err 62 } 63 64 //match the columns 65 for _, colname := range b.columnsName { 66 var bulkCol *columnStruct 67 68 for _, m := range b.metadata { 69 if m.ColName == colname { 70 bulkCol = &m 71 break 72 } 73 } 74 if bulkCol != nil { 75 76 if bulkCol.ti.TypeId == typeUdt { 77 //send udt as binary 78 bulkCol.ti.TypeId = typeBigVarBin 79 } 80 b.bulkColumns = append(b.bulkColumns, *bulkCol) 81 b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId) 82 } else { 83 return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename) 84 } 85 } 86 87 //create the bulk command 88 89 //columns definitions 90 var col_defs bytes.Buffer 91 for i, col := range b.bulkColumns { 92 if i != 0 { 93 col_defs.WriteString(", ") 94 } 95 col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti)) 96 } 97 98 //options 99 var with_opts []string 100 101 if b.Options.CheckConstraints { 102 with_opts = append(with_opts, "CHECK_CONSTRAINTS") 103 } 104 if b.Options.FireTriggers { 105 with_opts = append(with_opts, "FIRE_TRIGGERS") 106 } 107 if b.Options.KeepNulls { 108 with_opts = append(with_opts, "KEEP_NULLS") 109 } 110 if b.Options.KilobytesPerBatch > 0 { 111 with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch)) 112 } 113 if b.Options.RowsPerBatch > 0 { 114 with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch)) 115 } 116 if len(b.Options.Order) > 0 { 117 with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ","))) 118 } 119 if b.Options.Tablock { 120 with_opts = append(with_opts, "TABLOCK") 121 } 122 var with_part string 123 if len(with_opts) > 0 { 124 with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ",")) 125 } 126 127 query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part) 128 129 stmt, err := b.cn.PrepareContext(ctx, query) 130 if err != nil { 131 return fmt.Errorf("Prepare failed: %s", err.Error()) 132 } 133 b.dlogf(query) 134 135 _, err = stmt.(*Stmt).ExecContext(ctx, nil) 136 if err != nil { 137 return err 138 } 139 140 b.headerSent = true 141 142 var buf = b.cn.sess.buf 143 buf.BeginPacket(packBulkLoadBCP, false) 144 145 // Send the columns metadata. 146 columnMetadata := b.createColMetadata() 147 _, err = buf.Write(columnMetadata) 148 149 return 150} 151 152// AddRow immediately writes the row to the destination table. 153// The arguments are the row values in the order they were specified. 154func (b *Bulk) AddRow(row []interface{}) (err error) { 155 if !b.headerSent { 156 err = b.sendBulkCommand(b.ctx) 157 if err != nil { 158 return 159 } 160 } 161 162 if len(row) != len(b.bulkColumns) { 163 return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d", 164 len(row), len(b.bulkColumns)) 165 } 166 167 bytes, err := b.makeRowData(row) 168 if err != nil { 169 return 170 } 171 172 _, err = b.cn.sess.buf.Write(bytes) 173 if err != nil { 174 return 175 } 176 177 b.numRows = b.numRows + 1 178 return 179} 180 181func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) { 182 buf := new(bytes.Buffer) 183 buf.WriteByte(byte(tokenRow)) 184 185 var logcol bytes.Buffer 186 for i, col := range b.bulkColumns { 187 188 if b.Debug { 189 logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i])) 190 } 191 param, err := b.makeParam(row[i], col) 192 if err != nil { 193 return nil, fmt.Errorf("bulkcopy: %s", err.Error()) 194 } 195 196 if col.ti.Writer == nil { 197 return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x", 198 col.ColName, col.ti.TypeId) 199 } 200 err = col.ti.Writer(buf, param.ti, param.buffer) 201 if err != nil { 202 return nil, fmt.Errorf("bulkcopy: %s", err.Error()) 203 } 204 } 205 206 b.dlogf("row[%d] %s\n", b.numRows, logcol.String()) 207 208 return buf.Bytes(), nil 209} 210 211func (b *Bulk) Done() (rowcount int64, err error) { 212 if b.headerSent == false { 213 //no rows had been sent 214 return 0, nil 215 } 216 var buf = b.cn.sess.buf 217 buf.WriteByte(byte(tokenDone)) 218 219 binary.Write(buf, binary.LittleEndian, uint16(doneFinal)) 220 binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd 221 222 if b.cn.sess.loginAck.TDSVersion >= verTDS72 { 223 binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0 224 } else { 225 binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0 226 } 227 228 buf.FinishPacket() 229 230 tokchan := make(chan tokenStruct, 5) 231 go processResponse(b.ctx, b.cn.sess, tokchan, nil) 232 233 var rowCount int64 234 for token := range tokchan { 235 switch token := token.(type) { 236 case doneStruct: 237 if token.Status&doneCount != 0 { 238 rowCount = int64(token.RowCount) 239 } 240 if token.isError() { 241 return 0, token.getError() 242 } 243 case error: 244 return 0, b.cn.checkBadConn(token) 245 } 246 } 247 return rowCount, nil 248} 249 250func (b *Bulk) createColMetadata() []byte { 251 buf := new(bytes.Buffer) 252 buf.WriteByte(byte(tokenColMetadata)) // token 253 binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count 254 255 for i, col := range b.bulkColumns { 256 257 if b.cn.sess.loginAck.TDSVersion >= verTDS72 { 258 binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0? 259 } else { 260 binary.Write(buf, binary.LittleEndian, uint16(col.UserType)) 261 } 262 binary.Write(buf, binary.LittleEndian, uint16(col.Flags)) 263 264 writeTypeInfo(buf, &b.bulkColumns[i].ti) 265 266 if col.ti.TypeId == typeNText || 267 col.ti.TypeId == typeText || 268 col.ti.TypeId == typeImage { 269 270 tablename_ucs2 := str2ucs2(b.tablename) 271 binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2)) 272 buf.Write(tablename_ucs2) 273 } 274 colname_ucs2 := str2ucs2(col.ColName) 275 buf.WriteByte(uint8(len(colname_ucs2) / 2)) 276 buf.Write(colname_ucs2) 277 } 278 279 return buf.Bytes() 280} 281 282func (b *Bulk) getMetadata(ctx context.Context) (err error) { 283 stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON") 284 if err != nil { 285 return 286 } 287 288 _, err = stmt.ExecContext(ctx, nil) 289 if err != nil { 290 return 291 } 292 293 // Get columns info. 294 stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename)) 295 if err != nil { 296 return 297 } 298 rows, err := stmt.QueryContext(ctx, nil) 299 if err != nil { 300 return fmt.Errorf("get columns info failed: %v", err) 301 } 302 b.metadata = rows.(*Rows).cols 303 304 if b.Debug { 305 for _, col := range b.metadata { 306 b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n", 307 col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec, 308 col.Flags, col.ti.Collation.LcidAndFlags) 309 } 310 } 311 312 return rows.Close() 313} 314 315func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) { 316 res.ti.Size = col.ti.Size 317 res.ti.TypeId = col.ti.TypeId 318 319 if val == nil { 320 res.ti.Size = 0 321 return 322 } 323 324 switch col.ti.TypeId { 325 326 case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN: 327 var intvalue int64 328 329 switch val := val.(type) { 330 case int: 331 intvalue = int64(val) 332 case int32: 333 intvalue = int64(val) 334 case int64: 335 intvalue = val 336 default: 337 err = fmt.Errorf("mssql: invalid type for int column") 338 return 339 } 340 341 res.buffer = make([]byte, res.ti.Size) 342 if col.ti.Size == 1 { 343 res.buffer[0] = byte(intvalue) 344 } else if col.ti.Size == 2 { 345 binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue)) 346 } else if col.ti.Size == 4 { 347 binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue)) 348 } else if col.ti.Size == 8 { 349 binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue)) 350 } 351 case typeFlt4, typeFlt8, typeFltN: 352 var floatvalue float64 353 354 switch val := val.(type) { 355 case float32: 356 floatvalue = float64(val) 357 case float64: 358 floatvalue = val 359 case int: 360 floatvalue = float64(val) 361 case int64: 362 floatvalue = float64(val) 363 default: 364 err = fmt.Errorf("mssql: invalid type for float column: %s", val) 365 return 366 } 367 368 if col.ti.Size == 4 { 369 res.buffer = make([]byte, 4) 370 binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue))) 371 } else if col.ti.Size == 8 { 372 res.buffer = make([]byte, 8) 373 binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue)) 374 } 375 case typeNVarChar, typeNText, typeNChar: 376 377 switch val := val.(type) { 378 case string: 379 res.buffer = str2ucs2(val) 380 case []byte: 381 res.buffer = val 382 default: 383 err = fmt.Errorf("mssql: invalid type for nvarchar column: %s", val) 384 return 385 } 386 res.ti.Size = len(res.buffer) 387 388 case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar: 389 switch val := val.(type) { 390 case string: 391 res.buffer = []byte(val) 392 case []byte: 393 res.buffer = val 394 default: 395 err = fmt.Errorf("mssql: invalid type for varchar column: %s", val) 396 return 397 } 398 res.ti.Size = len(res.buffer) 399 400 case typeBit, typeBitN: 401 if reflect.TypeOf(val).Kind() != reflect.Bool { 402 err = fmt.Errorf("mssql: invalid type for bit column: %s", val) 403 return 404 } 405 res.ti.TypeId = typeBitN 406 res.ti.Size = 1 407 res.buffer = make([]byte, 1) 408 if val.(bool) { 409 res.buffer[0] = 1 410 } 411 case typeDateTime2N: 412 switch val := val.(type) { 413 case time.Time: 414 res.buffer = encodeDateTime2(val, int(col.ti.Scale)) 415 res.ti.Size = len(res.buffer) 416 default: 417 err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val) 418 return 419 } 420 case typeDateTimeOffsetN: 421 switch val := val.(type) { 422 case time.Time: 423 res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale)) 424 res.ti.Size = len(res.buffer) 425 426 default: 427 err = fmt.Errorf("mssql: invalid type for datetimeoffset column: %s", val) 428 return 429 } 430 case typeDateN: 431 switch val := val.(type) { 432 case time.Time: 433 res.buffer = encodeDate(val) 434 res.ti.Size = len(res.buffer) 435 default: 436 err = fmt.Errorf("mssql: invalid type for date column: %s", val) 437 return 438 } 439 case typeDateTime, typeDateTimeN, typeDateTim4: 440 switch val := val.(type) { 441 case time.Time: 442 if col.ti.Size == 4 { 443 res.buffer = encodeDateTim4(val) 444 res.ti.Size = len(res.buffer) 445 } else if col.ti.Size == 8 { 446 res.buffer = encodeDateTime(val) 447 res.ti.Size = len(res.buffer) 448 } else { 449 err = fmt.Errorf("mssql: invalid size of column") 450 } 451 452 default: 453 err = fmt.Errorf("mssql: invalid type for datetime column: %s", val) 454 } 455 456 // case typeMoney, typeMoney4, typeMoneyN: 457 case typeDecimal, typeDecimalN, typeNumeric, typeNumericN: 458 var value float64 459 switch v := val.(type) { 460 case int: 461 value = float64(v) 462 case int8: 463 value = float64(v) 464 case int16: 465 value = float64(v) 466 case int32: 467 value = float64(v) 468 case int64: 469 value = float64(v) 470 case float32: 471 value = float64(v) 472 case float64: 473 value = v 474 case string: 475 if value, err = strconv.ParseFloat(v, 64); err != nil { 476 return res, fmt.Errorf("bulk: unable to convert string to float: %v", err) 477 } 478 default: 479 return res, fmt.Errorf("unknown value for decimal: %#v", v) 480 } 481 482 perc := col.ti.Prec 483 scale := col.ti.Scale 484 var dec Decimal 485 dec, err = Float64ToDecimalScale(value, scale) 486 if err != nil { 487 return res, err 488 } 489 dec.prec = perc 490 491 var length byte 492 switch { 493 case perc <= 9: 494 length = 4 495 case perc <= 19: 496 length = 8 497 case perc <= 28: 498 length = 12 499 default: 500 length = 16 501 } 502 503 buf := make([]byte, length+1) 504 // first byte length written by typeInfo.writer 505 res.ti.Size = int(length) + 1 506 // second byte sign 507 if value < 0 { 508 buf[0] = 0 509 } else { 510 buf[0] = 1 511 } 512 513 ub := dec.UnscaledBytes() 514 l := len(ub) 515 if l > int(length) { 516 err = fmt.Errorf("decimal out of range: %s", dec) 517 return res, err 518 } 519 // reverse the bytes 520 for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 { 521 buf[i] = ub[j] 522 } 523 res.buffer = buf 524 case typeBigVarBin, typeBigBinary: 525 switch val := val.(type) { 526 case []byte: 527 res.ti.Size = len(val) 528 res.buffer = val 529 default: 530 err = fmt.Errorf("mssql: invalid type for Binary column: %s", val) 531 return 532 } 533 case typeGuid: 534 switch val := val.(type) { 535 case []byte: 536 res.ti.Size = len(val) 537 res.buffer = val 538 default: 539 err = fmt.Errorf("mssql: invalid type for Guid column: %s", val) 540 return 541 } 542 543 default: 544 err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId) 545 } 546 return 547 548} 549 550func (b *Bulk) dlogf(format string, v ...interface{}) { 551 if b.Debug { 552 b.cn.sess.log.Printf(format, v...) 553 } 554} 555