1// Copyright (c) 2012-present The upper.io/db authors. All rights reserved.
2//
3// Permission is hereby granted, free of charge, to any person obtaining
4// a copy of this software and associated documentation files (the
5// "Software"), to deal in the Software without restriction, including
6// without limitation the rights to use, copy, modify, merge, publish,
7// distribute, sublicense, and/or sell copies of the Software, and to
8// permit persons to whom the Software is furnished to do so, subject to
9// the following conditions:
10//
11// The above copyright notice and this permission notice shall be
12// included in all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21
22package mssql
23
24import (
25	db "upper.io/db.v3"
26	"upper.io/db.v3/internal/sqladapter"
27	"upper.io/db.v3/lib/sqlbuilder"
28)
29
30// table is the actual implementation of a collection.
31type table struct {
32	sqladapter.BaseCollection // Leveraged by sqladapter
33
34	d    *database
35	name string
36
37	hasIdentityColumn *bool
38}
39
40var (
41	_ = sqladapter.Collection(&table{})
42	_ = db.Collection(&table{})
43)
44
45// newTable binds *table with sqladapter.
46func newTable(d *database, name string) *table {
47	t := &table{
48		name: name,
49		d:    d,
50	}
51	t.BaseCollection = sqladapter.NewBaseCollection(t)
52	return t
53}
54
55func (t *table) Name() string {
56	return t.name
57}
58
59func (t *table) Database() sqladapter.Database {
60	return t.d
61}
62
63// Insert inserts an item (map or struct) into the collection.
64func (t *table) Insert(item interface{}) (interface{}, error) {
65	columnNames, columnValues, err := sqlbuilder.Map(item, nil)
66	if err != nil {
67		return nil, err
68	}
69
70	pKey := t.BaseCollection.PrimaryKeys()
71
72	var hasKeys bool
73	for i := range columnNames {
74		for j := 0; j < len(pKey); j++ {
75			if pKey[j] == columnNames[i] {
76				if columnValues[i] != nil {
77					hasKeys = true
78					break
79				}
80			}
81		}
82	}
83
84	if hasKeys {
85		if t.hasIdentityColumn == nil {
86			var hasIdentityColumn bool
87			var identityColumns int
88
89			row, err := t.d.QueryRow("SELECT COUNT(1) FROM sys.identity_columns WHERE OBJECT_NAME(object_id) = ?", t.Name())
90			if err != nil {
91				return nil, err
92			}
93
94			err = row.Scan(&identityColumns)
95			if err != nil {
96				return nil, err
97			}
98
99			if identityColumns > 0 {
100				hasIdentityColumn = true
101			}
102
103			t.hasIdentityColumn = &hasIdentityColumn
104		}
105
106		if *t.hasIdentityColumn {
107			_, err = t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " ON")
108			if err != nil {
109				return nil, err
110			}
111			defer func() {
112				_, _ = t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " OFF")
113			}()
114		}
115	}
116
117	q := t.d.InsertInto(t.Name()).
118		Columns(columnNames...).
119		Values(columnValues...)
120
121	if len(pKey) < 1 {
122		_, err = q.Exec()
123		if err != nil {
124			return nil, err
125		}
126		return nil, nil
127	}
128
129	q = q.Returning(pKey...)
130
131	var keyMap db.Cond
132	if err = q.Iterator().One(&keyMap); err != nil {
133		return nil, err
134	}
135
136	// The IDSetter interface does not match, look for another interface match.
137	if len(keyMap) == 1 {
138		return keyMap[pKey[0]], nil
139	}
140
141	// This was a compound key and no interface matched it, let's return a map.
142	return keyMap, nil
143}
144