1package gorm 2 3import ( 4 "database/sql" 5 "database/sql/driver" 6 "reflect" 7 "strings" 8 "time" 9 10 "gorm.io/gorm/schema" 11) 12 13func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { 14 if db.Statement.Schema != nil { 15 for idx, name := range columns { 16 if field := db.Statement.Schema.LookUpField(name); field != nil { 17 values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() 18 continue 19 } 20 values[idx] = new(interface{}) 21 } 22 } else if len(columnTypes) > 0 { 23 for idx, columnType := range columnTypes { 24 if columnType.ScanType() != nil { 25 values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() 26 } else { 27 values[idx] = new(interface{}) 28 } 29 } 30 } else { 31 for idx := range columns { 32 values[idx] = new(interface{}) 33 } 34 } 35} 36 37func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { 38 for idx, column := range columns { 39 if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { 40 mapValue[column] = reflectValue.Interface() 41 if valuer, ok := mapValue[column].(driver.Valuer); ok { 42 mapValue[column], _ = valuer.Value() 43 } else if b, ok := mapValue[column].(sql.RawBytes); ok { 44 mapValue[column] = string(b) 45 } 46 } else { 47 mapValue[column] = nil 48 } 49 } 50} 51 52func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { 53 for idx, column := range columns { 54 if sch == nil { 55 values[idx] = reflectValue.Interface() 56 } else if field := sch.LookUpField(column); field != nil && field.Readable { 57 values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() 58 } else if names := strings.Split(column, "__"); len(names) > 1 { 59 if rel, ok := sch.Relationships.Relations[names[0]]; ok { 60 if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { 61 values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() 62 continue 63 } 64 } 65 values[idx] = &sql.RawBytes{} 66 } else if len(columns) == 1 { 67 sch = nil 68 values[idx] = reflectValue.Interface() 69 } else { 70 values[idx] = &sql.RawBytes{} 71 } 72 } 73 74 db.RowsAffected++ 75 db.AddError(rows.Scan(values...)) 76 77 if sch != nil { 78 for idx, column := range columns { 79 if field := sch.LookUpField(column); field != nil && field.Readable { 80 field.Set(reflectValue, values[idx]) 81 } else if names := strings.Split(column, "__"); len(names) > 1 { 82 if rel, ok := sch.Relationships.Relations[names[0]]; ok { 83 if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { 84 relValue := rel.Field.ReflectValueOf(reflectValue) 85 value := reflect.ValueOf(values[idx]).Elem() 86 87 if relValue.Kind() == reflect.Ptr && relValue.IsNil() { 88 if value.IsNil() { 89 continue 90 } 91 relValue.Set(reflect.New(relValue.Type().Elem())) 92 } 93 94 field.Set(relValue, values[idx]) 95 } 96 } 97 } 98 } 99 } 100} 101 102type ScanMode uint8 103 104const ( 105 ScanInitialized ScanMode = 1 << 0 106 ScanUpdate = 1 << 1 107 ScanOnConflictDoNothing = 1 << 2 108) 109 110func Scan(rows *sql.Rows, db *DB, mode ScanMode) { 111 var ( 112 columns, _ = rows.Columns() 113 values = make([]interface{}, len(columns)) 114 initialized = mode&ScanInitialized != 0 115 update = mode&ScanUpdate != 0 116 onConflictDonothing = mode&ScanOnConflictDoNothing != 0 117 ) 118 119 db.RowsAffected = 0 120 121 switch dest := db.Statement.Dest.(type) { 122 case map[string]interface{}, *map[string]interface{}: 123 if initialized || rows.Next() { 124 columnTypes, _ := rows.ColumnTypes() 125 prepareValues(values, db, columnTypes, columns) 126 127 db.RowsAffected++ 128 db.AddError(rows.Scan(values...)) 129 130 mapValue, ok := dest.(map[string]interface{}) 131 if !ok { 132 if v, ok := dest.(*map[string]interface{}); ok { 133 mapValue = *v 134 } 135 } 136 scanIntoMap(mapValue, values, columns) 137 } 138 case *[]map[string]interface{}, []map[string]interface{}: 139 columnTypes, _ := rows.ColumnTypes() 140 for initialized || rows.Next() { 141 prepareValues(values, db, columnTypes, columns) 142 143 initialized = false 144 db.RowsAffected++ 145 db.AddError(rows.Scan(values...)) 146 147 mapValue := map[string]interface{}{} 148 scanIntoMap(mapValue, values, columns) 149 if values, ok := dest.([]map[string]interface{}); ok { 150 values = append(values, mapValue) 151 } else if values, ok := dest.(*[]map[string]interface{}); ok { 152 *values = append(*values, mapValue) 153 } 154 } 155 case *int, *int8, *int16, *int32, *int64, 156 *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, 157 *float32, *float64, 158 *bool, *string, *time.Time, 159 *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, 160 *sql.NullBool, *sql.NullString, *sql.NullTime: 161 for initialized || rows.Next() { 162 initialized = false 163 db.RowsAffected++ 164 db.AddError(rows.Scan(dest)) 165 } 166 default: 167 var ( 168 fields = make([]*schema.Field, len(columns)) 169 joinFields [][2]*schema.Field 170 sch = db.Statement.Schema 171 reflectValue = db.Statement.ReflectValue 172 ) 173 174 if reflectValue.Kind() == reflect.Interface { 175 reflectValue = reflectValue.Elem() 176 } 177 178 reflectValueType := reflectValue.Type() 179 switch reflectValueType.Kind() { 180 case reflect.Array, reflect.Slice: 181 reflectValueType = reflectValueType.Elem() 182 } 183 isPtr := reflectValueType.Kind() == reflect.Ptr 184 if isPtr { 185 reflectValueType = reflectValueType.Elem() 186 } 187 188 if sch != nil { 189 if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { 190 sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) 191 } 192 193 for idx, column := range columns { 194 if field := sch.LookUpField(column); field != nil && field.Readable { 195 fields[idx] = field 196 } else if names := strings.Split(column, "__"); len(names) > 1 { 197 if rel, ok := sch.Relationships.Relations[names[0]]; ok { 198 if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { 199 fields[idx] = field 200 201 if len(joinFields) == 0 { 202 joinFields = make([][2]*schema.Field, len(columns)) 203 } 204 joinFields[idx] = [2]*schema.Field{rel.Field, field} 205 continue 206 } 207 } 208 values[idx] = &sql.RawBytes{} 209 } else { 210 values[idx] = &sql.RawBytes{} 211 } 212 } 213 214 if len(columns) == 1 { 215 // isPluck 216 if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner 217 reflectValueType.Kind() != reflect.Struct || // is not struct 218 sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time 219 sch = nil 220 } 221 } 222 } 223 224 switch reflectValue.Kind() { 225 case reflect.Slice, reflect.Array: 226 var elem reflect.Value 227 228 if !update || reflectValue.Len() == 0 { 229 update = false 230 db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) 231 } 232 233 for initialized || rows.Next() { 234 BEGIN: 235 initialized = false 236 237 if update { 238 if int(db.RowsAffected) >= reflectValue.Len() { 239 return 240 } 241 elem = reflectValue.Index(int(db.RowsAffected)) 242 if onConflictDonothing { 243 for _, field := range fields { 244 if _, ok := field.ValueOf(elem); !ok { 245 db.RowsAffected++ 246 goto BEGIN 247 } 248 } 249 } 250 } else { 251 elem = reflect.New(reflectValueType) 252 } 253 254 db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) 255 256 if !update { 257 if isPtr { 258 reflectValue = reflect.Append(reflectValue, elem) 259 } else { 260 reflectValue = reflect.Append(reflectValue, elem.Elem()) 261 } 262 } 263 } 264 265 if !update { 266 db.Statement.ReflectValue.Set(reflectValue) 267 } 268 case reflect.Struct, reflect.Ptr: 269 if initialized || rows.Next() { 270 db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) 271 } 272 default: 273 db.AddError(rows.Scan(dest)) 274 } 275 } 276 277 if err := rows.Err(); err != nil && err != db.Error { 278 db.AddError(err) 279 } 280 281 if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { 282 db.AddError(ErrRecordNotFound) 283 } 284} 285