1package gorm
2
3import (
4	"errors"
5	"fmt"
6	"reflect"
7	"strconv"
8	"strings"
9)
10
11// preloadCallback used to preload associations
12func preloadCallback(scope *Scope) {
13	if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
14		return
15	}
16
17	if ap, ok := scope.Get("gorm:auto_preload"); ok {
18		// If gorm:auto_preload IS NOT a bool then auto preload.
19		// Else if it IS a bool, use the value
20		if apb, ok := ap.(bool); !ok {
21			autoPreload(scope)
22		} else if apb {
23			autoPreload(scope)
24		}
25	}
26
27	if scope.Search.preload == nil || scope.HasError() {
28		return
29	}
30
31	var (
32		preloadedMap = map[string]bool{}
33		fields       = scope.Fields()
34	)
35
36	for _, preload := range scope.Search.preload {
37		var (
38			preloadFields = strings.Split(preload.schema, ".")
39			currentScope  = scope
40			currentFields = fields
41		)
42
43		for idx, preloadField := range preloadFields {
44			var currentPreloadConditions []interface{}
45
46			if currentScope == nil {
47				continue
48			}
49
50			// if not preloaded
51			if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
52
53				// assign search conditions to last preload
54				if idx == len(preloadFields)-1 {
55					currentPreloadConditions = preload.conditions
56				}
57
58				for _, field := range currentFields {
59					if field.Name != preloadField || field.Relationship == nil {
60						continue
61					}
62
63					switch field.Relationship.Kind {
64					case "has_one":
65						currentScope.handleHasOnePreload(field, currentPreloadConditions)
66					case "has_many":
67						currentScope.handleHasManyPreload(field, currentPreloadConditions)
68					case "belongs_to":
69						currentScope.handleBelongsToPreload(field, currentPreloadConditions)
70					case "many_to_many":
71						currentScope.handleManyToManyPreload(field, currentPreloadConditions)
72					default:
73						scope.Err(errors.New("unsupported relation"))
74					}
75
76					preloadedMap[preloadKey] = true
77					break
78				}
79
80				if !preloadedMap[preloadKey] {
81					scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
82					return
83				}
84			}
85
86			// preload next level
87			if idx < len(preloadFields)-1 {
88				currentScope = currentScope.getColumnAsScope(preloadField)
89				if currentScope != nil {
90					currentFields = currentScope.Fields()
91				}
92			}
93		}
94	}
95}
96
97func autoPreload(scope *Scope) {
98	for _, field := range scope.Fields() {
99		if field.Relationship == nil {
100			continue
101		}
102
103		if val, ok := field.TagSettingsGet("PRELOAD"); ok {
104			if preload, err := strconv.ParseBool(val); err != nil {
105				scope.Err(errors.New("invalid preload option"))
106				return
107			} else if !preload {
108				continue
109			}
110		}
111
112		scope.Search.Preload(field.Name)
113	}
114}
115
116func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
117	var (
118		preloadDB         = scope.NewDB()
119		preloadConditions []interface{}
120	)
121
122	for _, condition := range conditions {
123		if scopes, ok := condition.(func(*DB) *DB); ok {
124			preloadDB = scopes(preloadDB)
125		} else {
126			preloadConditions = append(preloadConditions, condition)
127		}
128	}
129
130	return preloadDB, preloadConditions
131}
132
133// handleHasOnePreload used to preload has one associations
134func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
135	relation := field.Relationship
136
137	// get relations's primary keys
138	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
139	if len(primaryKeys) == 0 {
140		return
141	}
142
143	// preload conditions
144	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
145
146	// find relations
147	query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
148	values := toQueryValues(primaryKeys)
149	if relation.PolymorphicType != "" {
150		query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
151		values = append(values, relation.PolymorphicValue)
152	}
153
154	results := makeSlice(field.Struct.Type)
155	scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
156
157	// assign find results
158	var (
159		resultsValue       = indirect(reflect.ValueOf(results))
160		indirectScopeValue = scope.IndirectValue()
161	)
162
163	if indirectScopeValue.Kind() == reflect.Slice {
164		foreignValuesToResults := make(map[string]reflect.Value)
165		for i := 0; i < resultsValue.Len(); i++ {
166			result := resultsValue.Index(i)
167			foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
168			foreignValuesToResults[foreignValues] = result
169		}
170		for j := 0; j < indirectScopeValue.Len(); j++ {
171			indirectValue := indirect(indirectScopeValue.Index(j))
172			valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
173			if result, found := foreignValuesToResults[valueString]; found {
174				indirectValue.FieldByName(field.Name).Set(result)
175			}
176		}
177	} else {
178		for i := 0; i < resultsValue.Len(); i++ {
179			result := resultsValue.Index(i)
180			scope.Err(field.Set(result))
181		}
182	}
183}
184
185// handleHasManyPreload used to preload has many associations
186func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
187	relation := field.Relationship
188
189	// get relations's primary keys
190	primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
191	if len(primaryKeys) == 0 {
192		return
193	}
194
195	// preload conditions
196	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
197
198	// find relations
199	query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
200	values := toQueryValues(primaryKeys)
201	if relation.PolymorphicType != "" {
202		query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
203		values = append(values, relation.PolymorphicValue)
204	}
205
206	results := makeSlice(field.Struct.Type)
207	scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
208
209	// assign find results
210	var (
211		resultsValue       = indirect(reflect.ValueOf(results))
212		indirectScopeValue = scope.IndirectValue()
213	)
214
215	if indirectScopeValue.Kind() == reflect.Slice {
216		preloadMap := make(map[string][]reflect.Value)
217		for i := 0; i < resultsValue.Len(); i++ {
218			result := resultsValue.Index(i)
219			foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
220			preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
221		}
222
223		for j := 0; j < indirectScopeValue.Len(); j++ {
224			object := indirect(indirectScopeValue.Index(j))
225			objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
226			f := object.FieldByName(field.Name)
227			if results, ok := preloadMap[toString(objectRealValue)]; ok {
228				f.Set(reflect.Append(f, results...))
229			} else {
230				f.Set(reflect.MakeSlice(f.Type(), 0, 0))
231			}
232		}
233	} else {
234		scope.Err(field.Set(resultsValue))
235	}
236}
237
238// handleBelongsToPreload used to preload belongs to associations
239func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
240	relation := field.Relationship
241
242	// preload conditions
243	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
244
245	// get relations's primary keys
246	primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
247	if len(primaryKeys) == 0 {
248		return
249	}
250
251	// find relations
252	results := makeSlice(field.Struct.Type)
253	scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
254
255	// assign find results
256	var (
257		resultsValue       = indirect(reflect.ValueOf(results))
258		indirectScopeValue = scope.IndirectValue()
259	)
260
261	foreignFieldToObjects := make(map[string][]*reflect.Value)
262	if indirectScopeValue.Kind() == reflect.Slice {
263		for j := 0; j < indirectScopeValue.Len(); j++ {
264			object := indirect(indirectScopeValue.Index(j))
265			valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
266			foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
267		}
268	}
269
270	for i := 0; i < resultsValue.Len(); i++ {
271		result := resultsValue.Index(i)
272		if indirectScopeValue.Kind() == reflect.Slice {
273			valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
274			if objects, found := foreignFieldToObjects[valueString]; found {
275				for _, object := range objects {
276					object.FieldByName(field.Name).Set(result)
277				}
278			}
279		} else {
280			scope.Err(field.Set(result))
281		}
282	}
283}
284
285// handleManyToManyPreload used to preload many to many associations
286func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
287	var (
288		relation         = field.Relationship
289		joinTableHandler = relation.JoinTableHandler
290		fieldType        = field.Struct.Type.Elem()
291		foreignKeyValue  interface{}
292		foreignKeyType   = reflect.ValueOf(&foreignKeyValue).Type()
293		linkHash         = map[string][]reflect.Value{}
294		isPtr            bool
295	)
296
297	if fieldType.Kind() == reflect.Ptr {
298		isPtr = true
299		fieldType = fieldType.Elem()
300	}
301
302	var sourceKeys = []string{}
303	for _, key := range joinTableHandler.SourceForeignKeys() {
304		sourceKeys = append(sourceKeys, key.DBName)
305	}
306
307	// preload conditions
308	preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
309
310	// generate query with join table
311	newScope := scope.New(reflect.New(fieldType).Interface())
312	preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
313
314	if len(preloadDB.search.selects) == 0 {
315		preloadDB = preloadDB.Select("*")
316	}
317
318	preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
319
320	// preload inline conditions
321	if len(preloadConditions) > 0 {
322		preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
323	}
324
325	rows, err := preloadDB.Rows()
326
327	if scope.Err(err) != nil {
328		return
329	}
330	defer rows.Close()
331
332	columns, _ := rows.Columns()
333	for rows.Next() {
334		var (
335			elem   = reflect.New(fieldType).Elem()
336			fields = scope.New(elem.Addr().Interface()).Fields()
337		)
338
339		// register foreign keys in join tables
340		var joinTableFields []*Field
341		for _, sourceKey := range sourceKeys {
342			joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
343		}
344
345		scope.scan(rows, columns, append(fields, joinTableFields...))
346
347		scope.New(elem.Addr().Interface()).
348			InstanceSet("gorm:skip_query_callback", true).
349			callCallbacks(scope.db.parent.callbacks.queries)
350
351		var foreignKeys = make([]interface{}, len(sourceKeys))
352		// generate hashed forkey keys in join table
353		for idx, joinTableField := range joinTableFields {
354			if !joinTableField.Field.IsNil() {
355				foreignKeys[idx] = joinTableField.Field.Elem().Interface()
356			}
357		}
358		hashedSourceKeys := toString(foreignKeys)
359
360		if isPtr {
361			linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
362		} else {
363			linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
364		}
365	}
366
367	if err := rows.Err(); err != nil {
368		scope.Err(err)
369	}
370
371	// assign find results
372	var (
373		indirectScopeValue = scope.IndirectValue()
374		fieldsSourceMap    = map[string][]reflect.Value{}
375		foreignFieldNames  = []string{}
376	)
377
378	for _, dbName := range relation.ForeignFieldNames {
379		if field, ok := scope.FieldByName(dbName); ok {
380			foreignFieldNames = append(foreignFieldNames, field.Name)
381		}
382	}
383
384	if indirectScopeValue.Kind() == reflect.Slice {
385		for j := 0; j < indirectScopeValue.Len(); j++ {
386			object := indirect(indirectScopeValue.Index(j))
387			key := toString(getValueFromFields(object, foreignFieldNames))
388			fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
389		}
390	} else if indirectScopeValue.IsValid() {
391		key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
392		fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
393	}
394
395	for source, fields := range fieldsSourceMap {
396		for _, f := range fields {
397			//If not 0 this means Value is a pointer and we already added preloaded models to it
398			if f.Len() != 0 {
399				continue
400			}
401
402			v := reflect.MakeSlice(f.Type(), 0, 0)
403			if len(linkHash[source]) > 0 {
404				v = reflect.Append(f, linkHash[source]...)
405			}
406
407			f.Set(v)
408		}
409	}
410}
411