1package http3 2 3import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "io" 9 "net" 10 "net/http" 11 "time" 12 13 "github.com/golang/mock/gomock" 14 "github.com/lucas-clemente/quic-go" 15 mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" 16 "github.com/lucas-clemente/quic-go/internal/protocol" 17 "github.com/lucas-clemente/quic-go/internal/testdata" 18 "github.com/lucas-clemente/quic-go/internal/utils" 19 "github.com/marten-seemann/qpack" 20 21 . "github.com/onsi/ginkgo" 22 . "github.com/onsi/gomega" 23) 24 25type mockConn struct { 26 net.Conn 27 version protocol.VersionNumber 28} 29 30func newMockConn(version protocol.VersionNumber) net.Conn { 31 return &mockConn{version: version} 32} 33 34func (c *mockConn) GetQUICVersion() protocol.VersionNumber { 35 return c.version 36} 37 38var _ = Describe("Server", func() { 39 var ( 40 s *Server 41 origQuicListenAddr = quicListenAddr 42 ) 43 44 BeforeEach(func() { 45 s = &Server{ 46 Server: &http.Server{ 47 TLSConfig: testdata.GetTLSConfig(), 48 }, 49 logger: utils.DefaultLogger, 50 } 51 origQuicListenAddr = quicListenAddr 52 }) 53 54 AfterEach(func() { 55 quicListenAddr = origQuicListenAddr 56 }) 57 58 Context("handling requests", func() { 59 var ( 60 qpackDecoder *qpack.Decoder 61 str *mockquic.MockStream 62 sess *mockquic.MockEarlySession 63 exampleGetRequest *http.Request 64 examplePostRequest *http.Request 65 ) 66 reqContext := context.Background() 67 68 decodeHeader := func(str io.Reader) map[string][]string { 69 fields := make(map[string][]string) 70 decoder := qpack.NewDecoder(nil) 71 72 frame, err := parseNextFrame(str) 73 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 74 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 75 headersFrame := frame.(*headersFrame) 76 data := make([]byte, headersFrame.Length) 77 _, err = io.ReadFull(str, data) 78 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 79 hfs, err := decoder.DecodeFull(data) 80 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 81 for _, p := range hfs { 82 fields[p.Name] = append(fields[p.Name], p.Value) 83 } 84 return fields 85 } 86 87 encodeRequest := func(req *http.Request) []byte { 88 buf := &bytes.Buffer{} 89 str := mockquic.NewMockStream(mockCtrl) 90 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 91 return buf.Write(p) 92 }).AnyTimes() 93 closed := make(chan struct{}) 94 str.EXPECT().Close().Do(func() { close(closed) }) 95 rw := newRequestWriter(utils.DefaultLogger) 96 Expect(rw.WriteRequest(str, req, false)).To(Succeed()) 97 Eventually(closed).Should(BeClosed()) 98 return buf.Bytes() 99 } 100 101 setRequest := func(data []byte) { 102 buf := bytes.NewBuffer(data) 103 str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 104 if buf.Len() == 0 { 105 return 0, io.EOF 106 } 107 return buf.Read(p) 108 }).AnyTimes() 109 } 110 111 BeforeEach(func() { 112 var err error 113 exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) 114 Expect(err).ToNot(HaveOccurred()) 115 examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) 116 Expect(err).ToNot(HaveOccurred()) 117 118 qpackDecoder = qpack.NewDecoder(nil) 119 str = mockquic.NewMockStream(mockCtrl) 120 121 sess = mockquic.NewMockEarlySession(mockCtrl) 122 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 123 sess.EXPECT().RemoteAddr().Return(addr).AnyTimes() 124 sess.EXPECT().LocalAddr().AnyTimes() 125 }) 126 127 It("calls the HTTP handler function", func() { 128 requestChan := make(chan *http.Request, 1) 129 s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 130 requestChan <- r 131 }) 132 133 setRequest(encodeRequest(exampleGetRequest)) 134 str.EXPECT().Context().Return(reqContext) 135 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 136 return len(p), nil 137 }).AnyTimes() 138 str.EXPECT().CancelRead(gomock.Any()) 139 140 Expect(s.handleRequest(sess, str, qpackDecoder, nil)).To(Equal(requestError{})) 141 var req *http.Request 142 Eventually(requestChan).Should(Receive(&req)) 143 Expect(req.Host).To(Equal("www.example.com")) 144 Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) 145 Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) 146 }) 147 148 It("returns 200 with an empty handler", func() { 149 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 150 151 responseBuf := &bytes.Buffer{} 152 setRequest(encodeRequest(exampleGetRequest)) 153 str.EXPECT().Context().Return(reqContext) 154 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 155 return responseBuf.Write(p) 156 }).AnyTimes() 157 str.EXPECT().CancelRead(gomock.Any()) 158 159 serr := s.handleRequest(sess, str, qpackDecoder, nil) 160 Expect(serr.err).ToNot(HaveOccurred()) 161 hfs := decodeHeader(responseBuf) 162 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 163 }) 164 165 It("handles a panicking handler", func() { 166 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 167 panic("foobar") 168 }) 169 170 responseBuf := &bytes.Buffer{} 171 setRequest(encodeRequest(exampleGetRequest)) 172 str.EXPECT().Context().Return(reqContext) 173 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 174 return responseBuf.Write(p) 175 }).AnyTimes() 176 str.EXPECT().CancelRead(gomock.Any()) 177 178 serr := s.handleRequest(sess, str, qpackDecoder, nil) 179 Expect(serr.err).ToNot(HaveOccurred()) 180 hfs := decodeHeader(responseBuf) 181 Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"})) 182 }) 183 184 Context("stream- and connection-level errors", func() { 185 var sess *mockquic.MockEarlySession 186 187 BeforeEach(func() { 188 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 189 sess = mockquic.NewMockEarlySession(mockCtrl) 190 controlStr := mockquic.NewMockStream(mockCtrl) 191 controlStr.EXPECT().Write(gomock.Any()) 192 sess.EXPECT().OpenUniStream().Return(controlStr, nil) 193 sess.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) 194 sess.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 195 sess.EXPECT().RemoteAddr().Return(addr).AnyTimes() 196 sess.EXPECT().LocalAddr().AnyTimes() 197 }) 198 199 It("cancels reading when client sends a body in GET request", func() { 200 handlerCalled := make(chan struct{}) 201 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 202 close(handlerCalled) 203 }) 204 205 requestData := encodeRequest(exampleGetRequest) 206 buf := &bytes.Buffer{} 207 (&dataFrame{Length: 6}).Write(buf) // add a body 208 buf.Write([]byte("foobar")) 209 responseBuf := &bytes.Buffer{} 210 setRequest(append(requestData, buf.Bytes()...)) 211 done := make(chan struct{}) 212 str.EXPECT().Context().Return(reqContext) 213 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 214 return responseBuf.Write(p) 215 }).AnyTimes() 216 str.EXPECT().CancelRead(quic.ErrorCode(errorNoError)) 217 str.EXPECT().Close().Do(func() { close(done) }) 218 219 s.handleConn(sess) 220 Eventually(done).Should(BeClosed()) 221 hfs := decodeHeader(responseBuf) 222 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 223 }) 224 225 It("errors when the client sends a too large header frame", func() { 226 s.Server.MaxHeaderBytes = 20 227 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 228 Fail("Handler should not be called.") 229 }) 230 231 requestData := encodeRequest(exampleGetRequest) 232 buf := &bytes.Buffer{} 233 (&dataFrame{Length: 6}).Write(buf) // add a body 234 buf.Write([]byte("foobar")) 235 responseBuf := &bytes.Buffer{} 236 setRequest(append(requestData, buf.Bytes()...)) 237 done := make(chan struct{}) 238 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 239 return responseBuf.Write(p) 240 }).AnyTimes() 241 str.EXPECT().CancelWrite(quic.ErrorCode(errorFrameError)).Do(func(quic.ErrorCode) { close(done) }) 242 243 s.handleConn(sess) 244 Eventually(done).Should(BeClosed()) 245 }) 246 247 It("handles a request for which the client immediately resets the stream", func() { 248 handlerCalled := make(chan struct{}) 249 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 250 close(handlerCalled) 251 }) 252 253 testErr := errors.New("stream reset") 254 done := make(chan struct{}) 255 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 256 str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestIncomplete)).Do(func(quic.ErrorCode) { close(done) }) 257 258 s.handleConn(sess) 259 Consistently(handlerCalled).ShouldNot(BeClosed()) 260 }) 261 262 It("closes the connection when the first frame is not a HEADERS frame", func() { 263 handlerCalled := make(chan struct{}) 264 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 265 close(handlerCalled) 266 }) 267 268 buf := &bytes.Buffer{} 269 (&dataFrame{}).Write(buf) 270 setRequest(buf.Bytes()) 271 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 272 return len(p), nil 273 }).AnyTimes() 274 275 done := make(chan struct{}) 276 sess.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ErrorCode, _ string) { 277 Expect(code).To(Equal(quic.ErrorCode(errorFrameUnexpected))) 278 close(done) 279 }) 280 s.handleConn(sess) 281 Eventually(done).Should(BeClosed()) 282 }) 283 284 It("closes the connection when the first frame is not a HEADERS frame", func() { 285 handlerCalled := make(chan struct{}) 286 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 287 close(handlerCalled) 288 }) 289 290 // use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest, 291 // but the request will still end up larger than DefaultMaxHeaderBytes. 292 url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) 293 req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) 294 Expect(err).ToNot(HaveOccurred()) 295 setRequest(encodeRequest(req)) 296 // str.EXPECT().Context().Return(reqContext) 297 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 298 return len(p), nil 299 }).AnyTimes() 300 done := make(chan struct{}) 301 str.EXPECT().CancelWrite(quic.ErrorCode(errorFrameError)).Do(func(quic.ErrorCode) { close(done) }) 302 303 s.handleConn(sess) 304 Eventually(done).Should(BeClosed()) 305 }) 306 }) 307 308 It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { 309 handlerCalled := make(chan struct{}) 310 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 311 r.Body = struct { 312 io.Reader 313 io.Closer 314 }{} 315 close(handlerCalled) 316 }) 317 318 setRequest(encodeRequest(examplePostRequest)) 319 str.EXPECT().Context().Return(reqContext) 320 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 321 return len(p), nil 322 }).AnyTimes() 323 str.EXPECT().CancelRead(quic.ErrorCode(errorNoError)) 324 325 serr := s.handleRequest(sess, str, qpackDecoder, nil) 326 Expect(serr.err).ToNot(HaveOccurred()) 327 Eventually(handlerCalled).Should(BeClosed()) 328 }) 329 330 It("cancels the request context when the stream is closed", func() { 331 handlerCalled := make(chan struct{}) 332 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 333 defer GinkgoRecover() 334 Expect(r.Context().Done()).To(BeClosed()) 335 Expect(r.Context().Err()).To(MatchError(context.Canceled)) 336 close(handlerCalled) 337 }) 338 setRequest(encodeRequest(examplePostRequest)) 339 340 reqContext, cancel := context.WithCancel(context.Background()) 341 cancel() 342 str.EXPECT().Context().Return(reqContext) 343 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 344 return len(p), nil 345 }).AnyTimes() 346 str.EXPECT().CancelRead(quic.ErrorCode(errorNoError)) 347 348 serr := s.handleRequest(sess, str, qpackDecoder, nil) 349 Expect(serr.err).ToNot(HaveOccurred()) 350 Eventually(handlerCalled).Should(BeClosed()) 351 }) 352 }) 353 354 Context("setting http headers", func() { 355 BeforeEach(func() { 356 s.QuicConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.VersionDraft29}} 357 }) 358 359 expected := http.Header{ 360 "Alt-Svc": {`h3-29=":443"; ma=2592000`}, 361 } 362 363 It("sets proper headers with numeric port", func() { 364 s.Server.Addr = ":443" 365 hdr := http.Header{} 366 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 367 Expect(hdr).To(Equal(expected)) 368 }) 369 370 It("sets proper headers with full addr", func() { 371 s.Server.Addr = "127.0.0.1:443" 372 hdr := http.Header{} 373 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 374 Expect(hdr).To(Equal(expected)) 375 }) 376 377 It("sets proper headers with string port", func() { 378 s.Server.Addr = ":https" 379 hdr := http.Header{} 380 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 381 Expect(hdr).To(Equal(expected)) 382 }) 383 384 It("works multiple times", func() { 385 s.Server.Addr = ":https" 386 hdr := http.Header{} 387 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 388 Expect(hdr).To(Equal(expected)) 389 hdr = http.Header{} 390 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 391 Expect(hdr).To(Equal(expected)) 392 }) 393 394 It("works if the quic.Config sets QUIC versions", func() { 395 s.Server.Addr = ":443" 396 s.QuicConfig.Versions = []quic.VersionNumber{quic.VersionDraft32, quic.VersionDraft29} 397 hdr := http.Header{} 398 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 399 Expect(hdr).To(Equal(http.Header{"Alt-Svc": {`h3-32=":443"; ma=2592000,h3-29=":443"; ma=2592000`}})) 400 }) 401 }) 402 403 It("errors when ListenAndServe is called with s.Server nil", func() { 404 Expect((&Server{}).ListenAndServe()).To(MatchError("use of http3.Server without http.Server")) 405 }) 406 407 It("errors when ListenAndServeTLS is called with s.Server nil", func() { 408 Expect((&Server{}).ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError("use of http3.Server without http.Server")) 409 }) 410 411 It("should nop-Close() when s.server is nil", func() { 412 Expect((&Server{}).Close()).To(Succeed()) 413 }) 414 415 It("errors when ListenAndServe is called after Close", func() { 416 serv := &Server{Server: &http.Server{}} 417 Expect(serv.Close()).To(Succeed()) 418 Expect(serv.ListenAndServe()).To(MatchError(http.ErrServerClosed)) 419 }) 420 421 Context("Serve", func() { 422 origQuicListen := quicListen 423 424 AfterEach(func() { 425 quicListen = origQuicListen 426 }) 427 428 It("serves a packet conn", func() { 429 ln := mockquic.NewMockEarlyListener(mockCtrl) 430 conn := &net.UDPConn{} 431 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { 432 Expect(c).To(Equal(conn)) 433 return ln, nil 434 } 435 436 s := &Server{Server: &http.Server{}} 437 s.TLSConfig = &tls.Config{} 438 439 stopAccept := make(chan struct{}) 440 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { 441 <-stopAccept 442 return nil, errors.New("closed") 443 }) 444 done := make(chan struct{}) 445 go func() { 446 defer GinkgoRecover() 447 defer close(done) 448 s.Serve(conn) 449 }() 450 451 Consistently(done).ShouldNot(BeClosed()) 452 ln.EXPECT().Close().Do(func() { close(stopAccept) }) 453 Expect(s.Close()).To(Succeed()) 454 Eventually(done).Should(BeClosed()) 455 }) 456 457 It("serves two packet conns", func() { 458 ln1 := mockquic.NewMockEarlyListener(mockCtrl) 459 ln2 := mockquic.NewMockEarlyListener(mockCtrl) 460 lns := make(chan quic.EarlyListener, 2) 461 lns <- ln1 462 lns <- ln2 463 conn1 := &net.UDPConn{} 464 conn2 := &net.UDPConn{} 465 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { 466 return <-lns, nil 467 } 468 469 s := &Server{Server: &http.Server{}} 470 s.TLSConfig = &tls.Config{} 471 472 stopAccept1 := make(chan struct{}) 473 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { 474 <-stopAccept1 475 return nil, errors.New("closed") 476 }) 477 stopAccept2 := make(chan struct{}) 478 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { 479 <-stopAccept2 480 return nil, errors.New("closed") 481 }) 482 483 done1 := make(chan struct{}) 484 go func() { 485 defer GinkgoRecover() 486 defer close(done1) 487 s.Serve(conn1) 488 }() 489 done2 := make(chan struct{}) 490 go func() { 491 defer GinkgoRecover() 492 defer close(done2) 493 s.Serve(conn2) 494 }() 495 496 Consistently(done1).ShouldNot(BeClosed()) 497 Expect(done2).ToNot(BeClosed()) 498 ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) 499 ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) 500 Expect(s.Close()).To(Succeed()) 501 Eventually(done1).Should(BeClosed()) 502 Eventually(done2).Should(BeClosed()) 503 }) 504 }) 505 506 Context("ListenAndServe", func() { 507 BeforeEach(func() { 508 s.Server.Addr = "localhost:0" 509 }) 510 511 AfterEach(func() { 512 Expect(s.Close()).To(Succeed()) 513 }) 514 515 checkGetConfigForClientVersions := func(conf *tls.Config) { 516 c, err := conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft29)}) 517 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 518 ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft29})) 519 c, err = conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft32)}) 520 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 521 ExpectWithOffset(1, c.NextProtos).To(Equal([]string{nextProtoH3Draft32})) 522 } 523 524 It("uses the quic.Config to start the QUIC server", func() { 525 conf := &quic.Config{HandshakeTimeout: time.Nanosecond} 526 var receivedConf *quic.Config 527 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (quic.EarlyListener, error) { 528 receivedConf = config 529 return nil, errors.New("listen err") 530 } 531 s.QuicConfig = conf 532 Expect(s.ListenAndServe()).To(HaveOccurred()) 533 Expect(receivedConf).To(Equal(conf)) 534 }) 535 536 It("sets the GetConfigForClient and replaces the ALPN token to the tls.Config, if the GetConfigForClient callback is not set", func() { 537 tlsConf := &tls.Config{ 538 ClientAuth: tls.RequireAndVerifyClientCert, 539 NextProtos: []string{"foo", "bar"}, 540 } 541 var receivedConf *tls.Config 542 quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { 543 receivedConf = tlsConf 544 return nil, errors.New("listen err") 545 } 546 s.TLSConfig = tlsConf 547 Expect(s.ListenAndServe()).To(HaveOccurred()) 548 Expect(receivedConf.NextProtos).To(BeEmpty()) 549 Expect(receivedConf.ClientAuth).To(BeZero()) 550 // make sure the original tls.Config was not modified 551 Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"})) 552 // make sure that the config returned from the GetConfigForClient callback sets the fields of the original config 553 conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) 554 Expect(err).ToNot(HaveOccurred()) 555 Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) 556 checkGetConfigForClientVersions(receivedConf) 557 }) 558 559 It("sets the GetConfigForClient callback if no tls.Config is given", func() { 560 var receivedConf *tls.Config 561 quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { 562 receivedConf = tlsConf 563 return nil, errors.New("listen err") 564 } 565 Expect(s.ListenAndServe()).To(HaveOccurred()) 566 Expect(receivedConf).ToNot(BeNil()) 567 checkGetConfigForClientVersions(receivedConf) 568 }) 569 570 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { 571 tlsConf := &tls.Config{ 572 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 573 return &tls.Config{ 574 ClientAuth: tls.RequireAndVerifyClientCert, 575 NextProtos: []string{"foo", "bar"}, 576 }, nil 577 }, 578 } 579 580 var receivedConf *tls.Config 581 quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { 582 receivedConf = conf 583 return nil, errors.New("listen err") 584 } 585 s.TLSConfig = tlsConf 586 Expect(s.ListenAndServe()).To(HaveOccurred()) 587 // check that the original config was not modified 588 conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 589 Expect(err).ToNot(HaveOccurred()) 590 Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) 591 // check that the config returned by the GetConfigForClient callback uses the returned config 592 conf, err = receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) 593 Expect(err).ToNot(HaveOccurred()) 594 Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) 595 checkGetConfigForClientVersions(receivedConf) 596 }) 597 598 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { 599 tlsClientConf := &tls.Config{NextProtos: []string{"foo", "bar"}} 600 tlsConf := &tls.Config{ 601 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 602 return tlsClientConf, nil 603 }, 604 } 605 606 var receivedConf *tls.Config 607 quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { 608 receivedConf = conf 609 return nil, errors.New("listen err") 610 } 611 s.TLSConfig = tlsConf 612 Expect(s.ListenAndServe()).To(HaveOccurred()) 613 // check that the original config was not modified 614 conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 615 Expect(err).ToNot(HaveOccurred()) 616 Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) 617 checkGetConfigForClientVersions(receivedConf) 618 }) 619 620 It("works if GetConfigForClient returns a nil tls.Config", func() { 621 tlsConf := &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }} 622 623 var receivedConf *tls.Config 624 quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { 625 receivedConf = conf 626 return nil, errors.New("listen err") 627 } 628 s.TLSConfig = tlsConf 629 Expect(s.ListenAndServe()).To(HaveOccurred()) 630 conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) 631 Expect(err).ToNot(HaveOccurred()) 632 Expect(conf).ToNot(BeNil()) 633 checkGetConfigForClientVersions(receivedConf) 634 }) 635 }) 636 637 It("closes gracefully", func() { 638 Expect(s.CloseGracefully(0)).To(Succeed()) 639 }) 640 641 It("errors when listening fails", func() { 642 testErr := errors.New("listen error") 643 quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlyListener, error) { 644 return nil, testErr 645 } 646 fullpem, privkey := testdata.GetCertificatePaths() 647 Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) 648 }) 649}) 650