1package mysql 2 3import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "math" 8 "strings" 9 "time" 10 11 _ "github.com/go-sql-driver/mysql" 12 "gorm.io/gorm" 13 "gorm.io/gorm/callbacks" 14 "gorm.io/gorm/clause" 15 "gorm.io/gorm/logger" 16 "gorm.io/gorm/migrator" 17 "gorm.io/gorm/schema" 18) 19 20type Config struct { 21 DriverName string 22 DSN string 23 Conn gorm.ConnPool 24 SkipInitializeWithVersion bool 25 DefaultStringSize uint 26 DefaultDatetimePrecision *int 27 DisableDatetimePrecision bool 28 DontSupportRenameIndex bool 29 DontSupportRenameColumn bool 30 DontSupportForShareClause bool 31} 32 33type Dialector struct { 34 *Config 35} 36 37var ( 38 // CreateClauses create clauses 39 CreateClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} 40 // UpdateClauses update clauses 41 UpdateClauses = []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"} 42 // DeleteClauses delete clauses 43 DeleteClauses = []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"} 44 45 defaultDatetimePrecision = 3 46) 47 48func Open(dsn string) gorm.Dialector { 49 return &Dialector{Config: &Config{DSN: dsn}} 50} 51 52func New(config Config) gorm.Dialector { 53 return &Dialector{Config: &config} 54} 55 56func (dialector Dialector) Name() string { 57 return "mysql" 58} 59 60// NowFunc return now func 61func (dialector Dialector) NowFunc(n int) func() time.Time { 62 return func() time.Time { 63 round := time.Second / time.Duration(math.Pow10(n)) 64 return time.Now().Local().Round(round) 65 } 66} 67 68func (dialector Dialector) Apply(config *gorm.Config) error { 69 if config.NowFunc == nil { 70 if dialector.DefaultDatetimePrecision == nil { 71 dialector.DefaultDatetimePrecision = &defaultDatetimePrecision 72 } 73 74 // while maintaining the readability of the code, separate the business logic from 75 // the general part and leave it to the function to do it here. 76 config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision) 77 } 78 79 return nil 80} 81 82func (dialector Dialector) Initialize(db *gorm.DB) (err error) { 83 ctx := context.Background() 84 85 // register callbacks 86 callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ 87 CreateClauses: CreateClauses, 88 UpdateClauses: UpdateClauses, 89 DeleteClauses: DeleteClauses, 90 }) 91 92 if dialector.DriverName == "" { 93 dialector.DriverName = "mysql" 94 } 95 96 if dialector.DefaultDatetimePrecision == nil { 97 dialector.DefaultDatetimePrecision = &defaultDatetimePrecision 98 } 99 100 if dialector.Conn != nil { 101 db.ConnPool = dialector.Conn 102 } else { 103 db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN) 104 if err != nil { 105 return err 106 } 107 } 108 109 if !dialector.Config.SkipInitializeWithVersion { 110 var version string 111 err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&version) 112 if err != nil { 113 return err 114 } 115 116 if strings.Contains(version, "MariaDB") { 117 dialector.Config.DontSupportRenameIndex = true 118 dialector.Config.DontSupportRenameColumn = true 119 dialector.Config.DontSupportForShareClause = true 120 } else if strings.HasPrefix(version, "5.6.") { 121 dialector.Config.DontSupportRenameIndex = true 122 dialector.Config.DontSupportRenameColumn = true 123 dialector.Config.DontSupportForShareClause = true 124 } else if strings.HasPrefix(version, "5.7.") { 125 dialector.Config.DontSupportRenameColumn = true 126 dialector.Config.DontSupportForShareClause = true 127 } else if strings.HasPrefix(version, "5.") { 128 dialector.Config.DisableDatetimePrecision = true 129 dialector.Config.DontSupportRenameIndex = true 130 dialector.Config.DontSupportRenameColumn = true 131 dialector.Config.DontSupportForShareClause = true 132 } 133 } 134 135 for k, v := range dialector.ClauseBuilders() { 136 db.ClauseBuilders[k] = v 137 } 138 return 139} 140 141const ( 142 // ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key 143 ClauseOnConflict = "ON CONFLICT" 144 // ClauseValues for clause.ClauseBuilder VALUES key 145 ClauseValues = "VALUES" 146 // ClauseValues for clause.ClauseBuilder FOR key 147 ClauseFor = "FOR" 148) 149 150func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { 151 clauseBuilders := map[string]clause.ClauseBuilder{ 152 ClauseOnConflict: func(c clause.Clause, builder clause.Builder) { 153 onConflict, ok := c.Expression.(clause.OnConflict) 154 if !ok { 155 c.Build(builder) 156 return 157 } 158 159 builder.WriteString("ON DUPLICATE KEY UPDATE ") 160 if len(onConflict.DoUpdates) == 0 { 161 if s := builder.(*gorm.Statement).Schema; s != nil { 162 var column clause.Column 163 onConflict.DoNothing = false 164 165 if s.PrioritizedPrimaryField != nil { 166 column = clause.Column{Name: s.PrioritizedPrimaryField.DBName} 167 } else if len(s.DBNames) > 0 { 168 column = clause.Column{Name: s.DBNames[0]} 169 } 170 171 if column.Name != "" { 172 onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} 173 } 174 } 175 } 176 177 for idx, assignment := range onConflict.DoUpdates { 178 if idx > 0 { 179 builder.WriteByte(',') 180 } 181 182 builder.WriteQuoted(assignment.Column) 183 builder.WriteByte('=') 184 if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" { 185 column.Table = "" 186 builder.WriteString("VALUES(") 187 builder.WriteQuoted(column) 188 builder.WriteByte(')') 189 } else { 190 builder.AddVar(builder, assignment.Value) 191 } 192 } 193 }, 194 ClauseValues: func(c clause.Clause, builder clause.Builder) { 195 if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 { 196 builder.WriteString("VALUES()") 197 return 198 } 199 c.Build(builder) 200 }, 201 } 202 203 if dialector.Config.DontSupportForShareClause { 204 clauseBuilders[ClauseFor] = func(c clause.Clause, builder clause.Builder) { 205 if values, ok := c.Expression.(clause.Locking); ok && strings.EqualFold(values.Strength, "SHARE") { 206 builder.WriteString("LOCK IN SHARE MODE") 207 return 208 } 209 c.Build(builder) 210 } 211 } 212 213 return clauseBuilders 214} 215 216func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { 217 return clause.Expr{SQL: "DEFAULT"} 218} 219 220func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { 221 return Migrator{ 222 Migrator: migrator.Migrator{ 223 Config: migrator.Config{ 224 DB: db, 225 Dialector: dialector, 226 }, 227 }, 228 Dialector: dialector, 229 } 230} 231 232func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { 233 writer.WriteByte('?') 234} 235 236func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { 237 var ( 238 underQuoted, selfQuoted bool 239 continuousBacktick int8 240 shiftDelimiter int8 241 ) 242 243 for _, v := range []byte(str) { 244 switch v { 245 case '`': 246 continuousBacktick++ 247 if continuousBacktick == 2 { 248 writer.WriteString("``") 249 continuousBacktick = 0 250 } 251 case '.': 252 if continuousBacktick > 0 || !selfQuoted { 253 shiftDelimiter = 0 254 underQuoted = false 255 continuousBacktick = 0 256 writer.WriteString("`") 257 } 258 writer.WriteByte(v) 259 continue 260 default: 261 if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { 262 writer.WriteByte('`') 263 underQuoted = true 264 if selfQuoted = continuousBacktick > 0; selfQuoted { 265 continuousBacktick -= 1 266 } 267 } 268 269 for ; continuousBacktick > 0; continuousBacktick -= 1 { 270 writer.WriteString("``") 271 } 272 273 writer.WriteByte(v) 274 } 275 shiftDelimiter++ 276 } 277 278 if continuousBacktick > 0 && !selfQuoted { 279 writer.WriteString("``") 280 } 281 writer.WriteString("`") 282} 283 284func (dialector Dialector) Explain(sql string, vars ...interface{}) string { 285 return logger.ExplainSQL(sql, nil, `'`, vars...) 286} 287 288func (dialector Dialector) DataTypeOf(field *schema.Field) string { 289 switch field.DataType { 290 case schema.Bool: 291 return "boolean" 292 case schema.Int, schema.Uint: 293 return dialector.getSchemaIntAndUnitType(field) 294 case schema.Float: 295 return dialector.getSchemaFloatType(field) 296 case schema.String: 297 return dialector.getSchemaStringType(field) 298 case schema.Time: 299 return dialector.getSchemaTimeType(field) 300 case schema.Bytes: 301 return dialector.getSchemaBytesType(field) 302 } 303 304 return string(field.DataType) 305} 306 307func (dialector Dialector) getSchemaFloatType(field *schema.Field) string { 308 if field.Precision > 0 { 309 return fmt.Sprintf("decimal(%d, %d)", field.Precision, field.Scale) 310 } 311 312 if field.Size <= 32 { 313 return "float" 314 } 315 316 return "double" 317} 318 319func (dialector Dialector) getSchemaStringType(field *schema.Field) string { 320 size := field.Size 321 if size == 0 { 322 if dialector.DefaultStringSize > 0 { 323 size = int(dialector.DefaultStringSize) 324 } else { 325 hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != "" 326 // TEXT, GEOMETRY or JSON column can't have a default value 327 if field.PrimaryKey || field.HasDefaultValue || hasIndex { 328 size = 191 // utf8mb4 329 } 330 } 331 } 332 333 if size >= 65536 && size <= int(math.Pow(2, 24)) { 334 return "mediumtext" 335 } 336 337 if size > int(math.Pow(2, 24)) || size <= 0 { 338 return "longtext" 339 } 340 341 return fmt.Sprintf("varchar(%d)", size) 342} 343 344func (dialector Dialector) getSchemaTimeType(field *schema.Field) string { 345 precision := "" 346 if !dialector.DisableDatetimePrecision && field.Precision == 0 { 347 field.Precision = *dialector.DefaultDatetimePrecision 348 } 349 350 if field.Precision > 0 { 351 precision = fmt.Sprintf("(%d)", field.Precision) 352 } 353 354 if field.NotNull || field.PrimaryKey { 355 return "datetime" + precision 356 } 357 return "datetime" + precision + " NULL" 358} 359 360func (dialector Dialector) getSchemaBytesType(field *schema.Field) string { 361 if field.Size > 0 && field.Size < 65536 { 362 return fmt.Sprintf("varbinary(%d)", field.Size) 363 } 364 365 if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) { 366 return "mediumblob" 367 } 368 369 return "longblob" 370} 371 372func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string { 373 sqlType := "bigint" 374 switch { 375 case field.Size <= 8: 376 sqlType = "tinyint" 377 case field.Size <= 16: 378 sqlType = "smallint" 379 case field.Size <= 24: 380 sqlType = "mediumint" 381 case field.Size <= 32: 382 sqlType = "int" 383 } 384 385 if field.DataType == schema.Uint { 386 sqlType += " unsigned" 387 } 388 389 if field.AutoIncrement { 390 sqlType += " AUTO_INCREMENT" 391 } 392 393 return sqlType 394} 395 396func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error { 397 tx.Exec("SAVEPOINT " + name) 398 return nil 399} 400 401func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error { 402 tx.Exec("ROLLBACK TO SAVEPOINT " + name) 403 return nil 404} 405