1package graphql_test
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"reflect"
8	"testing"
9
10	"github.com/graphql-go/graphql"
11	"github.com/graphql-go/graphql/gqlerrors"
12	"github.com/graphql-go/graphql/testutil"
13)
14
15func tinit(t *testing.T) graphql.Schema {
16	schema, err := graphql.NewSchema(graphql.SchemaConfig{
17		Query: graphql.NewObject(graphql.ObjectConfig{
18			Name: "Type",
19			Fields: graphql.Fields{
20				"a": &graphql.Field{
21					Type: graphql.String,
22					Resolve: func(p graphql.ResolveParams) (interface{}, error) {
23						return "foo", nil
24					},
25				},
26				"erred": &graphql.Field{
27					Type: graphql.String,
28					Resolve: func(p graphql.ResolveParams) (interface{}, error) {
29						return "", errors.New("ooops")
30					},
31				},
32			},
33		}),
34	})
35	if err != nil {
36		t.Fatalf("Error in schema %v", err.Error())
37	}
38	return schema
39}
40
41func TestExtensionInitPanic(t *testing.T) {
42	ext := newtestExt("testExt")
43	ext.initFn = func(ctx context.Context, p *graphql.Params) context.Context {
44		if true {
45			panic(errors.New("test error"))
46		}
47		return ctx
48	}
49
50	schema := tinit(t)
51	query := `query Example { a }`
52	schema.AddExtensions(ext)
53
54	result := graphql.Do(graphql.Params{
55		Schema:        schema,
56		RequestString: query,
57	})
58
59	expected := &graphql.Result{
60		Data: nil,
61		Errors: []gqlerrors.FormattedError{
62			gqlerrors.FormatError(fmt.Errorf("%s.Init: %v", ext.Name(), errors.New("test error"))),
63		},
64	}
65	if !reflect.DeepEqual(expected, result) {
66		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
67	}
68}
69
70func TestExtensionParseDidStartPanic(t *testing.T) {
71	ext := newtestExt("testExt")
72	ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) {
73		if true {
74			panic(errors.New("test error"))
75		}
76		return ctx, func(err error) {
77
78		}
79	}
80
81	schema := tinit(t)
82	query := `query Example { a }`
83	schema.AddExtensions(ext)
84
85	result := graphql.Do(graphql.Params{
86		Schema:        schema,
87		RequestString: query,
88	})
89
90	expected := &graphql.Result{
91		Data: nil,
92		Errors: []gqlerrors.FormattedError{
93			gqlerrors.FormatError(fmt.Errorf("%s.ParseDidStart: %v", ext.Name(), errors.New("test error"))),
94		},
95	}
96	if !reflect.DeepEqual(expected, result) {
97		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
98	}
99}
100
101func TestExtensionParseFinishFuncPanic(t *testing.T) {
102	ext := newtestExt("testExt")
103	ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) {
104		return ctx, func(err error) {
105			panic(errors.New("test error"))
106		}
107	}
108
109	schema := tinit(t)
110	query := `query Example { a }`
111	schema.AddExtensions(ext)
112
113	result := graphql.Do(graphql.Params{
114		Schema:        schema,
115		RequestString: query,
116	})
117
118	expected := &graphql.Result{
119		Data: nil,
120		Errors: []gqlerrors.FormattedError{
121			gqlerrors.FormatError(fmt.Errorf("%s.ParseFinishFunc: %v", ext.Name(), errors.New("test error"))),
122		},
123	}
124	if !reflect.DeepEqual(expected, result) {
125		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
126	}
127}
128
129func TestExtensionValidationDidStartPanic(t *testing.T) {
130	ext := newtestExt("testExt")
131	ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) {
132		if true {
133			panic(errors.New("test error"))
134		}
135		return ctx, func([]gqlerrors.FormattedError) {
136
137		}
138	}
139
140	schema := tinit(t)
141	query := `query Example { a }`
142	schema.AddExtensions(ext)
143
144	result := graphql.Do(graphql.Params{
145		Schema:        schema,
146		RequestString: query,
147	})
148
149	expected := &graphql.Result{
150		Data: nil,
151		Errors: []gqlerrors.FormattedError{
152			gqlerrors.FormatError(fmt.Errorf("%s.ValidationDidStart: %v", ext.Name(), errors.New("test error"))),
153		},
154	}
155	if !reflect.DeepEqual(expected, result) {
156		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
157	}
158}
159
160func TestExtensionValidationFinishFuncPanic(t *testing.T) {
161	ext := newtestExt("testExt")
162	ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) {
163		return ctx, func([]gqlerrors.FormattedError) {
164			panic(errors.New("test error"))
165		}
166	}
167
168	schema := tinit(t)
169	query := `query Example { a }`
170	schema.AddExtensions(ext)
171
172	result := graphql.Do(graphql.Params{
173		Schema:        schema,
174		RequestString: query,
175	})
176
177	expected := &graphql.Result{
178		Data: nil,
179		Errors: []gqlerrors.FormattedError{
180			gqlerrors.FormatError(fmt.Errorf("%s.ValidationFinishFunc: %v", ext.Name(), errors.New("test error"))),
181		},
182	}
183	if !reflect.DeepEqual(expected, result) {
184		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
185	}
186}
187
188func TestExtensionExecutionDidStartPanic(t *testing.T) {
189	ext := newtestExt("testExt")
190	ext.executionDidStartFn = func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) {
191		if true {
192			panic(errors.New("test error"))
193		}
194		return ctx, func(r *graphql.Result) {
195
196		}
197	}
198
199	schema := tinit(t)
200	query := `query Example { a }`
201	schema.AddExtensions(ext)
202
203	result := graphql.Do(graphql.Params{
204		Schema:        schema,
205		RequestString: query,
206	})
207
208	expected := &graphql.Result{
209		Data: nil,
210		Errors: []gqlerrors.FormattedError{
211			gqlerrors.FormatError(fmt.Errorf("%s.ExecutionDidStart: %v", ext.Name(), errors.New("test error"))),
212		},
213	}
214	if !reflect.DeepEqual(expected, result) {
215		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
216	}
217}
218
219func TestExtensionExecutionFinishFuncPanic(t *testing.T) {
220	ext := newtestExt("testExt")
221	ext.executionDidStartFn = func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) {
222		return ctx, func(r *graphql.Result) {
223			panic(errors.New("test error"))
224		}
225	}
226
227	schema := tinit(t)
228	query := `query Example { a }`
229	schema.AddExtensions(ext)
230
231	result := graphql.Do(graphql.Params{
232		Schema:        schema,
233		RequestString: query,
234	})
235
236	expected := &graphql.Result{
237		Data: map[string]interface{}{
238			"a": "foo",
239		},
240		Errors: []gqlerrors.FormattedError{
241			gqlerrors.FormatError(fmt.Errorf("%s.ExecutionFinishFunc: %v", ext.Name(), errors.New("test error"))),
242		},
243	}
244
245	if !reflect.DeepEqual(expected, result) {
246		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
247	}
248}
249
250func TestExtensionResolveFieldDidStartPanic(t *testing.T) {
251	ext := newtestExt("testExt")
252	ext.resolveFieldDidStartFn = func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) {
253		if true {
254			panic(errors.New("test error"))
255		}
256		return ctx, func(v interface{}, err error) {
257
258		}
259	}
260
261	schema := tinit(t)
262	query := `query Example { a }`
263	schema.AddExtensions(ext)
264
265	result := graphql.Do(graphql.Params{
266		Schema:        schema,
267		RequestString: query,
268	})
269
270	expected := &graphql.Result{
271		Data: map[string]interface{}{
272			"a": "foo",
273		},
274		Errors: []gqlerrors.FormattedError{
275			gqlerrors.FormatError(fmt.Errorf("%s.ResolveFieldDidStart: %v", ext.Name(), errors.New("test error"))),
276		},
277	}
278
279	if !reflect.DeepEqual(expected, result) {
280		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
281	}
282}
283
284func TestExtensionResolveFieldFinishFuncPanic(t *testing.T) {
285	ext := newtestExt("testExt")
286	ext.resolveFieldDidStartFn = func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) {
287		return ctx, func(v interface{}, err error) {
288			panic(errors.New("test error"))
289		}
290	}
291
292	schema := tinit(t)
293	query := `query Example { a }`
294	schema.AddExtensions(ext)
295
296	result := graphql.Do(graphql.Params{
297		Schema:        schema,
298		RequestString: query,
299	})
300
301	expected := &graphql.Result{
302		Data: map[string]interface{}{
303			"a": "foo",
304		},
305		Errors: []gqlerrors.FormattedError{
306			gqlerrors.FormatError(fmt.Errorf("%s.ResolveFieldFinishFunc: %v", ext.Name(), errors.New("test error"))),
307		},
308	}
309
310	if !reflect.DeepEqual(expected, result) {
311		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
312	}
313}
314
315func TestExtensionResolveFieldFinishFuncAfterError(t *testing.T) {
316	var fnErrs int
317	ext := newtestExt("testExt")
318	ext.resolveFieldDidStartFn = func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) {
319		return ctx, func(v interface{}, err error) {
320			if err != nil {
321				fnErrs++
322			}
323		}
324	}
325
326	schema := tinit(t)
327	query := `query Example { erred }`
328	schema.AddExtensions(ext)
329
330	result := graphql.Do(graphql.Params{
331		Schema:        schema,
332		RequestString: query,
333	})
334
335	if resErrs := len(result.Errors); resErrs != 1 {
336		t.Errorf("Incorrect number of returned result errors: %d", resErrs)
337	}
338
339	if fnErrs != 1 {
340		t.Errorf("Incorrect number of errors captured: %d", fnErrs)
341	}
342}
343
344func TestExtensionGetResultPanic(t *testing.T) {
345	ext := newtestExt("testExt")
346	ext.getResultFn = func(context.Context) interface{} {
347		if true {
348			panic(errors.New("test error"))
349		}
350		return nil
351	}
352	ext.hasResultFn = func() bool {
353		return true
354	}
355
356	schema := tinit(t)
357	query := `query Example { a }`
358	schema.AddExtensions(ext)
359
360	result := graphql.Do(graphql.Params{
361		Schema:        schema,
362		RequestString: query,
363	})
364
365	expected := &graphql.Result{
366		Data: map[string]interface{}{
367			"a": "foo",
368		},
369		Errors: []gqlerrors.FormattedError{
370			gqlerrors.FormatError(fmt.Errorf("%s.GetResult: %v", ext.Name(), errors.New("test error"))),
371		},
372		Extensions: make(map[string]interface{}),
373	}
374
375	if !reflect.DeepEqual(expected, result) {
376		t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result))
377	}
378}
379
380func newtestExt(name string) *testExt {
381	ext := &testExt{
382		name: name,
383	}
384	if ext.initFn == nil {
385		ext.initFn = func(ctx context.Context, p *graphql.Params) context.Context {
386			return ctx
387		}
388	}
389	if ext.parseDidStartFn == nil {
390		ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) {
391			return ctx, func(err error) {
392
393			}
394		}
395	}
396	if ext.validationDidStartFn == nil {
397		ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) {
398			return ctx, func([]gqlerrors.FormattedError) {
399
400			}
401		}
402	}
403	if ext.executionDidStartFn == nil {
404		ext.executionDidStartFn = func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) {
405			return ctx, func(r *graphql.Result) {
406
407			}
408		}
409	}
410	if ext.resolveFieldDidStartFn == nil {
411		ext.resolveFieldDidStartFn = func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) {
412			return ctx, func(v interface{}, err error) {
413
414			}
415		}
416	}
417	if ext.hasResultFn == nil {
418		ext.hasResultFn = func() bool {
419			return false
420		}
421	}
422	if ext.getResultFn == nil {
423		ext.getResultFn = func(context.Context) interface{} {
424			return nil
425		}
426	}
427	return ext
428}
429
430type testExt struct {
431	name                   string
432	initFn                 func(ctx context.Context, p *graphql.Params) context.Context
433	hasResultFn            func() bool
434	getResultFn            func(context.Context) interface{}
435	parseDidStartFn        func(ctx context.Context) (context.Context, graphql.ParseFinishFunc)
436	validationDidStartFn   func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc)
437	executionDidStartFn    func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc)
438	resolveFieldDidStartFn func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc)
439}
440
441func (t *testExt) Init(ctx context.Context, p *graphql.Params) context.Context {
442	return t.initFn(ctx, p)
443}
444
445func (t *testExt) Name() string {
446	return t.name
447}
448
449func (t *testExt) HasResult() bool {
450	return t.hasResultFn()
451}
452
453func (t *testExt) GetResult(ctx context.Context) interface{} {
454	return t.getResultFn(ctx)
455}
456
457func (t *testExt) ParseDidStart(ctx context.Context) (context.Context, graphql.ParseFinishFunc) {
458	return t.parseDidStartFn(ctx)
459}
460
461func (t *testExt) ValidationDidStart(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) {
462	return t.validationDidStartFn(ctx)
463}
464
465func (t *testExt) ExecutionDidStart(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) {
466	return t.executionDidStartFn(ctx)
467}
468
469func (t *testExt) ResolveFieldDidStart(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) {
470	return t.resolveFieldDidStartFn(ctx, i)
471}
472