1package mysql
2
3import (
4	"context"
5	"database/sql"
6	"fmt"
7	"math"
8	"strings"
9	"time"
10
11	_ "github.com/go-sql-driver/mysql"
12	"gorm.io/gorm"
13	"gorm.io/gorm/callbacks"
14	"gorm.io/gorm/clause"
15	"gorm.io/gorm/logger"
16	"gorm.io/gorm/migrator"
17	"gorm.io/gorm/schema"
18)
19
20type Config struct {
21	DriverName                string
22	DSN                       string
23	Conn                      gorm.ConnPool
24	SkipInitializeWithVersion bool
25	DefaultStringSize         uint
26	DefaultDatetimePrecision  *int
27	DisableDatetimePrecision  bool
28	DontSupportRenameIndex    bool
29	DontSupportRenameColumn   bool
30	DontSupportForShareClause bool
31}
32
33type Dialector struct {
34	*Config
35}
36
37var (
38	// CreateClauses create clauses
39	CreateClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
40	// UpdateClauses update clauses
41	UpdateClauses = []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"}
42	// DeleteClauses delete clauses
43	DeleteClauses = []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"}
44
45	defaultDatetimePrecision = 3
46)
47
48func Open(dsn string) gorm.Dialector {
49	return &Dialector{Config: &Config{DSN: dsn}}
50}
51
52func New(config Config) gorm.Dialector {
53	return &Dialector{Config: &config}
54}
55
56func (dialector Dialector) Name() string {
57	return "mysql"
58}
59
60// NowFunc return now func
61func (dialector Dialector) NowFunc(n int) func() time.Time {
62	return func() time.Time {
63		round := time.Second / time.Duration(math.Pow10(n))
64		return time.Now().Local().Round(round)
65	}
66}
67
68func (dialector Dialector) Apply(config *gorm.Config) error {
69	if config.NowFunc == nil {
70		if dialector.DefaultDatetimePrecision == nil {
71			dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
72		}
73
74		// while maintaining the readability of the code, separate the business logic from
75		// the general part and leave it to the function to do it here.
76		config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
77	}
78
79	return nil
80}
81
82func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
83	ctx := context.Background()
84
85	// register callbacks
86	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
87		CreateClauses: CreateClauses,
88		UpdateClauses: UpdateClauses,
89		DeleteClauses: DeleteClauses,
90	})
91
92	if dialector.DriverName == "" {
93		dialector.DriverName = "mysql"
94	}
95
96	if dialector.DefaultDatetimePrecision == nil {
97		dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
98	}
99
100	if dialector.Conn != nil {
101		db.ConnPool = dialector.Conn
102	} else {
103		db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
104		if err != nil {
105			return err
106		}
107	}
108
109	if !dialector.Config.SkipInitializeWithVersion {
110		var version string
111		err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&version)
112		if err != nil {
113			return err
114		}
115
116		if strings.Contains(version, "MariaDB") {
117			dialector.Config.DontSupportRenameIndex = true
118			dialector.Config.DontSupportRenameColumn = true
119			dialector.Config.DontSupportForShareClause = true
120		} else if strings.HasPrefix(version, "5.6.") {
121			dialector.Config.DontSupportRenameIndex = true
122			dialector.Config.DontSupportRenameColumn = true
123			dialector.Config.DontSupportForShareClause = true
124		} else if strings.HasPrefix(version, "5.7.") {
125			dialector.Config.DontSupportRenameColumn = true
126			dialector.Config.DontSupportForShareClause = true
127		} else if strings.HasPrefix(version, "5.") {
128			dialector.Config.DisableDatetimePrecision = true
129			dialector.Config.DontSupportRenameIndex = true
130			dialector.Config.DontSupportRenameColumn = true
131			dialector.Config.DontSupportForShareClause = true
132		}
133	}
134
135	for k, v := range dialector.ClauseBuilders() {
136		db.ClauseBuilders[k] = v
137	}
138	return
139}
140
141const (
142	// ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key
143	ClauseOnConflict = "ON CONFLICT"
144	// ClauseValues for clause.ClauseBuilder VALUES key
145	ClauseValues = "VALUES"
146	// ClauseValues for clause.ClauseBuilder FOR key
147	ClauseFor = "FOR"
148)
149
150func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
151	clauseBuilders := map[string]clause.ClauseBuilder{
152		ClauseOnConflict: func(c clause.Clause, builder clause.Builder) {
153			onConflict, ok := c.Expression.(clause.OnConflict)
154			if !ok {
155				c.Build(builder)
156				return
157			}
158
159			builder.WriteString("ON DUPLICATE KEY UPDATE ")
160			if len(onConflict.DoUpdates) == 0 {
161				if s := builder.(*gorm.Statement).Schema; s != nil {
162					var column clause.Column
163					onConflict.DoNothing = false
164
165					if s.PrioritizedPrimaryField != nil {
166						column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
167					} else if len(s.DBNames) > 0 {
168						column = clause.Column{Name: s.DBNames[0]}
169					}
170
171					if column.Name != "" {
172						onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
173					}
174				}
175			}
176
177			for idx, assignment := range onConflict.DoUpdates {
178				if idx > 0 {
179					builder.WriteByte(',')
180				}
181
182				builder.WriteQuoted(assignment.Column)
183				builder.WriteByte('=')
184				if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" {
185					column.Table = ""
186					builder.WriteString("VALUES(")
187					builder.WriteQuoted(column)
188					builder.WriteByte(')')
189				} else {
190					builder.AddVar(builder, assignment.Value)
191				}
192			}
193		},
194		ClauseValues: func(c clause.Clause, builder clause.Builder) {
195			if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 {
196				builder.WriteString("VALUES()")
197				return
198			}
199			c.Build(builder)
200		},
201	}
202
203	if dialector.Config.DontSupportForShareClause {
204		clauseBuilders[ClauseFor] = func(c clause.Clause, builder clause.Builder) {
205			if values, ok := c.Expression.(clause.Locking); ok && strings.EqualFold(values.Strength, "SHARE") {
206				builder.WriteString("LOCK IN SHARE MODE")
207				return
208			}
209			c.Build(builder)
210		}
211	}
212
213	return clauseBuilders
214}
215
216func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
217	return clause.Expr{SQL: "DEFAULT"}
218}
219
220func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
221	return Migrator{
222		Migrator: migrator.Migrator{
223			Config: migrator.Config{
224				DB:        db,
225				Dialector: dialector,
226			},
227		},
228		Dialector: dialector,
229	}
230}
231
232func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
233	writer.WriteByte('?')
234}
235
236func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
237	var (
238		underQuoted, selfQuoted bool
239		continuousBacktick      int8
240		shiftDelimiter          int8
241	)
242
243	for _, v := range []byte(str) {
244		switch v {
245		case '`':
246			continuousBacktick++
247			if continuousBacktick == 2 {
248				writer.WriteString("``")
249				continuousBacktick = 0
250			}
251		case '.':
252			if continuousBacktick > 0 || !selfQuoted {
253				shiftDelimiter = 0
254				underQuoted = false
255				continuousBacktick = 0
256				writer.WriteString("`")
257			}
258			writer.WriteByte(v)
259			continue
260		default:
261			if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
262				writer.WriteByte('`')
263				underQuoted = true
264				if selfQuoted = continuousBacktick > 0; selfQuoted {
265					continuousBacktick -= 1
266				}
267			}
268
269			for ; continuousBacktick > 0; continuousBacktick -= 1 {
270				writer.WriteString("``")
271			}
272
273			writer.WriteByte(v)
274		}
275		shiftDelimiter++
276	}
277
278	if continuousBacktick > 0 && !selfQuoted {
279		writer.WriteString("``")
280	}
281	writer.WriteString("`")
282}
283
284func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
285	return logger.ExplainSQL(sql, nil, `'`, vars...)
286}
287
288func (dialector Dialector) DataTypeOf(field *schema.Field) string {
289	switch field.DataType {
290	case schema.Bool:
291		return "boolean"
292	case schema.Int, schema.Uint:
293		return dialector.getSchemaIntAndUnitType(field)
294	case schema.Float:
295		return dialector.getSchemaFloatType(field)
296	case schema.String:
297		return dialector.getSchemaStringType(field)
298	case schema.Time:
299		return dialector.getSchemaTimeType(field)
300	case schema.Bytes:
301		return dialector.getSchemaBytesType(field)
302	}
303
304	return string(field.DataType)
305}
306
307func (dialector Dialector) getSchemaFloatType(field *schema.Field) string {
308	if field.Precision > 0 {
309		return fmt.Sprintf("decimal(%d, %d)", field.Precision, field.Scale)
310	}
311
312	if field.Size <= 32 {
313		return "float"
314	}
315
316	return "double"
317}
318
319func (dialector Dialector) getSchemaStringType(field *schema.Field) string {
320	size := field.Size
321	if size == 0 {
322		if dialector.DefaultStringSize > 0 {
323			size = int(dialector.DefaultStringSize)
324		} else {
325			hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
326			// TEXT, GEOMETRY or JSON column can't have a default value
327			if field.PrimaryKey || field.HasDefaultValue || hasIndex {
328				size = 191 // utf8mb4
329			}
330		}
331	}
332
333	if size >= 65536 && size <= int(math.Pow(2, 24)) {
334		return "mediumtext"
335	}
336
337	if size > int(math.Pow(2, 24)) || size <= 0 {
338		return "longtext"
339	}
340
341	return fmt.Sprintf("varchar(%d)", size)
342}
343
344func (dialector Dialector) getSchemaTimeType(field *schema.Field) string {
345	precision := ""
346	if !dialector.DisableDatetimePrecision && field.Precision == 0 {
347		field.Precision = *dialector.DefaultDatetimePrecision
348	}
349
350	if field.Precision > 0 {
351		precision = fmt.Sprintf("(%d)", field.Precision)
352	}
353
354	if field.NotNull || field.PrimaryKey {
355		return "datetime" + precision
356	}
357	return "datetime" + precision + " NULL"
358}
359
360func (dialector Dialector) getSchemaBytesType(field *schema.Field) string {
361	if field.Size > 0 && field.Size < 65536 {
362		return fmt.Sprintf("varbinary(%d)", field.Size)
363	}
364
365	if field.Size >= 65536 && field.Size <= int(math.Pow(2, 24)) {
366		return "mediumblob"
367	}
368
369	return "longblob"
370}
371
372func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string {
373	sqlType := "bigint"
374	switch {
375	case field.Size <= 8:
376		sqlType = "tinyint"
377	case field.Size <= 16:
378		sqlType = "smallint"
379	case field.Size <= 24:
380		sqlType = "mediumint"
381	case field.Size <= 32:
382		sqlType = "int"
383	}
384
385	if field.DataType == schema.Uint {
386		sqlType += " unsigned"
387	}
388
389	if field.AutoIncrement {
390		sqlType += " AUTO_INCREMENT"
391	}
392
393	return sqlType
394}
395
396func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
397	tx.Exec("SAVEPOINT " + name)
398	return nil
399}
400
401func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
402	tx.Exec("ROLLBACK TO SAVEPOINT " + name)
403	return nil
404}
405