1package self_test
2
3import (
4	"context"
5	"crypto/tls"
6	"fmt"
7	"net"
8	"sync"
9
10	quic "github.com/lucas-clemente/quic-go"
11
12	. "github.com/onsi/ginkgo"
13	. "github.com/onsi/gomega"
14)
15
16type clientSessionCache struct {
17	mutex sync.Mutex
18	cache map[string]*tls.ClientSessionState
19
20	gets chan<- string
21	puts chan<- string
22}
23
24func newClientSessionCache(gets, puts chan<- string) *clientSessionCache {
25	return &clientSessionCache{
26		cache: make(map[string]*tls.ClientSessionState),
27		gets:  gets,
28		puts:  puts,
29	}
30}
31
32var _ tls.ClientSessionCache = &clientSessionCache{}
33
34func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
35	c.gets <- sessionKey
36	c.mutex.Lock()
37	session, ok := c.cache[sessionKey]
38	c.mutex.Unlock()
39	return session, ok
40}
41
42func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
43	c.puts <- sessionKey
44	c.mutex.Lock()
45	c.cache[sessionKey] = cs
46	c.mutex.Unlock()
47}
48
49var _ = Describe("TLS session resumption", func() {
50	It("uses session resumption", func() {
51		server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil)
52		Expect(err).ToNot(HaveOccurred())
53		defer server.Close()
54
55		gets := make(chan string, 100)
56		puts := make(chan string, 100)
57		cache := newClientSessionCache(gets, puts)
58		tlsConf := getTLSClientConfig()
59		tlsConf.ClientSessionCache = cache
60		sess, err := quic.DialAddr(
61			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
62			tlsConf,
63			nil,
64		)
65		Expect(err).ToNot(HaveOccurred())
66		var sessionKey string
67		Eventually(puts).Should(Receive(&sessionKey))
68		Expect(sess.ConnectionState().DidResume).To(BeFalse())
69
70		serverSess, err := server.Accept(context.Background())
71		Expect(err).ToNot(HaveOccurred())
72		Expect(serverSess.ConnectionState().DidResume).To(BeFalse())
73
74		sess, err = quic.DialAddr(
75			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
76			tlsConf,
77			nil,
78		)
79		Expect(err).ToNot(HaveOccurred())
80		Expect(gets).To(Receive(Equal(sessionKey)))
81		Expect(sess.ConnectionState().DidResume).To(BeTrue())
82
83		serverSess, err = server.Accept(context.Background())
84		Expect(err).ToNot(HaveOccurred())
85		Expect(serverSess.ConnectionState().DidResume).To(BeTrue())
86	})
87
88	It("doesn't use session resumption, if the config disables it", func() {
89		sConf := getTLSConfig()
90		sConf.SessionTicketsDisabled = true
91		server, err := quic.ListenAddr("localhost:0", sConf, nil)
92		Expect(err).ToNot(HaveOccurred())
93		defer server.Close()
94
95		gets := make(chan string, 100)
96		puts := make(chan string, 100)
97		cache := newClientSessionCache(gets, puts)
98		tlsConf := getTLSClientConfig()
99		tlsConf.ClientSessionCache = cache
100		sess, err := quic.DialAddr(
101			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
102			tlsConf,
103			nil,
104		)
105		Expect(err).ToNot(HaveOccurred())
106		Consistently(puts).ShouldNot(Receive())
107		Expect(sess.ConnectionState().DidResume).To(BeFalse())
108
109		serverSess, err := server.Accept(context.Background())
110		Expect(err).ToNot(HaveOccurred())
111		Expect(serverSess.ConnectionState().DidResume).To(BeFalse())
112
113		sess, err = quic.DialAddr(
114			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
115			tlsConf,
116			nil,
117		)
118		Expect(err).ToNot(HaveOccurred())
119		Expect(sess.ConnectionState().DidResume).To(BeFalse())
120
121		serverSess, err = server.Accept(context.Background())
122		Expect(err).ToNot(HaveOccurred())
123		Expect(serverSess.ConnectionState().DidResume).To(BeFalse())
124	})
125})
126