1package graphql_test 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "reflect" 8 "testing" 9 10 "github.com/graphql-go/graphql" 11 "github.com/graphql-go/graphql/gqlerrors" 12 "github.com/graphql-go/graphql/testutil" 13) 14 15func tinit(t *testing.T) graphql.Schema { 16 schema, err := graphql.NewSchema(graphql.SchemaConfig{ 17 Query: graphql.NewObject(graphql.ObjectConfig{ 18 Name: "Type", 19 Fields: graphql.Fields{ 20 "a": &graphql.Field{ 21 Type: graphql.String, 22 Resolve: func(p graphql.ResolveParams) (interface{}, error) { 23 return "foo", nil 24 }, 25 }, 26 "erred": &graphql.Field{ 27 Type: graphql.String, 28 Resolve: func(p graphql.ResolveParams) (interface{}, error) { 29 return "", errors.New("ooops") 30 }, 31 }, 32 }, 33 }), 34 }) 35 if err != nil { 36 t.Fatalf("Error in schema %v", err.Error()) 37 } 38 return schema 39} 40 41func TestExtensionInitPanic(t *testing.T) { 42 ext := newtestExt("testExt") 43 ext.initFn = func(ctx context.Context, p *graphql.Params) context.Context { 44 if true { 45 panic(errors.New("test error")) 46 } 47 return ctx 48 } 49 50 schema := tinit(t) 51 query := `query Example { a }` 52 schema.AddExtensions(ext) 53 54 result := graphql.Do(graphql.Params{ 55 Schema: schema, 56 RequestString: query, 57 }) 58 59 expected := &graphql.Result{ 60 Data: nil, 61 Errors: []gqlerrors.FormattedError{ 62 gqlerrors.FormatError(fmt.Errorf("%s.Init: %v", ext.Name(), errors.New("test error"))), 63 }, 64 } 65 if !reflect.DeepEqual(expected, result) { 66 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 67 } 68} 69 70func TestExtensionParseDidStartPanic(t *testing.T) { 71 ext := newtestExt("testExt") 72 ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 73 if true { 74 panic(errors.New("test error")) 75 } 76 return ctx, func(err error) { 77 78 } 79 } 80 81 schema := tinit(t) 82 query := `query Example { a }` 83 schema.AddExtensions(ext) 84 85 result := graphql.Do(graphql.Params{ 86 Schema: schema, 87 RequestString: query, 88 }) 89 90 expected := &graphql.Result{ 91 Data: nil, 92 Errors: []gqlerrors.FormattedError{ 93 gqlerrors.FormatError(fmt.Errorf("%s.ParseDidStart: %v", ext.Name(), errors.New("test error"))), 94 }, 95 } 96 if !reflect.DeepEqual(expected, result) { 97 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 98 } 99} 100 101func TestExtensionParseFinishFuncPanic(t *testing.T) { 102 ext := newtestExt("testExt") 103 ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 104 return ctx, func(err error) { 105 panic(errors.New("test error")) 106 } 107 } 108 109 schema := tinit(t) 110 query := `query Example { a }` 111 schema.AddExtensions(ext) 112 113 result := graphql.Do(graphql.Params{ 114 Schema: schema, 115 RequestString: query, 116 }) 117 118 expected := &graphql.Result{ 119 Data: nil, 120 Errors: []gqlerrors.FormattedError{ 121 gqlerrors.FormatError(fmt.Errorf("%s.ParseFinishFunc: %v", ext.Name(), errors.New("test error"))), 122 }, 123 } 124 if !reflect.DeepEqual(expected, result) { 125 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 126 } 127} 128 129func TestExtensionValidationDidStartPanic(t *testing.T) { 130 ext := newtestExt("testExt") 131 ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 132 if true { 133 panic(errors.New("test error")) 134 } 135 return ctx, func([]gqlerrors.FormattedError) { 136 137 } 138 } 139 140 schema := tinit(t) 141 query := `query Example { a }` 142 schema.AddExtensions(ext) 143 144 result := graphql.Do(graphql.Params{ 145 Schema: schema, 146 RequestString: query, 147 }) 148 149 expected := &graphql.Result{ 150 Data: nil, 151 Errors: []gqlerrors.FormattedError{ 152 gqlerrors.FormatError(fmt.Errorf("%s.ValidationDidStart: %v", ext.Name(), errors.New("test error"))), 153 }, 154 } 155 if !reflect.DeepEqual(expected, result) { 156 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 157 } 158} 159 160func TestExtensionValidationFinishFuncPanic(t *testing.T) { 161 ext := newtestExt("testExt") 162 ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 163 return ctx, func([]gqlerrors.FormattedError) { 164 panic(errors.New("test error")) 165 } 166 } 167 168 schema := tinit(t) 169 query := `query Example { a }` 170 schema.AddExtensions(ext) 171 172 result := graphql.Do(graphql.Params{ 173 Schema: schema, 174 RequestString: query, 175 }) 176 177 expected := &graphql.Result{ 178 Data: nil, 179 Errors: []gqlerrors.FormattedError{ 180 gqlerrors.FormatError(fmt.Errorf("%s.ValidationFinishFunc: %v", ext.Name(), errors.New("test error"))), 181 }, 182 } 183 if !reflect.DeepEqual(expected, result) { 184 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 185 } 186} 187 188func TestExtensionExecutionDidStartPanic(t *testing.T) { 189 ext := newtestExt("testExt") 190 ext.executionDidStartFn = func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) { 191 if true { 192 panic(errors.New("test error")) 193 } 194 return ctx, func(r *graphql.Result) { 195 196 } 197 } 198 199 schema := tinit(t) 200 query := `query Example { a }` 201 schema.AddExtensions(ext) 202 203 result := graphql.Do(graphql.Params{ 204 Schema: schema, 205 RequestString: query, 206 }) 207 208 expected := &graphql.Result{ 209 Data: nil, 210 Errors: []gqlerrors.FormattedError{ 211 gqlerrors.FormatError(fmt.Errorf("%s.ExecutionDidStart: %v", ext.Name(), errors.New("test error"))), 212 }, 213 } 214 if !reflect.DeepEqual(expected, result) { 215 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 216 } 217} 218 219func TestExtensionExecutionFinishFuncPanic(t *testing.T) { 220 ext := newtestExt("testExt") 221 ext.executionDidStartFn = func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) { 222 return ctx, func(r *graphql.Result) { 223 panic(errors.New("test error")) 224 } 225 } 226 227 schema := tinit(t) 228 query := `query Example { a }` 229 schema.AddExtensions(ext) 230 231 result := graphql.Do(graphql.Params{ 232 Schema: schema, 233 RequestString: query, 234 }) 235 236 expected := &graphql.Result{ 237 Data: map[string]interface{}{ 238 "a": "foo", 239 }, 240 Errors: []gqlerrors.FormattedError{ 241 gqlerrors.FormatError(fmt.Errorf("%s.ExecutionFinishFunc: %v", ext.Name(), errors.New("test error"))), 242 }, 243 } 244 245 if !reflect.DeepEqual(expected, result) { 246 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 247 } 248} 249 250func TestExtensionResolveFieldDidStartPanic(t *testing.T) { 251 ext := newtestExt("testExt") 252 ext.resolveFieldDidStartFn = func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) { 253 if true { 254 panic(errors.New("test error")) 255 } 256 return ctx, func(v interface{}, err error) { 257 258 } 259 } 260 261 schema := tinit(t) 262 query := `query Example { a }` 263 schema.AddExtensions(ext) 264 265 result := graphql.Do(graphql.Params{ 266 Schema: schema, 267 RequestString: query, 268 }) 269 270 expected := &graphql.Result{ 271 Data: map[string]interface{}{ 272 "a": "foo", 273 }, 274 Errors: []gqlerrors.FormattedError{ 275 gqlerrors.FormatError(fmt.Errorf("%s.ResolveFieldDidStart: %v", ext.Name(), errors.New("test error"))), 276 }, 277 } 278 279 if !reflect.DeepEqual(expected, result) { 280 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 281 } 282} 283 284func TestExtensionResolveFieldFinishFuncPanic(t *testing.T) { 285 ext := newtestExt("testExt") 286 ext.resolveFieldDidStartFn = func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) { 287 return ctx, func(v interface{}, err error) { 288 panic(errors.New("test error")) 289 } 290 } 291 292 schema := tinit(t) 293 query := `query Example { a }` 294 schema.AddExtensions(ext) 295 296 result := graphql.Do(graphql.Params{ 297 Schema: schema, 298 RequestString: query, 299 }) 300 301 expected := &graphql.Result{ 302 Data: map[string]interface{}{ 303 "a": "foo", 304 }, 305 Errors: []gqlerrors.FormattedError{ 306 gqlerrors.FormatError(fmt.Errorf("%s.ResolveFieldFinishFunc: %v", ext.Name(), errors.New("test error"))), 307 }, 308 } 309 310 if !reflect.DeepEqual(expected, result) { 311 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 312 } 313} 314 315func TestExtensionResolveFieldFinishFuncAfterError(t *testing.T) { 316 var fnErrs int 317 ext := newtestExt("testExt") 318 ext.resolveFieldDidStartFn = func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) { 319 return ctx, func(v interface{}, err error) { 320 if err != nil { 321 fnErrs++ 322 } 323 } 324 } 325 326 schema := tinit(t) 327 query := `query Example { erred }` 328 schema.AddExtensions(ext) 329 330 result := graphql.Do(graphql.Params{ 331 Schema: schema, 332 RequestString: query, 333 }) 334 335 if resErrs := len(result.Errors); resErrs != 1 { 336 t.Errorf("Incorrect number of returned result errors: %d", resErrs) 337 } 338 339 if fnErrs != 1 { 340 t.Errorf("Incorrect number of errors captured: %d", fnErrs) 341 } 342} 343 344func TestExtensionGetResultPanic(t *testing.T) { 345 ext := newtestExt("testExt") 346 ext.getResultFn = func(context.Context) interface{} { 347 if true { 348 panic(errors.New("test error")) 349 } 350 return nil 351 } 352 ext.hasResultFn = func() bool { 353 return true 354 } 355 356 schema := tinit(t) 357 query := `query Example { a }` 358 schema.AddExtensions(ext) 359 360 result := graphql.Do(graphql.Params{ 361 Schema: schema, 362 RequestString: query, 363 }) 364 365 expected := &graphql.Result{ 366 Data: map[string]interface{}{ 367 "a": "foo", 368 }, 369 Errors: []gqlerrors.FormattedError{ 370 gqlerrors.FormatError(fmt.Errorf("%s.GetResult: %v", ext.Name(), errors.New("test error"))), 371 }, 372 Extensions: make(map[string]interface{}), 373 } 374 375 if !reflect.DeepEqual(expected, result) { 376 t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) 377 } 378} 379 380func newtestExt(name string) *testExt { 381 ext := &testExt{ 382 name: name, 383 } 384 if ext.initFn == nil { 385 ext.initFn = func(ctx context.Context, p *graphql.Params) context.Context { 386 return ctx 387 } 388 } 389 if ext.parseDidStartFn == nil { 390 ext.parseDidStartFn = func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 391 return ctx, func(err error) { 392 393 } 394 } 395 } 396 if ext.validationDidStartFn == nil { 397 ext.validationDidStartFn = func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 398 return ctx, func([]gqlerrors.FormattedError) { 399 400 } 401 } 402 } 403 if ext.executionDidStartFn == nil { 404 ext.executionDidStartFn = func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) { 405 return ctx, func(r *graphql.Result) { 406 407 } 408 } 409 } 410 if ext.resolveFieldDidStartFn == nil { 411 ext.resolveFieldDidStartFn = func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) { 412 return ctx, func(v interface{}, err error) { 413 414 } 415 } 416 } 417 if ext.hasResultFn == nil { 418 ext.hasResultFn = func() bool { 419 return false 420 } 421 } 422 if ext.getResultFn == nil { 423 ext.getResultFn = func(context.Context) interface{} { 424 return nil 425 } 426 } 427 return ext 428} 429 430type testExt struct { 431 name string 432 initFn func(ctx context.Context, p *graphql.Params) context.Context 433 hasResultFn func() bool 434 getResultFn func(context.Context) interface{} 435 parseDidStartFn func(ctx context.Context) (context.Context, graphql.ParseFinishFunc) 436 validationDidStartFn func(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) 437 executionDidStartFn func(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) 438 resolveFieldDidStartFn func(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) 439} 440 441func (t *testExt) Init(ctx context.Context, p *graphql.Params) context.Context { 442 return t.initFn(ctx, p) 443} 444 445func (t *testExt) Name() string { 446 return t.name 447} 448 449func (t *testExt) HasResult() bool { 450 return t.hasResultFn() 451} 452 453func (t *testExt) GetResult(ctx context.Context) interface{} { 454 return t.getResultFn(ctx) 455} 456 457func (t *testExt) ParseDidStart(ctx context.Context) (context.Context, graphql.ParseFinishFunc) { 458 return t.parseDidStartFn(ctx) 459} 460 461func (t *testExt) ValidationDidStart(ctx context.Context) (context.Context, graphql.ValidationFinishFunc) { 462 return t.validationDidStartFn(ctx) 463} 464 465func (t *testExt) ExecutionDidStart(ctx context.Context) (context.Context, graphql.ExecutionFinishFunc) { 466 return t.executionDidStartFn(ctx) 467} 468 469func (t *testExt) ResolveFieldDidStart(ctx context.Context, i *graphql.ResolveInfo) (context.Context, graphql.ResolveFieldFinishFunc) { 470 return t.resolveFieldDidStartFn(ctx, i) 471} 472