1package graphql
2
3import (
4	"context"
5	"fmt"
6
7	"github.com/graphql-go/graphql/gqlerrors"
8)
9
10type (
11	// ParseFinishFunc is called when the parse of the query is done
12	ParseFinishFunc func(error)
13	// parseFinishFuncHandler handles the call of all the ParseFinishFuncs from the extenisons
14	parseFinishFuncHandler func(error) []gqlerrors.FormattedError
15
16	// ValidationFinishFunc is called when the Validation of the query is finished
17	ValidationFinishFunc func([]gqlerrors.FormattedError)
18	// validationFinishFuncHandler responsible for the call of all the ValidationFinishFuncs
19	validationFinishFuncHandler func([]gqlerrors.FormattedError) []gqlerrors.FormattedError
20
21	// ExecutionFinishFunc is called when the execution is done
22	ExecutionFinishFunc func(*Result)
23	// executionFinishFuncHandler calls all the ExecutionFinishFuncs from each extension
24	executionFinishFuncHandler func(*Result) []gqlerrors.FormattedError
25
26	// ResolveFieldFinishFunc is called with the result of the ResolveFn and the error it returned
27	ResolveFieldFinishFunc func(interface{}, error)
28	// resolveFieldFinishFuncHandler calls the resolveFieldFinishFns for all the extensions
29	resolveFieldFinishFuncHandler func(interface{}, error) []gqlerrors.FormattedError
30)
31
32// Extension is an interface for extensions in graphql
33type Extension interface {
34	// Init is used to help you initialize the extension
35	Init(context.Context, *Params) context.Context
36
37	// Name returns the name of the extension (make sure it's custom)
38	Name() string
39
40	// ParseDidStart is being called before starting the parse
41	ParseDidStart(context.Context) (context.Context, ParseFinishFunc)
42
43	// ValidationDidStart is called just before the validation begins
44	ValidationDidStart(context.Context) (context.Context, ValidationFinishFunc)
45
46	// ExecutionDidStart notifies about the start of the execution
47	ExecutionDidStart(context.Context) (context.Context, ExecutionFinishFunc)
48
49	// ResolveFieldDidStart notifies about the start of the resolving of a field
50	ResolveFieldDidStart(context.Context, *ResolveInfo) (context.Context, ResolveFieldFinishFunc)
51
52	// HasResult returns if the extension wants to add data to the result
53	HasResult() bool
54
55	// GetResult returns the data that the extension wants to add to the result
56	GetResult(context.Context) interface{}
57}
58
59// handleExtensionsInits handles all the init functions for all the extensions in the schema
60func handleExtensionsInits(p *Params) gqlerrors.FormattedErrors {
61	errs := gqlerrors.FormattedErrors{}
62	for _, ext := range p.Schema.extensions {
63		func() {
64			// catch panic from an extension init fn
65			defer func() {
66				if r := recover(); r != nil {
67					errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.Init: %v", ext.Name(), r.(error))))
68				}
69			}()
70			// update context
71			p.Context = ext.Init(p.Context, p)
72		}()
73	}
74	return errs
75}
76
77// handleExtensionsParseDidStart runs the ParseDidStart functions for each extension
78func handleExtensionsParseDidStart(p *Params) ([]gqlerrors.FormattedError, parseFinishFuncHandler) {
79	fs := map[string]ParseFinishFunc{}
80	errs := gqlerrors.FormattedErrors{}
81	for _, ext := range p.Schema.extensions {
82		var (
83			ctx      context.Context
84			finishFn ParseFinishFunc
85		)
86		// catch panic from an extension's parseDidStart functions
87		func() {
88			defer func() {
89				if r := recover(); r != nil {
90					errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ParseDidStart: %v", ext.Name(), r.(error))))
91				}
92			}()
93			ctx, finishFn = ext.ParseDidStart(p.Context)
94			// update context
95			p.Context = ctx
96			fs[ext.Name()] = finishFn
97		}()
98	}
99	return errs, func(err error) []gqlerrors.FormattedError {
100		errs := gqlerrors.FormattedErrors{}
101		for name, fn := range fs {
102			func() {
103				// catch panic from a finishFn
104				defer func() {
105					if r := recover(); r != nil {
106						errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ParseFinishFunc: %v", name, r.(error))))
107					}
108				}()
109				fn(err)
110			}()
111		}
112		return errs
113	}
114}
115
116// handleExtensionsValidationDidStart notifies the extensions about the start of the validation process
117func handleExtensionsValidationDidStart(p *Params) ([]gqlerrors.FormattedError, validationFinishFuncHandler) {
118	fs := map[string]ValidationFinishFunc{}
119	errs := gqlerrors.FormattedErrors{}
120	for _, ext := range p.Schema.extensions {
121		var (
122			ctx      context.Context
123			finishFn ValidationFinishFunc
124		)
125		// catch panic from an extension's validationDidStart function
126		func() {
127			defer func() {
128				if r := recover(); r != nil {
129					errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ValidationDidStart: %v", ext.Name(), r.(error))))
130				}
131			}()
132			ctx, finishFn = ext.ValidationDidStart(p.Context)
133			// update context
134			p.Context = ctx
135			fs[ext.Name()] = finishFn
136		}()
137	}
138	return errs, func(errs []gqlerrors.FormattedError) []gqlerrors.FormattedError {
139		extErrs := gqlerrors.FormattedErrors{}
140		for name, finishFn := range fs {
141			func() {
142				// catch panic from a finishFn
143				defer func() {
144					if r := recover(); r != nil {
145						extErrs = append(extErrs, gqlerrors.FormatError(fmt.Errorf("%s.ValidationFinishFunc: %v", name, r.(error))))
146					}
147				}()
148				finishFn(errs)
149			}()
150		}
151		return extErrs
152	}
153}
154
155// handleExecutionDidStart handles the ExecutionDidStart functions
156func handleExtensionsExecutionDidStart(p *ExecuteParams) ([]gqlerrors.FormattedError, executionFinishFuncHandler) {
157	fs := map[string]ExecutionFinishFunc{}
158	errs := gqlerrors.FormattedErrors{}
159	for _, ext := range p.Schema.extensions {
160		var (
161			ctx      context.Context
162			finishFn ExecutionFinishFunc
163		)
164		// catch panic from an extension's executionDidStart function
165		func() {
166			defer func() {
167				if r := recover(); r != nil {
168					errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ExecutionDidStart: %v", ext.Name(), r.(error))))
169				}
170			}()
171			ctx, finishFn = ext.ExecutionDidStart(p.Context)
172			// update context
173			p.Context = ctx
174			fs[ext.Name()] = finishFn
175		}()
176	}
177	return errs, func(result *Result) []gqlerrors.FormattedError {
178		extErrs := gqlerrors.FormattedErrors{}
179		for name, finishFn := range fs {
180			func() {
181				// catch panic from a finishFn
182				defer func() {
183					if r := recover(); r != nil {
184						extErrs = append(extErrs, gqlerrors.FormatError(fmt.Errorf("%s.ExecutionFinishFunc: %v", name, r.(error))))
185					}
186				}()
187				finishFn(result)
188			}()
189		}
190		return extErrs
191	}
192}
193
194// handleResolveFieldDidStart handles the notification of the extensions about the start of a resolve function
195func handleExtensionsResolveFieldDidStart(exts []Extension, p *executionContext, i *ResolveInfo) ([]gqlerrors.FormattedError, resolveFieldFinishFuncHandler) {
196	fs := map[string]ResolveFieldFinishFunc{}
197	errs := gqlerrors.FormattedErrors{}
198	for _, ext := range p.Schema.extensions {
199		var (
200			ctx      context.Context
201			finishFn ResolveFieldFinishFunc
202		)
203		// catch panic from an extension's resolveFieldDidStart function
204		func() {
205			defer func() {
206				if r := recover(); r != nil {
207					errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ResolveFieldDidStart: %v", ext.Name(), r.(error))))
208				}
209			}()
210			ctx, finishFn = ext.ResolveFieldDidStart(p.Context, i)
211			// update context
212			p.Context = ctx
213			fs[ext.Name()] = finishFn
214		}()
215	}
216	return errs, func(val interface{}, err error) []gqlerrors.FormattedError {
217		extErrs := gqlerrors.FormattedErrors{}
218		for name, finishFn := range fs {
219			func() {
220				// catch panic from a finishFn
221				defer func() {
222					if r := recover(); r != nil {
223						extErrs = append(extErrs, gqlerrors.FormatError(fmt.Errorf("%s.ResolveFieldFinishFunc: %v", name, r.(error))))
224					}
225				}()
226				finishFn(val, err)
227			}()
228		}
229		return extErrs
230	}
231}
232
233func addExtensionResults(p *ExecuteParams, result *Result) {
234	if len(p.Schema.extensions) != 0 {
235		for _, ext := range p.Schema.extensions {
236			func() {
237				defer func() {
238					if r := recover(); r != nil {
239						result.Errors = append(result.Errors, gqlerrors.FormatError(fmt.Errorf("%s.GetResult: %v", ext.Name(), r.(error))))
240					}
241				}()
242				if ext.HasResult() {
243					if result.Extensions == nil {
244						result.Extensions = make(map[string]interface{})
245					}
246					result.Extensions[ext.Name()] = ext.GetResult(p.Context)
247				}
248			}()
249		}
250	}
251}
252