1package handshake
2
3import (
4	"fmt"
5
6	"github.com/lucas-clemente/quic-go/internal/protocol"
7	"github.com/lucas-clemente/quic-go/internal/qtls"
8
9	. "github.com/onsi/ginkgo"
10	. "github.com/onsi/gomega"
11)
12
13var _ = Describe("TLS Extension Handler, for the server", func() {
14	var (
15		handlerServer tlsExtensionHandler
16		handlerClient tlsExtensionHandler
17		version       protocol.VersionNumber
18	)
19
20	BeforeEach(func() {
21		version = protocol.VersionDraft29
22	})
23
24	JustBeforeEach(func() {
25		handlerServer = newExtensionHandler(
26			[]byte("foobar"),
27			protocol.PerspectiveServer,
28			version,
29		)
30		handlerClient = newExtensionHandler(
31			[]byte("raboof"),
32			protocol.PerspectiveClient,
33			version,
34		)
35	})
36
37	Context("for the server", func() {
38		for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1} {
39			v := ver
40
41			Context(fmt.Sprintf("sending, for version %s", v), func() {
42				var extensionType uint16
43
44				BeforeEach(func() {
45					version = v
46					if v == protocol.VersionDraft29 {
47						extensionType = quicTLSExtensionTypeOldDrafts
48					} else {
49						extensionType = quicTLSExtensionType
50					}
51				})
52
53				It("only adds TransportParameters for the Encrypted Extensions", func() {
54					// test 2 other handshake types
55					Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty())
56					Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty())
57				})
58
59				It("adds TransportParameters to the EncryptedExtensions message", func() {
60					exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions))
61					Expect(exts).To(HaveLen(1))
62					Expect(exts[0].Type).To(BeEquivalentTo(extensionType))
63					Expect(exts[0].Data).To(Equal([]byte("foobar")))
64				})
65			})
66		}
67
68		Context("receiving", func() {
69			var chExts []qtls.Extension
70
71			JustBeforeEach(func() {
72				chExts = handlerClient.GetExtensions(uint8(typeClientHello))
73				Expect(chExts).To(HaveLen(1))
74			})
75
76			It("sends the extension on the channel", func() {
77				go func() {
78					defer GinkgoRecover()
79					handlerServer.ReceivedExtensions(uint8(typeClientHello), chExts)
80				}()
81
82				var data []byte
83				Eventually(handlerServer.TransportParameters()).Should(Receive(&data))
84				Expect(data).To(Equal([]byte("raboof")))
85			})
86
87			It("sends nil on the channel if the extension is missing", func() {
88				go func() {
89					defer GinkgoRecover()
90					handlerServer.ReceivedExtensions(uint8(typeClientHello), nil)
91				}()
92
93				var data []byte
94				Eventually(handlerServer.TransportParameters()).Should(Receive(&data))
95				Expect(data).To(BeEmpty())
96			})
97
98			It("ignores extensions with different code points", func() {
99				go func() {
100					defer GinkgoRecover()
101					exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}}
102					handlerServer.ReceivedExtensions(uint8(typeClientHello), exts)
103				}()
104
105				var data []byte
106				Eventually(handlerServer.TransportParameters()).Should(Receive())
107				Expect(data).To(BeEmpty())
108			})
109
110			It("ignores extensions that are not sent with the ClientHello", func() {
111				done := make(chan struct{})
112				go func() {
113					defer GinkgoRecover()
114					handlerServer.ReceivedExtensions(uint8(typeFinished), chExts)
115					close(done)
116				}()
117
118				Consistently(handlerServer.TransportParameters()).ShouldNot(Receive())
119				Eventually(done).Should(BeClosed())
120			})
121		})
122	})
123
124	Context("for the client", func() {
125		for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1} {
126			v := ver
127
128			Context(fmt.Sprintf("sending, for version %s", v), func() {
129				var extensionType uint16
130
131				BeforeEach(func() {
132					version = v
133					if v == protocol.VersionDraft29 {
134						extensionType = quicTLSExtensionTypeOldDrafts
135					} else {
136						extensionType = quicTLSExtensionType
137					}
138				})
139
140				It("only adds TransportParameters for the Encrypted Extensions", func() {
141					// test 2 other handshake types
142					Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty())
143					Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty())
144				})
145
146				It("adds TransportParameters to the ClientHello message", func() {
147					exts := handlerClient.GetExtensions(uint8(typeClientHello))
148					Expect(exts).To(HaveLen(1))
149					Expect(exts[0].Type).To(BeEquivalentTo(extensionType))
150					Expect(exts[0].Data).To(Equal([]byte("raboof")))
151				})
152			})
153		}
154
155		Context("receiving", func() {
156			var chExts []qtls.Extension
157
158			JustBeforeEach(func() {
159				chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions))
160				Expect(chExts).To(HaveLen(1))
161			})
162
163			It("sends the extension on the channel", func() {
164				go func() {
165					defer GinkgoRecover()
166					handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), chExts)
167				}()
168
169				var data []byte
170				Eventually(handlerClient.TransportParameters()).Should(Receive(&data))
171				Expect(data).To(Equal([]byte("foobar")))
172			})
173
174			It("sends nil on the channel if the extension is missing", func() {
175				go func() {
176					defer GinkgoRecover()
177					handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), nil)
178				}()
179
180				var data []byte
181				Eventually(handlerClient.TransportParameters()).Should(Receive(&data))
182				Expect(data).To(BeEmpty())
183			})
184
185			It("ignores extensions with different code points", func() {
186				go func() {
187					defer GinkgoRecover()
188					exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}}
189					handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), exts)
190				}()
191
192				var data []byte
193				Eventually(handlerClient.TransportParameters()).Should(Receive())
194				Expect(data).To(BeEmpty())
195			})
196
197			It("ignores extensions that are not sent with the EncryptedExtensions", func() {
198				done := make(chan struct{})
199				go func() {
200					defer GinkgoRecover()
201					handlerClient.ReceivedExtensions(uint8(typeFinished), chExts)
202					close(done)
203				}()
204
205				Consistently(handlerClient.TransportParameters()).ShouldNot(Receive())
206				Eventually(done).Should(BeClosed())
207			})
208		})
209	})
210})
211