1package gorm
2
3import (
4	"fmt"
5	"strings"
6)
7
8// Define callbacks for creating
9func init() {
10	DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
11	DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
12	DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
13	DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
14	DefaultCallback.Create().Register("gorm:create", createCallback)
15	DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
16	DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
17	DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
18	DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
19}
20
21// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
22func beforeCreateCallback(scope *Scope) {
23	if !scope.HasError() {
24		scope.CallMethod("BeforeSave")
25	}
26	if !scope.HasError() {
27		scope.CallMethod("BeforeCreate")
28	}
29}
30
31// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
32func updateTimeStampForCreateCallback(scope *Scope) {
33	if !scope.HasError() {
34		now := scope.db.nowFunc()
35
36		if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
37			if createdAtField.IsBlank {
38				createdAtField.Set(now)
39			}
40		}
41
42		if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
43			if updatedAtField.IsBlank {
44				updatedAtField.Set(now)
45			}
46		}
47	}
48}
49
50// createCallback the callback used to insert data into database
51func createCallback(scope *Scope) {
52	if !scope.HasError() {
53		defer scope.trace(scope.db.nowFunc())
54
55		var (
56			columns, placeholders        []string
57			blankColumnsWithDefaultValue []string
58		)
59
60		for _, field := range scope.Fields() {
61			if scope.changeableField(field) {
62				if field.IsNormal && !field.IsIgnored {
63					if field.IsBlank && field.HasDefaultValue {
64						blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
65						scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
66					} else if !field.IsPrimaryKey || !field.IsBlank {
67						columns = append(columns, scope.Quote(field.DBName))
68						placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
69					}
70				} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
71					for _, foreignKey := range field.Relationship.ForeignDBNames {
72						if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
73							columns = append(columns, scope.Quote(foreignField.DBName))
74							placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
75						}
76					}
77				}
78			}
79		}
80
81		var (
82			returningColumn = "*"
83			quotedTableName = scope.QuotedTableName()
84			primaryField    = scope.PrimaryField()
85			extraOption     string
86			insertModifier  string
87		)
88
89		if str, ok := scope.Get("gorm:insert_option"); ok {
90			extraOption = fmt.Sprint(str)
91		}
92		if str, ok := scope.Get("gorm:insert_modifier"); ok {
93			insertModifier = strings.ToUpper(fmt.Sprint(str))
94			if insertModifier == "INTO" {
95				insertModifier = ""
96			}
97		}
98
99		if primaryField != nil {
100			returningColumn = scope.Quote(primaryField.DBName)
101		}
102
103		lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
104
105		if len(columns) == 0 {
106			scope.Raw(fmt.Sprintf(
107				"INSERT %v INTO %v %v%v%v",
108				addExtraSpaceIfExist(insertModifier),
109				quotedTableName,
110				scope.Dialect().DefaultValueStr(),
111				addExtraSpaceIfExist(extraOption),
112				addExtraSpaceIfExist(lastInsertIDReturningSuffix),
113			))
114		} else {
115			scope.Raw(fmt.Sprintf(
116				"INSERT %v INTO %v (%v) VALUES (%v)%v%v",
117				addExtraSpaceIfExist(insertModifier),
118				scope.QuotedTableName(),
119				strings.Join(columns, ","),
120				strings.Join(placeholders, ","),
121				addExtraSpaceIfExist(extraOption),
122				addExtraSpaceIfExist(lastInsertIDReturningSuffix),
123			))
124		}
125
126		// execute create sql
127		if lastInsertIDReturningSuffix == "" || primaryField == nil {
128			if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
129				// set rows affected count
130				scope.db.RowsAffected, _ = result.RowsAffected()
131
132				// set primary value to primary field
133				if primaryField != nil && primaryField.IsBlank {
134					if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
135						scope.Err(primaryField.Set(primaryValue))
136					}
137				}
138			}
139		} else {
140			if primaryField.Field.CanAddr() {
141				if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
142					primaryField.IsBlank = false
143					scope.db.RowsAffected = 1
144				}
145			} else {
146				scope.Err(ErrUnaddressable)
147			}
148		}
149	}
150}
151
152// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
153func forceReloadAfterCreateCallback(scope *Scope) {
154	if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
155		db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
156		for _, field := range scope.Fields() {
157			if field.IsPrimaryKey && !field.IsBlank {
158				db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
159			}
160		}
161		db.Scan(scope.Value)
162	}
163}
164
165// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
166func afterCreateCallback(scope *Scope) {
167	if !scope.HasError() {
168		scope.CallMethod("AfterCreate")
169	}
170	if !scope.HasError() {
171		scope.CallMethod("AfterSave")
172	}
173}
174