1package gorm 2 3import ( 4 "fmt" 5 "strings" 6) 7 8// Define callbacks for creating 9func init() { 10 DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback) 11 DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback) 12 DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback) 13 DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback) 14 DefaultCallback.Create().Register("gorm:create", createCallback) 15 DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback) 16 DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback) 17 DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback) 18 DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback) 19} 20 21// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating 22func beforeCreateCallback(scope *Scope) { 23 if !scope.HasError() { 24 scope.CallMethod("BeforeSave") 25 } 26 if !scope.HasError() { 27 scope.CallMethod("BeforeCreate") 28 } 29} 30 31// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating 32func updateTimeStampForCreateCallback(scope *Scope) { 33 if !scope.HasError() { 34 now := scope.db.nowFunc() 35 36 if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { 37 if createdAtField.IsBlank { 38 createdAtField.Set(now) 39 } 40 } 41 42 if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { 43 if updatedAtField.IsBlank { 44 updatedAtField.Set(now) 45 } 46 } 47 } 48} 49 50// createCallback the callback used to insert data into database 51func createCallback(scope *Scope) { 52 if !scope.HasError() { 53 defer scope.trace(scope.db.nowFunc()) 54 55 var ( 56 columns, placeholders []string 57 blankColumnsWithDefaultValue []string 58 ) 59 60 for _, field := range scope.Fields() { 61 if scope.changeableField(field) { 62 if field.IsNormal && !field.IsIgnored { 63 if field.IsBlank && field.HasDefaultValue { 64 blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) 65 scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) 66 } else if !field.IsPrimaryKey || !field.IsBlank { 67 columns = append(columns, scope.Quote(field.DBName)) 68 placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) 69 } 70 } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { 71 for _, foreignKey := range field.Relationship.ForeignDBNames { 72 if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { 73 columns = append(columns, scope.Quote(foreignField.DBName)) 74 placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface())) 75 } 76 } 77 } 78 } 79 } 80 81 var ( 82 returningColumn = "*" 83 quotedTableName = scope.QuotedTableName() 84 primaryField = scope.PrimaryField() 85 extraOption string 86 insertModifier string 87 ) 88 89 if str, ok := scope.Get("gorm:insert_option"); ok { 90 extraOption = fmt.Sprint(str) 91 } 92 if str, ok := scope.Get("gorm:insert_modifier"); ok { 93 insertModifier = strings.ToUpper(fmt.Sprint(str)) 94 if insertModifier == "INTO" { 95 insertModifier = "" 96 } 97 } 98 99 if primaryField != nil { 100 returningColumn = scope.Quote(primaryField.DBName) 101 } 102 103 lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) 104 105 if len(columns) == 0 { 106 scope.Raw(fmt.Sprintf( 107 "INSERT %v INTO %v %v%v%v", 108 addExtraSpaceIfExist(insertModifier), 109 quotedTableName, 110 scope.Dialect().DefaultValueStr(), 111 addExtraSpaceIfExist(extraOption), 112 addExtraSpaceIfExist(lastInsertIDReturningSuffix), 113 )) 114 } else { 115 scope.Raw(fmt.Sprintf( 116 "INSERT %v INTO %v (%v) VALUES (%v)%v%v", 117 addExtraSpaceIfExist(insertModifier), 118 scope.QuotedTableName(), 119 strings.Join(columns, ","), 120 strings.Join(placeholders, ","), 121 addExtraSpaceIfExist(extraOption), 122 addExtraSpaceIfExist(lastInsertIDReturningSuffix), 123 )) 124 } 125 126 // execute create sql 127 if lastInsertIDReturningSuffix == "" || primaryField == nil { 128 if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { 129 // set rows affected count 130 scope.db.RowsAffected, _ = result.RowsAffected() 131 132 // set primary value to primary field 133 if primaryField != nil && primaryField.IsBlank { 134 if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { 135 scope.Err(primaryField.Set(primaryValue)) 136 } 137 } 138 } 139 } else { 140 if primaryField.Field.CanAddr() { 141 if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { 142 primaryField.IsBlank = false 143 scope.db.RowsAffected = 1 144 } 145 } else { 146 scope.Err(ErrUnaddressable) 147 } 148 } 149 } 150} 151 152// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object 153func forceReloadAfterCreateCallback(scope *Scope) { 154 if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { 155 db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) 156 for _, field := range scope.Fields() { 157 if field.IsPrimaryKey && !field.IsBlank { 158 db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) 159 } 160 } 161 db.Scan(scope.Value) 162 } 163} 164 165// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating 166func afterCreateCallback(scope *Scope) { 167 if !scope.HasError() { 168 scope.CallMethod("AfterCreate") 169 } 170 if !scope.HasError() { 171 scope.CallMethod("AfterSave") 172 } 173} 174