1package resolvable
2
3import (
4	"context"
5	"fmt"
6	"reflect"
7	"strings"
8
9	"github.com/graph-gophers/graphql-go/internal/common"
10	"github.com/graph-gophers/graphql-go/internal/exec/packer"
11	"github.com/graph-gophers/graphql-go/internal/schema"
12)
13
14type Schema struct {
15	*Meta
16	schema.Schema
17	Query        Resolvable
18	Mutation     Resolvable
19	Subscription Resolvable
20	Resolver     reflect.Value
21}
22
23type Resolvable interface {
24	isResolvable()
25}
26
27type Object struct {
28	Name           string
29	Fields         map[string]*Field
30	TypeAssertions map[string]*TypeAssertion
31}
32
33type Field struct {
34	schema.Field
35	TypeName    string
36	MethodIndex int
37	FieldIndex  []int
38	HasContext  bool
39	HasError    bool
40	ArgsPacker  *packer.StructPacker
41	ValueExec   Resolvable
42	TraceLabel  string
43}
44
45func (f *Field) UseMethodResolver() bool {
46	return len(f.FieldIndex) == 0
47}
48
49type TypeAssertion struct {
50	MethodIndex int
51	TypeExec    Resolvable
52}
53
54type List struct {
55	Elem Resolvable
56}
57
58type Scalar struct{}
59
60func (*Object) isResolvable() {}
61func (*List) isResolvable()   {}
62func (*Scalar) isResolvable() {}
63
64func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) {
65	if resolver == nil {
66		return &Schema{Meta: newMeta(s), Schema: *s}, nil
67	}
68
69	b := newBuilder(s)
70
71	var query, mutation, subscription Resolvable
72
73	if t, ok := s.EntryPoints["query"]; ok {
74		if err := b.assignExec(&query, t, reflect.TypeOf(resolver)); err != nil {
75			return nil, err
76		}
77	}
78
79	if t, ok := s.EntryPoints["mutation"]; ok {
80		if err := b.assignExec(&mutation, t, reflect.TypeOf(resolver)); err != nil {
81			return nil, err
82		}
83	}
84
85	if t, ok := s.EntryPoints["subscription"]; ok {
86		if err := b.assignExec(&subscription, t, reflect.TypeOf(resolver)); err != nil {
87			return nil, err
88		}
89	}
90
91	if err := b.finish(); err != nil {
92		return nil, err
93	}
94
95	return &Schema{
96		Meta:         newMeta(s),
97		Schema:       *s,
98		Resolver:     reflect.ValueOf(resolver),
99		Query:        query,
100		Mutation:     mutation,
101		Subscription: subscription,
102	}, nil
103}
104
105type execBuilder struct {
106	schema        *schema.Schema
107	resMap        map[typePair]*resMapEntry
108	packerBuilder *packer.Builder
109}
110
111type typePair struct {
112	graphQLType  common.Type
113	resolverType reflect.Type
114}
115
116type resMapEntry struct {
117	exec    Resolvable
118	targets []*Resolvable
119}
120
121func newBuilder(s *schema.Schema) *execBuilder {
122	return &execBuilder{
123		schema:        s,
124		resMap:        make(map[typePair]*resMapEntry),
125		packerBuilder: packer.NewBuilder(),
126	}
127}
128
129func (b *execBuilder) finish() error {
130	for _, entry := range b.resMap {
131		for _, target := range entry.targets {
132			*target = entry.exec
133		}
134	}
135
136	return b.packerBuilder.Finish()
137}
138
139func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType reflect.Type) error {
140	k := typePair{t, resolverType}
141	ref, ok := b.resMap[k]
142	if !ok {
143		ref = &resMapEntry{}
144		b.resMap[k] = ref
145		var err error
146		ref.exec, err = b.makeExec(t, resolverType)
147		if err != nil {
148			return err
149		}
150	}
151	ref.targets = append(ref.targets, target)
152	return nil
153}
154
155func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type) (Resolvable, error) {
156	var nonNull bool
157	t, nonNull = unwrapNonNull(t)
158
159	switch t := t.(type) {
160	case *schema.Object:
161		return b.makeObjectExec(t.Name, t.Fields, nil, nonNull, resolverType)
162
163	case *schema.Interface:
164		return b.makeObjectExec(t.Name, t.Fields, t.PossibleTypes, nonNull, resolverType)
165
166	case *schema.Union:
167		return b.makeObjectExec(t.Name, nil, t.PossibleTypes, nonNull, resolverType)
168	}
169
170	if !nonNull {
171		if resolverType.Kind() != reflect.Ptr {
172			return nil, fmt.Errorf("%s is not a pointer", resolverType)
173		}
174		resolverType = resolverType.Elem()
175	}
176
177	switch t := t.(type) {
178	case *schema.Scalar:
179		return makeScalarExec(t, resolverType)
180
181	case *schema.Enum:
182		return &Scalar{}, nil
183
184	case *common.List:
185		if resolverType.Kind() != reflect.Slice {
186			return nil, fmt.Errorf("%s is not a slice", resolverType)
187		}
188		e := &List{}
189		if err := b.assignExec(&e.Elem, t.OfType, resolverType.Elem()); err != nil {
190			return nil, err
191		}
192		return e, nil
193
194	default:
195		panic("invalid type: " + t.String())
196	}
197}
198
199func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, error) {
200	implementsType := false
201	switch r := reflect.New(resolverType).Interface().(type) {
202	case *int32:
203		implementsType = t.Name == "Int"
204	case *float64:
205		implementsType = t.Name == "Float"
206	case *string:
207		implementsType = t.Name == "String"
208	case *bool:
209		implementsType = t.Name == "Boolean"
210	case packer.Unmarshaler:
211		implementsType = r.ImplementsGraphQLType(t.Name)
212	}
213	if !implementsType {
214		return nil, fmt.Errorf("can not use %s as %s", resolverType, t.Name)
215	}
216	return &Scalar{}, nil
217}
218
219func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object,
220	nonNull bool, resolverType reflect.Type) (*Object, error) {
221	if !nonNull {
222		if resolverType.Kind() != reflect.Ptr && resolverType.Kind() != reflect.Interface {
223			return nil, fmt.Errorf("%s is not a pointer or interface", resolverType)
224		}
225	}
226
227	methodHasReceiver := resolverType.Kind() != reflect.Interface
228
229	Fields := make(map[string]*Field)
230	rt := unwrapPtr(resolverType)
231	fieldsCount := fieldCount(rt, map[string]int{})
232	for _, f := range fields {
233		var fieldIndex []int
234		methodIndex := findMethod(resolverType, f.Name)
235		if b.schema.UseFieldResolvers && methodIndex == -1 {
236			if fieldsCount[strings.ToLower(stripUnderscore(f.Name))] > 1 {
237				return nil, fmt.Errorf("%s does not resolve %q: ambiguous field %q", resolverType, typeName, f.Name)
238			}
239			fieldIndex = findField(rt, f.Name, []int{})
240		}
241		if methodIndex == -1 && len(fieldIndex) == 0 {
242			hint := ""
243			if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 {
244				hint = " (hint: the method exists on the pointer type)"
245			}
246			return nil, fmt.Errorf("%s does not resolve %q: missing method for field %q%s", resolverType, typeName, f.Name, hint)
247		}
248
249		var m reflect.Method
250		var sf reflect.StructField
251		if methodIndex != -1 {
252			m = resolverType.Method(methodIndex)
253		} else {
254			sf = rt.FieldByIndex(fieldIndex)
255		}
256		fe, err := b.makeFieldExec(typeName, f, m, sf, methodIndex, fieldIndex, methodHasReceiver)
257		if err != nil {
258			return nil, fmt.Errorf("%s\n\tused by (%s).%s", err, resolverType, m.Name)
259		}
260		Fields[f.Name] = fe
261	}
262
263	// Check type assertions when
264	//	1) using method resolvers
265	//	2) Or resolver is not an interface type
266	typeAssertions := make(map[string]*TypeAssertion)
267	if !b.schema.UseFieldResolvers || resolverType.Kind() != reflect.Interface {
268		for _, impl := range possibleTypes {
269			methodIndex := findMethod(resolverType, "To"+impl.Name)
270			if methodIndex == -1 {
271				return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "To"+impl.Name, impl.Name)
272			}
273			if resolverType.Method(methodIndex).Type.NumOut() != 2 {
274				return nil, fmt.Errorf("%s does not resolve %q: method %q should return a value and a bool indicating success", resolverType, typeName, "To"+impl.Name)
275			}
276			a := &TypeAssertion{
277				MethodIndex: methodIndex,
278			}
279			if err := b.assignExec(&a.TypeExec, impl, resolverType.Method(methodIndex).Type.Out(0)); err != nil {
280				return nil, err
281			}
282			typeAssertions[impl.Name] = a
283		}
284	}
285
286	return &Object{
287		Name:           typeName,
288		Fields:         Fields,
289		TypeAssertions: typeAssertions,
290	}, nil
291}
292
293var contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
294var errorType = reflect.TypeOf((*error)(nil)).Elem()
295
296func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, sf reflect.StructField,
297	methodIndex int, fieldIndex []int, methodHasReceiver bool) (*Field, error) {
298
299	var argsPacker *packer.StructPacker
300	var hasError bool
301	var hasContext bool
302
303	// Validate resolver method only when there is one
304	if methodIndex != -1 {
305		in := make([]reflect.Type, m.Type.NumIn())
306		for i := range in {
307			in[i] = m.Type.In(i)
308		}
309		if methodHasReceiver {
310			in = in[1:] // first parameter is receiver
311		}
312
313		hasContext = len(in) > 0 && in[0] == contextType
314		if hasContext {
315			in = in[1:]
316		}
317
318		if len(f.Args) > 0 {
319			if len(in) == 0 {
320				return nil, fmt.Errorf("must have parameter for field arguments")
321			}
322			var err error
323			argsPacker, err = b.packerBuilder.MakeStructPacker(f.Args, in[0])
324			if err != nil {
325				return nil, err
326			}
327			in = in[1:]
328		}
329
330		if len(in) > 0 {
331			return nil, fmt.Errorf("too many parameters")
332		}
333
334		maxNumOfReturns := 2
335		if m.Type.NumOut() < maxNumOfReturns-1 {
336			return nil, fmt.Errorf("too few return values")
337		}
338
339		if m.Type.NumOut() > maxNumOfReturns {
340			return nil, fmt.Errorf("too many return values")
341		}
342
343		hasError = m.Type.NumOut() == maxNumOfReturns
344		if hasError {
345			if m.Type.Out(maxNumOfReturns-1) != errorType {
346				return nil, fmt.Errorf(`must have "error" as its last return value`)
347			}
348		}
349	}
350
351	fe := &Field{
352		Field:       *f,
353		TypeName:    typeName,
354		MethodIndex: methodIndex,
355		FieldIndex:  fieldIndex,
356		HasContext:  hasContext,
357		ArgsPacker:  argsPacker,
358		HasError:    hasError,
359		TraceLabel:  fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name),
360	}
361
362	var out reflect.Type
363	if methodIndex != -1 {
364		out = m.Type.Out(0)
365		sub, ok := b.schema.EntryPoints["subscription"]
366		if ok && typeName == sub.TypeName() && out.Kind() == reflect.Chan {
367			out = m.Type.Out(0).Elem()
368		}
369	} else {
370		out = sf.Type
371	}
372	if err := b.assignExec(&fe.ValueExec, f.Type, out); err != nil {
373		return nil, err
374	}
375
376	return fe, nil
377}
378
379func findMethod(t reflect.Type, name string) int {
380	for i := 0; i < t.NumMethod(); i++ {
381		if strings.EqualFold(stripUnderscore(name), stripUnderscore(t.Method(i).Name)) {
382			return i
383		}
384	}
385	return -1
386}
387
388func findField(t reflect.Type, name string, index []int) []int {
389	for i := 0; i < t.NumField(); i++ {
390		field := t.Field(i)
391
392		if field.Type.Kind() == reflect.Struct && field.Anonymous {
393			newIndex := findField(field.Type, name, []int{i})
394			if len(newIndex) > 1 {
395				return append(index, newIndex...)
396			}
397		}
398
399		if strings.EqualFold(stripUnderscore(name), stripUnderscore(field.Name)) {
400			return append(index, i)
401		}
402	}
403
404	return index
405}
406
407// fieldCount helps resolve ambiguity when more than one embedded struct contains fields with the same name.
408func fieldCount(t reflect.Type, count map[string]int) map[string]int {
409	if t.Kind() != reflect.Struct {
410		return nil
411	}
412
413	for i := 0; i < t.NumField(); i++ {
414		field := t.Field(i)
415		fieldName := strings.ToLower(stripUnderscore(field.Name))
416
417		if field.Type.Kind() == reflect.Struct && field.Anonymous {
418			count = fieldCount(field.Type, count)
419		} else {
420			if _, ok := count[fieldName]; !ok {
421				count[fieldName] = 0
422			}
423			count[fieldName]++
424		}
425	}
426
427	return count
428}
429
430func unwrapNonNull(t common.Type) (common.Type, bool) {
431	if nn, ok := t.(*common.NonNull); ok {
432		return nn.OfType, true
433	}
434	return t, false
435}
436
437func stripUnderscore(s string) string {
438	return strings.Replace(s, "_", "", -1)
439}
440
441func unwrapPtr(t reflect.Type) reflect.Type {
442	if t.Kind() == reflect.Ptr {
443		return t.Elem()
444	}
445	return t
446}
447