1package gorm
2
3import (
4	"database/sql"
5	"database/sql/driver"
6	"reflect"
7	"strings"
8	"time"
9
10	"gorm.io/gorm/schema"
11)
12
13func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
14	if db.Statement.Schema != nil {
15		for idx, name := range columns {
16			if field := db.Statement.Schema.LookUpField(name); field != nil {
17				values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
18				continue
19			}
20			values[idx] = new(interface{})
21		}
22	} else if len(columnTypes) > 0 {
23		for idx, columnType := range columnTypes {
24			if columnType.ScanType() != nil {
25				values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
26			} else {
27				values[idx] = new(interface{})
28			}
29		}
30	} else {
31		for idx := range columns {
32			values[idx] = new(interface{})
33		}
34	}
35}
36
37func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
38	for idx, column := range columns {
39		if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
40			mapValue[column] = reflectValue.Interface()
41			if valuer, ok := mapValue[column].(driver.Valuer); ok {
42				mapValue[column], _ = valuer.Value()
43			} else if b, ok := mapValue[column].(sql.RawBytes); ok {
44				mapValue[column] = string(b)
45			}
46		} else {
47			mapValue[column] = nil
48		}
49	}
50}
51
52func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
53	for idx, column := range columns {
54		if sch == nil {
55			values[idx] = reflectValue.Interface()
56		} else if field := sch.LookUpField(column); field != nil && field.Readable {
57			values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
58		} else if names := strings.Split(column, "__"); len(names) > 1 {
59			if rel, ok := sch.Relationships.Relations[names[0]]; ok {
60				if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
61					values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
62					continue
63				}
64			}
65			values[idx] = &sql.RawBytes{}
66		} else if len(columns) == 1 {
67			sch = nil
68			values[idx] = reflectValue.Interface()
69		} else {
70			values[idx] = &sql.RawBytes{}
71		}
72	}
73
74	db.RowsAffected++
75	db.AddError(rows.Scan(values...))
76
77	if sch != nil {
78		for idx, column := range columns {
79			if field := sch.LookUpField(column); field != nil && field.Readable {
80				field.Set(reflectValue, values[idx])
81			} else if names := strings.Split(column, "__"); len(names) > 1 {
82				if rel, ok := sch.Relationships.Relations[names[0]]; ok {
83					if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
84						relValue := rel.Field.ReflectValueOf(reflectValue)
85						value := reflect.ValueOf(values[idx]).Elem()
86
87						if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
88							if value.IsNil() {
89								continue
90							}
91							relValue.Set(reflect.New(relValue.Type().Elem()))
92						}
93
94						field.Set(relValue, values[idx])
95					}
96				}
97			}
98		}
99	}
100}
101
102type ScanMode uint8
103
104const (
105	ScanInitialized         ScanMode = 1 << 0
106	ScanUpdate                       = 1 << 1
107	ScanOnConflictDoNothing          = 1 << 2
108)
109
110func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
111	var (
112		columns, _          = rows.Columns()
113		values              = make([]interface{}, len(columns))
114		initialized         = mode&ScanInitialized != 0
115		update              = mode&ScanUpdate != 0
116		onConflictDonothing = mode&ScanOnConflictDoNothing != 0
117	)
118
119	db.RowsAffected = 0
120
121	switch dest := db.Statement.Dest.(type) {
122	case map[string]interface{}, *map[string]interface{}:
123		if initialized || rows.Next() {
124			columnTypes, _ := rows.ColumnTypes()
125			prepareValues(values, db, columnTypes, columns)
126
127			db.RowsAffected++
128			db.AddError(rows.Scan(values...))
129
130			mapValue, ok := dest.(map[string]interface{})
131			if !ok {
132				if v, ok := dest.(*map[string]interface{}); ok {
133					mapValue = *v
134				}
135			}
136			scanIntoMap(mapValue, values, columns)
137		}
138	case *[]map[string]interface{}, []map[string]interface{}:
139		columnTypes, _ := rows.ColumnTypes()
140		for initialized || rows.Next() {
141			prepareValues(values, db, columnTypes, columns)
142
143			initialized = false
144			db.RowsAffected++
145			db.AddError(rows.Scan(values...))
146
147			mapValue := map[string]interface{}{}
148			scanIntoMap(mapValue, values, columns)
149			if values, ok := dest.([]map[string]interface{}); ok {
150				values = append(values, mapValue)
151			} else if values, ok := dest.(*[]map[string]interface{}); ok {
152				*values = append(*values, mapValue)
153			}
154		}
155	case *int, *int8, *int16, *int32, *int64,
156		*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
157		*float32, *float64,
158		*bool, *string, *time.Time,
159		*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
160		*sql.NullBool, *sql.NullString, *sql.NullTime:
161		for initialized || rows.Next() {
162			initialized = false
163			db.RowsAffected++
164			db.AddError(rows.Scan(dest))
165		}
166	default:
167		var (
168			fields       = make([]*schema.Field, len(columns))
169			joinFields   [][2]*schema.Field
170			sch          = db.Statement.Schema
171			reflectValue = db.Statement.ReflectValue
172		)
173
174		if reflectValue.Kind() == reflect.Interface {
175			reflectValue = reflectValue.Elem()
176		}
177
178		reflectValueType := reflectValue.Type()
179		switch reflectValueType.Kind() {
180		case reflect.Array, reflect.Slice:
181			reflectValueType = reflectValueType.Elem()
182		}
183		isPtr := reflectValueType.Kind() == reflect.Ptr
184		if isPtr {
185			reflectValueType = reflectValueType.Elem()
186		}
187
188		if sch != nil {
189			if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
190				sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
191			}
192
193			for idx, column := range columns {
194				if field := sch.LookUpField(column); field != nil && field.Readable {
195					fields[idx] = field
196				} else if names := strings.Split(column, "__"); len(names) > 1 {
197					if rel, ok := sch.Relationships.Relations[names[0]]; ok {
198						if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
199							fields[idx] = field
200
201							if len(joinFields) == 0 {
202								joinFields = make([][2]*schema.Field, len(columns))
203							}
204							joinFields[idx] = [2]*schema.Field{rel.Field, field}
205							continue
206						}
207					}
208					values[idx] = &sql.RawBytes{}
209				} else {
210					values[idx] = &sql.RawBytes{}
211				}
212			}
213
214			if len(columns) == 1 {
215				// isPluck
216				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
217					reflectValueType.Kind() != reflect.Struct || // is not struct
218					sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
219					sch = nil
220				}
221			}
222		}
223
224		switch reflectValue.Kind() {
225		case reflect.Slice, reflect.Array:
226			var elem reflect.Value
227
228			if !update || reflectValue.Len() == 0 {
229				update = false
230				db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
231			}
232
233			for initialized || rows.Next() {
234			BEGIN:
235				initialized = false
236
237				if update {
238					if int(db.RowsAffected) >= reflectValue.Len() {
239						return
240					}
241					elem = reflectValue.Index(int(db.RowsAffected))
242					if onConflictDonothing {
243						for _, field := range fields {
244							if _, ok := field.ValueOf(elem); !ok {
245								db.RowsAffected++
246								goto BEGIN
247							}
248						}
249					}
250				} else {
251					elem = reflect.New(reflectValueType)
252				}
253
254				db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
255
256				if !update {
257					if isPtr {
258						reflectValue = reflect.Append(reflectValue, elem)
259					} else {
260						reflectValue = reflect.Append(reflectValue, elem.Elem())
261					}
262				}
263			}
264
265			if !update {
266				db.Statement.ReflectValue.Set(reflectValue)
267			}
268		case reflect.Struct, reflect.Ptr:
269			if initialized || rows.Next() {
270				db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
271			}
272		default:
273			db.AddError(rows.Scan(dest))
274		}
275	}
276
277	if err := rows.Err(); err != nil && err != db.Error {
278		db.AddError(err)
279	}
280
281	if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
282		db.AddError(ErrRecordNotFound)
283	}
284}
285