1package gorm 2 3import ( 4 "errors" 5 "fmt" 6 "reflect" 7 "strconv" 8 "strings" 9) 10 11// preloadCallback used to preload associations 12func preloadCallback(scope *Scope) { 13 if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { 14 return 15 } 16 17 if ap, ok := scope.Get("gorm:auto_preload"); ok { 18 // If gorm:auto_preload IS NOT a bool then auto preload. 19 // Else if it IS a bool, use the value 20 if apb, ok := ap.(bool); !ok { 21 autoPreload(scope) 22 } else if apb { 23 autoPreload(scope) 24 } 25 } 26 27 if scope.Search.preload == nil || scope.HasError() { 28 return 29 } 30 31 var ( 32 preloadedMap = map[string]bool{} 33 fields = scope.Fields() 34 ) 35 36 for _, preload := range scope.Search.preload { 37 var ( 38 preloadFields = strings.Split(preload.schema, ".") 39 currentScope = scope 40 currentFields = fields 41 ) 42 43 for idx, preloadField := range preloadFields { 44 var currentPreloadConditions []interface{} 45 46 if currentScope == nil { 47 continue 48 } 49 50 // if not preloaded 51 if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { 52 53 // assign search conditions to last preload 54 if idx == len(preloadFields)-1 { 55 currentPreloadConditions = preload.conditions 56 } 57 58 for _, field := range currentFields { 59 if field.Name != preloadField || field.Relationship == nil { 60 continue 61 } 62 63 switch field.Relationship.Kind { 64 case "has_one": 65 currentScope.handleHasOnePreload(field, currentPreloadConditions) 66 case "has_many": 67 currentScope.handleHasManyPreload(field, currentPreloadConditions) 68 case "belongs_to": 69 currentScope.handleBelongsToPreload(field, currentPreloadConditions) 70 case "many_to_many": 71 currentScope.handleManyToManyPreload(field, currentPreloadConditions) 72 default: 73 scope.Err(errors.New("unsupported relation")) 74 } 75 76 preloadedMap[preloadKey] = true 77 break 78 } 79 80 if !preloadedMap[preloadKey] { 81 scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) 82 return 83 } 84 } 85 86 // preload next level 87 if idx < len(preloadFields)-1 { 88 currentScope = currentScope.getColumnAsScope(preloadField) 89 if currentScope != nil { 90 currentFields = currentScope.Fields() 91 } 92 } 93 } 94 } 95} 96 97func autoPreload(scope *Scope) { 98 for _, field := range scope.Fields() { 99 if field.Relationship == nil { 100 continue 101 } 102 103 if val, ok := field.TagSettingsGet("PRELOAD"); ok { 104 if preload, err := strconv.ParseBool(val); err != nil { 105 scope.Err(errors.New("invalid preload option")) 106 return 107 } else if !preload { 108 continue 109 } 110 } 111 112 scope.Search.Preload(field.Name) 113 } 114} 115 116func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { 117 var ( 118 preloadDB = scope.NewDB() 119 preloadConditions []interface{} 120 ) 121 122 for _, condition := range conditions { 123 if scopes, ok := condition.(func(*DB) *DB); ok { 124 preloadDB = scopes(preloadDB) 125 } else { 126 preloadConditions = append(preloadConditions, condition) 127 } 128 } 129 130 return preloadDB, preloadConditions 131} 132 133// handleHasOnePreload used to preload has one associations 134func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { 135 relation := field.Relationship 136 137 // get relations's primary keys 138 primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) 139 if len(primaryKeys) == 0 { 140 return 141 } 142 143 // preload conditions 144 preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) 145 146 // find relations 147 query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) 148 values := toQueryValues(primaryKeys) 149 if relation.PolymorphicType != "" { 150 query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) 151 values = append(values, relation.PolymorphicValue) 152 } 153 154 results := makeSlice(field.Struct.Type) 155 scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) 156 157 // assign find results 158 var ( 159 resultsValue = indirect(reflect.ValueOf(results)) 160 indirectScopeValue = scope.IndirectValue() 161 ) 162 163 if indirectScopeValue.Kind() == reflect.Slice { 164 foreignValuesToResults := make(map[string]reflect.Value) 165 for i := 0; i < resultsValue.Len(); i++ { 166 result := resultsValue.Index(i) 167 foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) 168 foreignValuesToResults[foreignValues] = result 169 } 170 for j := 0; j < indirectScopeValue.Len(); j++ { 171 indirectValue := indirect(indirectScopeValue.Index(j)) 172 valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) 173 if result, found := foreignValuesToResults[valueString]; found { 174 indirectValue.FieldByName(field.Name).Set(result) 175 } 176 } 177 } else { 178 for i := 0; i < resultsValue.Len(); i++ { 179 result := resultsValue.Index(i) 180 scope.Err(field.Set(result)) 181 } 182 } 183} 184 185// handleHasManyPreload used to preload has many associations 186func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { 187 relation := field.Relationship 188 189 // get relations's primary keys 190 primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) 191 if len(primaryKeys) == 0 { 192 return 193 } 194 195 // preload conditions 196 preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) 197 198 // find relations 199 query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) 200 values := toQueryValues(primaryKeys) 201 if relation.PolymorphicType != "" { 202 query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) 203 values = append(values, relation.PolymorphicValue) 204 } 205 206 results := makeSlice(field.Struct.Type) 207 scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) 208 209 // assign find results 210 var ( 211 resultsValue = indirect(reflect.ValueOf(results)) 212 indirectScopeValue = scope.IndirectValue() 213 ) 214 215 if indirectScopeValue.Kind() == reflect.Slice { 216 preloadMap := make(map[string][]reflect.Value) 217 for i := 0; i < resultsValue.Len(); i++ { 218 result := resultsValue.Index(i) 219 foreignValues := getValueFromFields(result, relation.ForeignFieldNames) 220 preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) 221 } 222 223 for j := 0; j < indirectScopeValue.Len(); j++ { 224 object := indirect(indirectScopeValue.Index(j)) 225 objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) 226 f := object.FieldByName(field.Name) 227 if results, ok := preloadMap[toString(objectRealValue)]; ok { 228 f.Set(reflect.Append(f, results...)) 229 } else { 230 f.Set(reflect.MakeSlice(f.Type(), 0, 0)) 231 } 232 } 233 } else { 234 scope.Err(field.Set(resultsValue)) 235 } 236} 237 238// handleBelongsToPreload used to preload belongs to associations 239func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { 240 relation := field.Relationship 241 242 // preload conditions 243 preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) 244 245 // get relations's primary keys 246 primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) 247 if len(primaryKeys) == 0 { 248 return 249 } 250 251 // find relations 252 results := makeSlice(field.Struct.Type) 253 scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) 254 255 // assign find results 256 var ( 257 resultsValue = indirect(reflect.ValueOf(results)) 258 indirectScopeValue = scope.IndirectValue() 259 ) 260 261 foreignFieldToObjects := make(map[string][]*reflect.Value) 262 if indirectScopeValue.Kind() == reflect.Slice { 263 for j := 0; j < indirectScopeValue.Len(); j++ { 264 object := indirect(indirectScopeValue.Index(j)) 265 valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) 266 foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) 267 } 268 } 269 270 for i := 0; i < resultsValue.Len(); i++ { 271 result := resultsValue.Index(i) 272 if indirectScopeValue.Kind() == reflect.Slice { 273 valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) 274 if objects, found := foreignFieldToObjects[valueString]; found { 275 for _, object := range objects { 276 object.FieldByName(field.Name).Set(result) 277 } 278 } 279 } else { 280 scope.Err(field.Set(result)) 281 } 282 } 283} 284 285// handleManyToManyPreload used to preload many to many associations 286func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { 287 var ( 288 relation = field.Relationship 289 joinTableHandler = relation.JoinTableHandler 290 fieldType = field.Struct.Type.Elem() 291 foreignKeyValue interface{} 292 foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() 293 linkHash = map[string][]reflect.Value{} 294 isPtr bool 295 ) 296 297 if fieldType.Kind() == reflect.Ptr { 298 isPtr = true 299 fieldType = fieldType.Elem() 300 } 301 302 var sourceKeys = []string{} 303 for _, key := range joinTableHandler.SourceForeignKeys() { 304 sourceKeys = append(sourceKeys, key.DBName) 305 } 306 307 // preload conditions 308 preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) 309 310 // generate query with join table 311 newScope := scope.New(reflect.New(fieldType).Interface()) 312 preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) 313 314 if len(preloadDB.search.selects) == 0 { 315 preloadDB = preloadDB.Select("*") 316 } 317 318 preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) 319 320 // preload inline conditions 321 if len(preloadConditions) > 0 { 322 preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) 323 } 324 325 rows, err := preloadDB.Rows() 326 327 if scope.Err(err) != nil { 328 return 329 } 330 defer rows.Close() 331 332 columns, _ := rows.Columns() 333 for rows.Next() { 334 var ( 335 elem = reflect.New(fieldType).Elem() 336 fields = scope.New(elem.Addr().Interface()).Fields() 337 ) 338 339 // register foreign keys in join tables 340 var joinTableFields []*Field 341 for _, sourceKey := range sourceKeys { 342 joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) 343 } 344 345 scope.scan(rows, columns, append(fields, joinTableFields...)) 346 347 scope.New(elem.Addr().Interface()). 348 InstanceSet("gorm:skip_query_callback", true). 349 callCallbacks(scope.db.parent.callbacks.queries) 350 351 var foreignKeys = make([]interface{}, len(sourceKeys)) 352 // generate hashed forkey keys in join table 353 for idx, joinTableField := range joinTableFields { 354 if !joinTableField.Field.IsNil() { 355 foreignKeys[idx] = joinTableField.Field.Elem().Interface() 356 } 357 } 358 hashedSourceKeys := toString(foreignKeys) 359 360 if isPtr { 361 linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) 362 } else { 363 linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) 364 } 365 } 366 367 if err := rows.Err(); err != nil { 368 scope.Err(err) 369 } 370 371 // assign find results 372 var ( 373 indirectScopeValue = scope.IndirectValue() 374 fieldsSourceMap = map[string][]reflect.Value{} 375 foreignFieldNames = []string{} 376 ) 377 378 for _, dbName := range relation.ForeignFieldNames { 379 if field, ok := scope.FieldByName(dbName); ok { 380 foreignFieldNames = append(foreignFieldNames, field.Name) 381 } 382 } 383 384 if indirectScopeValue.Kind() == reflect.Slice { 385 for j := 0; j < indirectScopeValue.Len(); j++ { 386 object := indirect(indirectScopeValue.Index(j)) 387 key := toString(getValueFromFields(object, foreignFieldNames)) 388 fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) 389 } 390 } else if indirectScopeValue.IsValid() { 391 key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) 392 fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) 393 } 394 395 for source, fields := range fieldsSourceMap { 396 for _, f := range fields { 397 //If not 0 this means Value is a pointer and we already added preloaded models to it 398 if f.Len() != 0 { 399 continue 400 } 401 402 v := reflect.MakeSlice(f.Type(), 0, 0) 403 if len(linkHash[source]) > 0 { 404 v = reflect.Append(f, linkHash[source]...) 405 } 406 407 f.Set(v) 408 } 409 } 410} 411