1package schema_test
2
3import (
4	"fmt"
5	"reflect"
6	"strings"
7	"testing"
8
9	"gorm.io/gorm/schema"
10	"gorm.io/gorm/utils/tests"
11)
12
13func checkSchema(t *testing.T, s *schema.Schema, v schema.Schema, primaryFields []string) {
14	t.Run("CheckSchema/"+s.Name, func(t *testing.T) {
15		tests.AssertObjEqual(t, s, v, "Name", "Table")
16
17		for idx, field := range primaryFields {
18			var found bool
19			for _, f := range s.PrimaryFields {
20				if f.Name == field {
21					found = true
22				}
23			}
24
25			if idx == 0 {
26				if field != s.PrioritizedPrimaryField.Name {
27					t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name)
28				}
29			}
30
31			if !found {
32				t.Errorf("schema %v failed to found primary key: %v", s, field)
33			}
34		}
35	})
36}
37
38func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) {
39	t.Run("CheckField/"+f.Name, func(t *testing.T) {
40		if fc != nil {
41			fc(f)
42		}
43
44		if f.TagSettings == nil {
45			if f.Tag != "" {
46				f.TagSettings = schema.ParseTagSetting(f.Tag.Get("gorm"), ";")
47			} else {
48				f.TagSettings = map[string]string{}
49			}
50		}
51
52		parsedField, ok := s.FieldsByDBName[f.DBName]
53		if !ok {
54			parsedField, ok = s.FieldsByName[f.Name]
55		}
56
57		if !ok {
58			t.Errorf("schema %v failed to look up field with name %v", s, f.Name)
59		} else {
60			tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings")
61
62			if f.DBName != "" {
63				if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field {
64					t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
65				}
66			}
67
68			for _, name := range []string{f.DBName, f.Name} {
69				if name != "" {
70					if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) {
71						t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName)
72					}
73				}
74			}
75
76			if f.PrimaryKey {
77				var found bool
78				for _, primaryField := range s.PrimaryFields {
79					if primaryField == parsedField {
80						found = true
81					}
82				}
83
84				if !found {
85					t.Errorf("schema %v doesn't include field %v", s, f.Name)
86				}
87			}
88		}
89	})
90}
91
92type Relation struct {
93	Name        string
94	Type        schema.RelationshipType
95	Schema      string
96	FieldSchema string
97	Polymorphic Polymorphic
98	JoinTable   JoinTable
99	References  []Reference
100}
101
102type Polymorphic struct {
103	ID    string
104	Type  string
105	Value string
106}
107
108type JoinTable struct {
109	Name   string
110	Table  string
111	Fields []schema.Field
112}
113
114type Reference struct {
115	PrimaryKey    string
116	PrimarySchema string
117	ForeignKey    string
118	ForeignSchema string
119	PrimaryValue  string
120	OwnPrimaryKey bool
121}
122
123func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
124	t.Run("CheckRelation/"+relation.Name, func(t *testing.T) {
125		if r, ok := s.Relationships.Relations[relation.Name]; ok {
126			if r.Name != relation.Name {
127				t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name)
128			}
129
130			if r.Type != relation.Type {
131				t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type)
132			}
133
134			if r.Schema.Name != relation.Schema {
135				t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name)
136			}
137
138			if r.FieldSchema.Name != relation.FieldSchema {
139				t.Errorf("schema %v field relation's schema expects %v, but got %v", s, relation.FieldSchema, r.FieldSchema.Name)
140			}
141
142			if r.Polymorphic != nil {
143				if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID {
144					t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name)
145				}
146
147				if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type {
148					t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name)
149				}
150
151				if r.Polymorphic.Value != relation.Polymorphic.Value {
152					t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value)
153				}
154			}
155
156			if r.JoinTable != nil {
157				if r.JoinTable.Name != relation.JoinTable.Name {
158					t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name)
159				}
160
161				if r.JoinTable.Table != relation.JoinTable.Table {
162					t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table)
163				}
164
165				for _, f := range relation.JoinTable.Fields {
166					checkSchemaField(t, r.JoinTable, &f, nil)
167				}
168			}
169
170			if len(relation.References) != len(r.References) {
171				t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References))
172			}
173
174			for _, ref := range relation.References {
175				var found bool
176				for _, rf := range r.References {
177					if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) {
178						found = true
179					}
180				}
181
182				if !found {
183					var refs []string
184					for _, rf := range r.References {
185						var primaryKey, primaryKeySchema string
186						if rf.PrimaryKey != nil {
187							primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name
188						}
189						refs = append(refs, fmt.Sprintf(
190							"{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}",
191							primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey,
192						))
193					}
194					t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", "))
195				}
196			}
197		} else {
198			t.Errorf("schema %v failed to find relations by name %v", s, relation.Name)
199		}
200	})
201}
202
203func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
204	for k, v := range values {
205		t.Run("CheckField/"+k, func(t *testing.T) {
206			fv, _ := s.FieldsByDBName[k].ValueOf(value)
207			tests.AssertEqual(t, v, fv)
208		})
209	}
210}
211