1package self_test 2 3import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "net" 8 "sync" 9 10 quic "github.com/lucas-clemente/quic-go" 11 12 . "github.com/onsi/ginkgo" 13 . "github.com/onsi/gomega" 14) 15 16type clientSessionCache struct { 17 mutex sync.Mutex 18 cache map[string]*tls.ClientSessionState 19 20 gets chan<- string 21 puts chan<- string 22} 23 24func newClientSessionCache(gets, puts chan<- string) *clientSessionCache { 25 return &clientSessionCache{ 26 cache: make(map[string]*tls.ClientSessionState), 27 gets: gets, 28 puts: puts, 29 } 30} 31 32var _ tls.ClientSessionCache = &clientSessionCache{} 33 34func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) { 35 c.gets <- sessionKey 36 c.mutex.Lock() 37 session, ok := c.cache[sessionKey] 38 c.mutex.Unlock() 39 return session, ok 40} 41 42func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) { 43 c.puts <- sessionKey 44 c.mutex.Lock() 45 c.cache[sessionKey] = cs 46 c.mutex.Unlock() 47} 48 49var _ = Describe("TLS session resumption", func() { 50 It("uses session resumption", func() { 51 server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) 52 Expect(err).ToNot(HaveOccurred()) 53 defer server.Close() 54 55 gets := make(chan string, 100) 56 puts := make(chan string, 100) 57 cache := newClientSessionCache(gets, puts) 58 tlsConf := getTLSClientConfig() 59 tlsConf.ClientSessionCache = cache 60 sess, err := quic.DialAddr( 61 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 62 tlsConf, 63 nil, 64 ) 65 Expect(err).ToNot(HaveOccurred()) 66 var sessionKey string 67 Eventually(puts).Should(Receive(&sessionKey)) 68 Expect(sess.ConnectionState().DidResume).To(BeFalse()) 69 70 serverSess, err := server.Accept(context.Background()) 71 Expect(err).ToNot(HaveOccurred()) 72 Expect(serverSess.ConnectionState().DidResume).To(BeFalse()) 73 74 sess, err = quic.DialAddr( 75 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 76 tlsConf, 77 nil, 78 ) 79 Expect(err).ToNot(HaveOccurred()) 80 Expect(gets).To(Receive(Equal(sessionKey))) 81 Expect(sess.ConnectionState().DidResume).To(BeTrue()) 82 83 serverSess, err = server.Accept(context.Background()) 84 Expect(err).ToNot(HaveOccurred()) 85 Expect(serverSess.ConnectionState().DidResume).To(BeTrue()) 86 }) 87 88 It("doesn't use session resumption, if the config disables it", func() { 89 sConf := getTLSConfig() 90 sConf.SessionTicketsDisabled = true 91 server, err := quic.ListenAddr("localhost:0", sConf, nil) 92 Expect(err).ToNot(HaveOccurred()) 93 defer server.Close() 94 95 gets := make(chan string, 100) 96 puts := make(chan string, 100) 97 cache := newClientSessionCache(gets, puts) 98 tlsConf := getTLSClientConfig() 99 tlsConf.ClientSessionCache = cache 100 sess, err := quic.DialAddr( 101 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 102 tlsConf, 103 nil, 104 ) 105 Expect(err).ToNot(HaveOccurred()) 106 Consistently(puts).ShouldNot(Receive()) 107 Expect(sess.ConnectionState().DidResume).To(BeFalse()) 108 109 serverSess, err := server.Accept(context.Background()) 110 Expect(err).ToNot(HaveOccurred()) 111 Expect(serverSess.ConnectionState().DidResume).To(BeFalse()) 112 113 sess, err = quic.DialAddr( 114 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 115 tlsConf, 116 nil, 117 ) 118 Expect(err).ToNot(HaveOccurred()) 119 Expect(sess.ConnectionState().DidResume).To(BeFalse()) 120 121 serverSess, err = server.Accept(context.Background()) 122 Expect(err).ToNot(HaveOccurred()) 123 Expect(serverSess.ConnectionState().DidResume).To(BeFalse()) 124 }) 125}) 126