1// Copyright (c) 2012-present The upper.io/db authors. All rights reserved. 2// 3// Permission is hereby granted, free of charge, to any person obtaining 4// a copy of this software and associated documentation files (the 5// "Software"), to deal in the Software without restriction, including 6// without limitation the rights to use, copy, modify, merge, publish, 7// distribute, sublicense, and/or sell copies of the Software, and to 8// permit persons to whom the Software is furnished to do so, subject to 9// the following conditions: 10// 11// The above copyright notice and this permission notice shall be 12// included in all copies or substantial portions of the Software. 13// 14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 22package mssql 23 24import ( 25 db "upper.io/db.v3" 26 "upper.io/db.v3/internal/sqladapter" 27 "upper.io/db.v3/lib/sqlbuilder" 28) 29 30// table is the actual implementation of a collection. 31type table struct { 32 sqladapter.BaseCollection // Leveraged by sqladapter 33 34 d *database 35 name string 36 37 hasIdentityColumn *bool 38} 39 40var ( 41 _ = sqladapter.Collection(&table{}) 42 _ = db.Collection(&table{}) 43) 44 45// newTable binds *table with sqladapter. 46func newTable(d *database, name string) *table { 47 t := &table{ 48 name: name, 49 d: d, 50 } 51 t.BaseCollection = sqladapter.NewBaseCollection(t) 52 return t 53} 54 55func (t *table) Name() string { 56 return t.name 57} 58 59func (t *table) Database() sqladapter.Database { 60 return t.d 61} 62 63// Insert inserts an item (map or struct) into the collection. 64func (t *table) Insert(item interface{}) (interface{}, error) { 65 columnNames, columnValues, err := sqlbuilder.Map(item, nil) 66 if err != nil { 67 return nil, err 68 } 69 70 pKey := t.BaseCollection.PrimaryKeys() 71 72 var hasKeys bool 73 for i := range columnNames { 74 for j := 0; j < len(pKey); j++ { 75 if pKey[j] == columnNames[i] { 76 if columnValues[i] != nil { 77 hasKeys = true 78 break 79 } 80 } 81 } 82 } 83 84 if hasKeys { 85 if t.hasIdentityColumn == nil { 86 var hasIdentityColumn bool 87 var identityColumns int 88 89 row, err := t.d.QueryRow("SELECT COUNT(1) FROM sys.identity_columns WHERE OBJECT_NAME(object_id) = ?", t.Name()) 90 if err != nil { 91 return nil, err 92 } 93 94 err = row.Scan(&identityColumns) 95 if err != nil { 96 return nil, err 97 } 98 99 if identityColumns > 0 { 100 hasIdentityColumn = true 101 } 102 103 t.hasIdentityColumn = &hasIdentityColumn 104 } 105 106 if *t.hasIdentityColumn { 107 _, err = t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " ON") 108 if err != nil { 109 return nil, err 110 } 111 defer func() { 112 _, _ = t.d.Exec("SET IDENTITY_INSERT " + t.Name() + " OFF") 113 }() 114 } 115 } 116 117 q := t.d.InsertInto(t.Name()). 118 Columns(columnNames...). 119 Values(columnValues...) 120 121 if len(pKey) < 1 { 122 _, err = q.Exec() 123 if err != nil { 124 return nil, err 125 } 126 return nil, nil 127 } 128 129 q = q.Returning(pKey...) 130 131 var keyMap db.Cond 132 if err = q.Iterator().One(&keyMap); err != nil { 133 return nil, err 134 } 135 136 // The IDSetter interface does not match, look for another interface match. 137 if len(keyMap) == 1 { 138 return keyMap[pKey[0]], nil 139 } 140 141 // This was a compound key and no interface matched it, let's return a map. 142 return keyMap, nil 143} 144