1package mssql
2
3import (
4	"context"
5	"database/sql/driver"
6	"encoding/json"
7	"errors"
8)
9
10type copyin struct {
11	cn       *Conn
12	bulkcopy *Bulk
13	closed   bool
14}
15
16type serializableBulkConfig struct {
17	TableName   string
18	ColumnsName []string
19	Options     BulkOptions
20}
21
22func (d *Driver) OpenConnection(dsn string) (*Conn, error) {
23	return d.open(context.Background(), dsn)
24}
25
26func (c *Conn) prepareCopyIn(ctx context.Context, query string) (_ driver.Stmt, err error) {
27	config_json := query[11:]
28
29	bulkconfig := serializableBulkConfig{}
30	err = json.Unmarshal([]byte(config_json), &bulkconfig)
31	if err != nil {
32		return
33	}
34
35	bulkcopy := c.CreateBulkContext(ctx, bulkconfig.TableName, bulkconfig.ColumnsName)
36	bulkcopy.Options = bulkconfig.Options
37
38	ci := &copyin{
39		cn:       c,
40		bulkcopy: bulkcopy,
41	}
42
43	return ci, nil
44}
45
46func CopyIn(table string, options BulkOptions, columns ...string) string {
47	bulkconfig := &serializableBulkConfig{TableName: table, Options: options, ColumnsName: columns}
48
49	config_json, err := json.Marshal(bulkconfig)
50	if err != nil {
51		panic(err)
52	}
53
54	stmt := "INSERTBULK " + string(config_json)
55
56	return stmt
57}
58
59func (ci *copyin) NumInput() int {
60	return -1
61}
62
63func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
64	panic("should never be called")
65}
66
67func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
68	if ci.closed {
69		return nil, errors.New("copyin query is closed")
70	}
71
72	if len(v) == 0 {
73		rowCount, err := ci.bulkcopy.Done()
74		ci.closed = true
75		return driver.RowsAffected(rowCount), err
76	}
77
78	t := make([]interface{}, len(v))
79	for i, val := range v {
80		t[i] = val
81	}
82
83	err = ci.bulkcopy.AddRow(t)
84	if err != nil {
85		return
86	}
87
88	return driver.RowsAffected(0), nil
89}
90
91func (ci *copyin) Close() (err error) {
92	return nil
93}
94