1package graphql 2 3import ( 4 "context" 5 "fmt" 6 7 "github.com/graphql-go/graphql/gqlerrors" 8) 9 10type ( 11 // ParseFinishFunc is called when the parse of the query is done 12 ParseFinishFunc func(error) 13 // parseFinishFuncHandler handles the call of all the ParseFinishFuncs from the extenisons 14 parseFinishFuncHandler func(error) []gqlerrors.FormattedError 15 16 // ValidationFinishFunc is called when the Validation of the query is finished 17 ValidationFinishFunc func([]gqlerrors.FormattedError) 18 // validationFinishFuncHandler responsible for the call of all the ValidationFinishFuncs 19 validationFinishFuncHandler func([]gqlerrors.FormattedError) []gqlerrors.FormattedError 20 21 // ExecutionFinishFunc is called when the execution is done 22 ExecutionFinishFunc func(*Result) 23 // executionFinishFuncHandler calls all the ExecutionFinishFuncs from each extension 24 executionFinishFuncHandler func(*Result) []gqlerrors.FormattedError 25 26 // ResolveFieldFinishFunc is called with the result of the ResolveFn and the error it returned 27 ResolveFieldFinishFunc func(interface{}, error) 28 // resolveFieldFinishFuncHandler calls the resolveFieldFinishFns for all the extensions 29 resolveFieldFinishFuncHandler func(interface{}, error) []gqlerrors.FormattedError 30) 31 32// Extension is an interface for extensions in graphql 33type Extension interface { 34 // Init is used to help you initialize the extension 35 Init(context.Context, *Params) context.Context 36 37 // Name returns the name of the extension (make sure it's custom) 38 Name() string 39 40 // ParseDidStart is being called before starting the parse 41 ParseDidStart(context.Context) (context.Context, ParseFinishFunc) 42 43 // ValidationDidStart is called just before the validation begins 44 ValidationDidStart(context.Context) (context.Context, ValidationFinishFunc) 45 46 // ExecutionDidStart notifies about the start of the execution 47 ExecutionDidStart(context.Context) (context.Context, ExecutionFinishFunc) 48 49 // ResolveFieldDidStart notifies about the start of the resolving of a field 50 ResolveFieldDidStart(context.Context, *ResolveInfo) (context.Context, ResolveFieldFinishFunc) 51 52 // HasResult returns if the extension wants to add data to the result 53 HasResult() bool 54 55 // GetResult returns the data that the extension wants to add to the result 56 GetResult(context.Context) interface{} 57} 58 59// handleExtensionsInits handles all the init functions for all the extensions in the schema 60func handleExtensionsInits(p *Params) gqlerrors.FormattedErrors { 61 errs := gqlerrors.FormattedErrors{} 62 for _, ext := range p.Schema.extensions { 63 func() { 64 // catch panic from an extension init fn 65 defer func() { 66 if r := recover(); r != nil { 67 errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.Init: %v", ext.Name(), r.(error)))) 68 } 69 }() 70 // update context 71 p.Context = ext.Init(p.Context, p) 72 }() 73 } 74 return errs 75} 76 77// handleExtensionsParseDidStart runs the ParseDidStart functions for each extension 78func handleExtensionsParseDidStart(p *Params) ([]gqlerrors.FormattedError, parseFinishFuncHandler) { 79 fs := map[string]ParseFinishFunc{} 80 errs := gqlerrors.FormattedErrors{} 81 for _, ext := range p.Schema.extensions { 82 var ( 83 ctx context.Context 84 finishFn ParseFinishFunc 85 ) 86 // catch panic from an extension's parseDidStart functions 87 func() { 88 defer func() { 89 if r := recover(); r != nil { 90 errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ParseDidStart: %v", ext.Name(), r.(error)))) 91 } 92 }() 93 ctx, finishFn = ext.ParseDidStart(p.Context) 94 // update context 95 p.Context = ctx 96 fs[ext.Name()] = finishFn 97 }() 98 } 99 return errs, func(err error) []gqlerrors.FormattedError { 100 errs := gqlerrors.FormattedErrors{} 101 for name, fn := range fs { 102 func() { 103 // catch panic from a finishFn 104 defer func() { 105 if r := recover(); r != nil { 106 errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ParseFinishFunc: %v", name, r.(error)))) 107 } 108 }() 109 fn(err) 110 }() 111 } 112 return errs 113 } 114} 115 116// handleExtensionsValidationDidStart notifies the extensions about the start of the validation process 117func handleExtensionsValidationDidStart(p *Params) ([]gqlerrors.FormattedError, validationFinishFuncHandler) { 118 fs := map[string]ValidationFinishFunc{} 119 errs := gqlerrors.FormattedErrors{} 120 for _, ext := range p.Schema.extensions { 121 var ( 122 ctx context.Context 123 finishFn ValidationFinishFunc 124 ) 125 // catch panic from an extension's validationDidStart function 126 func() { 127 defer func() { 128 if r := recover(); r != nil { 129 errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ValidationDidStart: %v", ext.Name(), r.(error)))) 130 } 131 }() 132 ctx, finishFn = ext.ValidationDidStart(p.Context) 133 // update context 134 p.Context = ctx 135 fs[ext.Name()] = finishFn 136 }() 137 } 138 return errs, func(errs []gqlerrors.FormattedError) []gqlerrors.FormattedError { 139 extErrs := gqlerrors.FormattedErrors{} 140 for name, finishFn := range fs { 141 func() { 142 // catch panic from a finishFn 143 defer func() { 144 if r := recover(); r != nil { 145 extErrs = append(extErrs, gqlerrors.FormatError(fmt.Errorf("%s.ValidationFinishFunc: %v", name, r.(error)))) 146 } 147 }() 148 finishFn(errs) 149 }() 150 } 151 return extErrs 152 } 153} 154 155// handleExecutionDidStart handles the ExecutionDidStart functions 156func handleExtensionsExecutionDidStart(p *ExecuteParams) ([]gqlerrors.FormattedError, executionFinishFuncHandler) { 157 fs := map[string]ExecutionFinishFunc{} 158 errs := gqlerrors.FormattedErrors{} 159 for _, ext := range p.Schema.extensions { 160 var ( 161 ctx context.Context 162 finishFn ExecutionFinishFunc 163 ) 164 // catch panic from an extension's executionDidStart function 165 func() { 166 defer func() { 167 if r := recover(); r != nil { 168 errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ExecutionDidStart: %v", ext.Name(), r.(error)))) 169 } 170 }() 171 ctx, finishFn = ext.ExecutionDidStart(p.Context) 172 // update context 173 p.Context = ctx 174 fs[ext.Name()] = finishFn 175 }() 176 } 177 return errs, func(result *Result) []gqlerrors.FormattedError { 178 extErrs := gqlerrors.FormattedErrors{} 179 for name, finishFn := range fs { 180 func() { 181 // catch panic from a finishFn 182 defer func() { 183 if r := recover(); r != nil { 184 extErrs = append(extErrs, gqlerrors.FormatError(fmt.Errorf("%s.ExecutionFinishFunc: %v", name, r.(error)))) 185 } 186 }() 187 finishFn(result) 188 }() 189 } 190 return extErrs 191 } 192} 193 194// handleResolveFieldDidStart handles the notification of the extensions about the start of a resolve function 195func handleExtensionsResolveFieldDidStart(exts []Extension, p *executionContext, i *ResolveInfo) ([]gqlerrors.FormattedError, resolveFieldFinishFuncHandler) { 196 fs := map[string]ResolveFieldFinishFunc{} 197 errs := gqlerrors.FormattedErrors{} 198 for _, ext := range p.Schema.extensions { 199 var ( 200 ctx context.Context 201 finishFn ResolveFieldFinishFunc 202 ) 203 // catch panic from an extension's resolveFieldDidStart function 204 func() { 205 defer func() { 206 if r := recover(); r != nil { 207 errs = append(errs, gqlerrors.FormatError(fmt.Errorf("%s.ResolveFieldDidStart: %v", ext.Name(), r.(error)))) 208 } 209 }() 210 ctx, finishFn = ext.ResolveFieldDidStart(p.Context, i) 211 // update context 212 p.Context = ctx 213 fs[ext.Name()] = finishFn 214 }() 215 } 216 return errs, func(val interface{}, err error) []gqlerrors.FormattedError { 217 extErrs := gqlerrors.FormattedErrors{} 218 for name, finishFn := range fs { 219 func() { 220 // catch panic from a finishFn 221 defer func() { 222 if r := recover(); r != nil { 223 extErrs = append(extErrs, gqlerrors.FormatError(fmt.Errorf("%s.ResolveFieldFinishFunc: %v", name, r.(error)))) 224 } 225 }() 226 finishFn(val, err) 227 }() 228 } 229 return extErrs 230 } 231} 232 233func addExtensionResults(p *ExecuteParams, result *Result) { 234 if len(p.Schema.extensions) != 0 { 235 for _, ext := range p.Schema.extensions { 236 func() { 237 defer func() { 238 if r := recover(); r != nil { 239 result.Errors = append(result.Errors, gqlerrors.FormatError(fmt.Errorf("%s.GetResult: %v", ext.Name(), r.(error)))) 240 } 241 }() 242 if ext.HasResult() { 243 if result.Extensions == nil { 244 result.Extensions = make(map[string]interface{}) 245 } 246 result.Extensions[ext.Name()] = ext.GetResult(p.Context) 247 } 248 }() 249 } 250 } 251} 252