1package mysql
2
3import (
4	"database/sql"
5	"fmt"
6
7	"gorm.io/gorm"
8	"gorm.io/gorm/clause"
9	"gorm.io/gorm/migrator"
10	"gorm.io/gorm/schema"
11)
12
13type Migrator struct {
14	migrator.Migrator
15	Dialector
16}
17
18type Column struct {
19	name              string
20	nullable          sql.NullString
21	datatype          string
22	maxLen            sql.NullInt64
23	precision         sql.NullInt64
24	scale             sql.NullInt64
25	datetimePrecision sql.NullInt64
26}
27
28func (c Column) Name() string {
29	return c.name
30}
31
32func (c Column) DatabaseTypeName() string {
33	return c.datatype
34}
35
36func (c Column) Length() (int64, bool) {
37	if c.maxLen.Valid {
38		return c.maxLen.Int64, c.maxLen.Valid
39	}
40
41	return 0, false
42}
43
44func (c Column) Nullable() (bool, bool) {
45	if c.nullable.Valid {
46		return c.nullable.String == "YES", true
47	}
48
49	return false, false
50}
51
52// DecimalSize return precision int64, scale int64, ok bool
53func (c Column) DecimalSize() (int64, int64, bool) {
54	if c.precision.Valid {
55		if c.scale.Valid {
56			return c.precision.Int64, c.scale.Int64, true
57		}
58
59		return c.precision.Int64, 0, true
60	}
61
62	if c.datetimePrecision.Valid {
63		return c.datetimePrecision.Int64, 0, true
64	}
65
66	return 0, 0, false
67}
68
69func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr {
70	expr := m.Migrator.FullDataTypeOf(field)
71
72	if value, ok := field.TagSettings["COMMENT"]; ok {
73		expr.SQL += " COMMENT " + m.Dialector.Explain("?", value)
74	}
75
76	return expr
77}
78
79func (m Migrator) AlterColumn(value interface{}, field string) error {
80	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
81		if field := stmt.Schema.LookUpField(field); field != nil {
82			return m.DB.Exec(
83				"ALTER TABLE ? MODIFY COLUMN ? ?",
84				clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
85			).Error
86		}
87		return fmt.Errorf("failed to look up field with name: %s", field)
88	})
89}
90
91func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
92	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
93		if !m.Dialector.DontSupportRenameColumn {
94			return m.Migrator.RenameColumn(value, oldName, newName)
95		}
96
97		var field *schema.Field
98		if f := stmt.Schema.LookUpField(oldName); f != nil {
99			oldName = f.DBName
100			field = f
101		}
102
103		if f := stmt.Schema.LookUpField(newName); f != nil {
104			newName = f.DBName
105			field = f
106		}
107
108		if field != nil {
109			return m.DB.Exec(
110				"ALTER TABLE ? CHANGE ? ? ?",
111				clause.Table{Name: stmt.Table}, clause.Column{Name: oldName},
112				clause.Column{Name: newName}, m.FullDataTypeOf(field),
113			).Error
114		}
115
116		return fmt.Errorf("failed to look up field with name: %s", newName)
117	})
118}
119
120func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
121	if !m.Dialector.DontSupportRenameIndex {
122		return m.RunWithValue(value, func(stmt *gorm.Statement) error {
123			return m.DB.Exec(
124				"ALTER TABLE ? RENAME INDEX ? TO ?",
125				clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
126			).Error
127		})
128	}
129
130	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
131		err := m.DropIndex(value, oldName)
132		if err != nil {
133			return err
134		}
135
136		if idx := stmt.Schema.LookIndex(newName); idx == nil {
137			if idx = stmt.Schema.LookIndex(oldName); idx != nil {
138				opts := m.BuildIndexOptions(idx.Fields, stmt)
139				values := []interface{}{clause.Column{Name: newName}, clause.Table{Name: stmt.Table}, opts}
140
141				createIndexSQL := "CREATE "
142				if idx.Class != "" {
143					createIndexSQL += idx.Class + " "
144				}
145				createIndexSQL += "INDEX ? ON ??"
146
147				if idx.Type != "" {
148					createIndexSQL += " USING " + idx.Type
149				}
150
151				return m.DB.Exec(createIndexSQL, values...).Error
152			}
153		}
154
155		return m.CreateIndex(value, newName)
156	})
157
158}
159
160func (m Migrator) DropTable(values ...interface{}) error {
161	values = m.ReorderModels(values, false)
162	tx := m.DB.Session(&gorm.Session{})
163	tx.Exec("SET FOREIGN_KEY_CHECKS = 0;")
164	for i := len(values) - 1; i >= 0; i-- {
165		if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
166			return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
167		}); err != nil {
168			return err
169		}
170	}
171	tx.Exec("SET FOREIGN_KEY_CHECKS = 1;")
172	return nil
173}
174
175func (m Migrator) DropConstraint(value interface{}, name string) error {
176	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
177		constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
178		if chk != nil {
179			return m.DB.Exec("ALTER TABLE ? DROP CHECK ?", clause.Table{Name: stmt.Table}, clause.Column{Name: chk.Name}).Error
180		}
181		if constraint != nil {
182			name = constraint.Name
183		}
184
185		return m.DB.Exec(
186			"ALTER TABLE ? DROP FOREIGN KEY ?", clause.Table{Name: table}, clause.Column{Name: name},
187		).Error
188	})
189}
190
191// ColumnTypes column types return columnTypes,error
192func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
193	columnTypes := make([]gorm.ColumnType, 0)
194	err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
195		var (
196			currentDatabase = m.DB.Migrator().CurrentDatabase()
197			columnTypeSQL   = "SELECT column_name, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_scale "
198		)
199
200		if !m.DisableDatetimePrecision {
201			columnTypeSQL += ", datetime_precision "
202		}
203		columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ?"
204
205		columns, rowErr := m.DB.Raw(columnTypeSQL, currentDatabase, stmt.Table).Rows()
206		if rowErr != nil {
207			return rowErr
208		}
209
210		defer columns.Close()
211
212		for columns.Next() {
213			var column Column
214			var values = []interface{}{&column.name, &column.nullable, &column.datatype,
215				&column.maxLen, &column.precision, &column.scale}
216
217			if !m.DisableDatetimePrecision {
218				values = append(values, &column.datetimePrecision)
219			}
220
221			if scanErr := columns.Scan(values...); scanErr != nil {
222				return scanErr
223			}
224			columnTypes = append(columnTypes, column)
225		}
226
227		return nil
228	})
229
230	return columnTypes, err
231}
232