1package gorm
2
3import (
4	"database/sql"
5	"fmt"
6	"reflect"
7	"strconv"
8	"strings"
9)
10
11// Dialect interface contains behaviors that differ across SQL database
12type Dialect interface {
13	// GetName get dialect's name
14	GetName() string
15
16	// SetDB set db for dialect
17	SetDB(db SQLCommon)
18
19	// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
20	BindVar(i int) string
21	// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
22	Quote(key string) string
23	// DataTypeOf return data's sql type
24	DataTypeOf(field *StructField) string
25
26	// HasIndex check has index or not
27	HasIndex(tableName string, indexName string) bool
28	// HasForeignKey check has foreign key or not
29	HasForeignKey(tableName string, foreignKeyName string) bool
30	// RemoveIndex remove index
31	RemoveIndex(tableName string, indexName string) error
32	// HasTable check has table or not
33	HasTable(tableName string) bool
34	// HasColumn check has column or not
35	HasColumn(tableName string, columnName string) bool
36	// ModifyColumn modify column's type
37	ModifyColumn(tableName string, columnName string, typ string) error
38
39	// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
40	LimitAndOffsetSQL(limit, offset interface{}) (string, error)
41	// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
42	SelectFromDummyTable() string
43	// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
44	LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
45	// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
46	LastInsertIDReturningSuffix(tableName, columnName string) string
47	// DefaultValueStr
48	DefaultValueStr() string
49
50	// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
51	BuildKeyName(kind, tableName string, fields ...string) string
52
53	// NormalizeIndexAndColumn returns valid index name and column name depending on each dialect
54	NormalizeIndexAndColumn(indexName, columnName string) (string, string)
55
56	// CurrentDatabase return current database name
57	CurrentDatabase() string
58}
59
60var dialectsMap = map[string]Dialect{}
61
62func newDialect(name string, db SQLCommon) Dialect {
63	if value, ok := dialectsMap[name]; ok {
64		dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
65		dialect.SetDB(db)
66		return dialect
67	}
68
69	fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
70	commontDialect := &commonDialect{}
71	commontDialect.SetDB(db)
72	return commontDialect
73}
74
75// RegisterDialect register new dialect
76func RegisterDialect(name string, dialect Dialect) {
77	dialectsMap[name] = dialect
78}
79
80// GetDialect gets the dialect for the specified dialect name
81func GetDialect(name string) (dialect Dialect, ok bool) {
82	dialect, ok = dialectsMap[name]
83	return
84}
85
86// ParseFieldStructForDialect get field's sql data type
87var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
88	// Get redirected field type
89	var (
90		reflectType = field.Struct.Type
91		dataType, _ = field.TagSettingsGet("TYPE")
92	)
93
94	for reflectType.Kind() == reflect.Ptr {
95		reflectType = reflectType.Elem()
96	}
97
98	// Get redirected field value
99	fieldValue = reflect.Indirect(reflect.New(reflectType))
100
101	if gormDataType, ok := fieldValue.Interface().(interface {
102		GormDataType(Dialect) string
103	}); ok {
104		dataType = gormDataType.GormDataType(dialect)
105	}
106
107	// Get scanner's real value
108	if dataType == "" {
109		var getScannerValue func(reflect.Value)
110		getScannerValue = func(value reflect.Value) {
111			fieldValue = value
112			if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
113				getScannerValue(fieldValue.Field(0))
114			}
115		}
116		getScannerValue(fieldValue)
117	}
118
119	// Default Size
120	if num, ok := field.TagSettingsGet("SIZE"); ok {
121		size, _ = strconv.Atoi(num)
122	} else {
123		size = 255
124	}
125
126	// Default type from tag setting
127	notNull, _ := field.TagSettingsGet("NOT NULL")
128	unique, _ := field.TagSettingsGet("UNIQUE")
129	additionalType = notNull + " " + unique
130	if value, ok := field.TagSettingsGet("DEFAULT"); ok {
131		additionalType = additionalType + " DEFAULT " + value
132	}
133
134	if value, ok := field.TagSettingsGet("COMMENT"); ok {
135		additionalType = additionalType + " COMMENT " + value
136	}
137
138	return fieldValue, dataType, size, strings.TrimSpace(additionalType)
139}
140
141func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) {
142	if strings.Contains(tableName, ".") {
143		splitStrings := strings.SplitN(tableName, ".", 2)
144		return splitStrings[0], splitStrings[1]
145	}
146	return dialect.CurrentDatabase(), tableName
147}
148