1package handshake 2 3import ( 4 "fmt" 5 6 "github.com/lucas-clemente/quic-go/internal/protocol" 7 "github.com/lucas-clemente/quic-go/internal/qtls" 8 9 . "github.com/onsi/ginkgo" 10 . "github.com/onsi/gomega" 11) 12 13var _ = Describe("TLS Extension Handler, for the server", func() { 14 var ( 15 handlerServer tlsExtensionHandler 16 handlerClient tlsExtensionHandler 17 version protocol.VersionNumber 18 ) 19 20 BeforeEach(func() { 21 version = protocol.VersionDraft29 22 }) 23 24 JustBeforeEach(func() { 25 handlerServer = newExtensionHandler( 26 []byte("foobar"), 27 protocol.PerspectiveServer, 28 version, 29 ) 30 handlerClient = newExtensionHandler( 31 []byte("raboof"), 32 protocol.PerspectiveClient, 33 version, 34 ) 35 }) 36 37 Context("for the server", func() { 38 for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1} { 39 v := ver 40 41 Context(fmt.Sprintf("sending, for version %s", v), func() { 42 var extensionType uint16 43 44 BeforeEach(func() { 45 version = v 46 if v == protocol.VersionDraft29 { 47 extensionType = quicTLSExtensionTypeOldDrafts 48 } else { 49 extensionType = quicTLSExtensionType 50 } 51 }) 52 53 It("only adds TransportParameters for the Encrypted Extensions", func() { 54 // test 2 other handshake types 55 Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) 56 Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty()) 57 }) 58 59 It("adds TransportParameters to the EncryptedExtensions message", func() { 60 exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) 61 Expect(exts).To(HaveLen(1)) 62 Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) 63 Expect(exts[0].Data).To(Equal([]byte("foobar"))) 64 }) 65 }) 66 } 67 68 Context("receiving", func() { 69 var chExts []qtls.Extension 70 71 JustBeforeEach(func() { 72 chExts = handlerClient.GetExtensions(uint8(typeClientHello)) 73 Expect(chExts).To(HaveLen(1)) 74 }) 75 76 It("sends the extension on the channel", func() { 77 go func() { 78 defer GinkgoRecover() 79 handlerServer.ReceivedExtensions(uint8(typeClientHello), chExts) 80 }() 81 82 var data []byte 83 Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) 84 Expect(data).To(Equal([]byte("raboof"))) 85 }) 86 87 It("sends nil on the channel if the extension is missing", func() { 88 go func() { 89 defer GinkgoRecover() 90 handlerServer.ReceivedExtensions(uint8(typeClientHello), nil) 91 }() 92 93 var data []byte 94 Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) 95 Expect(data).To(BeEmpty()) 96 }) 97 98 It("ignores extensions with different code points", func() { 99 go func() { 100 defer GinkgoRecover() 101 exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} 102 handlerServer.ReceivedExtensions(uint8(typeClientHello), exts) 103 }() 104 105 var data []byte 106 Eventually(handlerServer.TransportParameters()).Should(Receive()) 107 Expect(data).To(BeEmpty()) 108 }) 109 110 It("ignores extensions that are not sent with the ClientHello", func() { 111 done := make(chan struct{}) 112 go func() { 113 defer GinkgoRecover() 114 handlerServer.ReceivedExtensions(uint8(typeFinished), chExts) 115 close(done) 116 }() 117 118 Consistently(handlerServer.TransportParameters()).ShouldNot(Receive()) 119 Eventually(done).Should(BeClosed()) 120 }) 121 }) 122 }) 123 124 Context("for the client", func() { 125 for _, ver := range []protocol.VersionNumber{protocol.VersionDraft29, protocol.Version1} { 126 v := ver 127 128 Context(fmt.Sprintf("sending, for version %s", v), func() { 129 var extensionType uint16 130 131 BeforeEach(func() { 132 version = v 133 if v == protocol.VersionDraft29 { 134 extensionType = quicTLSExtensionTypeOldDrafts 135 } else { 136 extensionType = quicTLSExtensionType 137 } 138 }) 139 140 It("only adds TransportParameters for the Encrypted Extensions", func() { 141 // test 2 other handshake types 142 Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) 143 Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty()) 144 }) 145 146 It("adds TransportParameters to the ClientHello message", func() { 147 exts := handlerClient.GetExtensions(uint8(typeClientHello)) 148 Expect(exts).To(HaveLen(1)) 149 Expect(exts[0].Type).To(BeEquivalentTo(extensionType)) 150 Expect(exts[0].Data).To(Equal([]byte("raboof"))) 151 }) 152 }) 153 } 154 155 Context("receiving", func() { 156 var chExts []qtls.Extension 157 158 JustBeforeEach(func() { 159 chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) 160 Expect(chExts).To(HaveLen(1)) 161 }) 162 163 It("sends the extension on the channel", func() { 164 go func() { 165 defer GinkgoRecover() 166 handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), chExts) 167 }() 168 169 var data []byte 170 Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) 171 Expect(data).To(Equal([]byte("foobar"))) 172 }) 173 174 It("sends nil on the channel if the extension is missing", func() { 175 go func() { 176 defer GinkgoRecover() 177 handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), nil) 178 }() 179 180 var data []byte 181 Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) 182 Expect(data).To(BeEmpty()) 183 }) 184 185 It("ignores extensions with different code points", func() { 186 go func() { 187 defer GinkgoRecover() 188 exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} 189 handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), exts) 190 }() 191 192 var data []byte 193 Eventually(handlerClient.TransportParameters()).Should(Receive()) 194 Expect(data).To(BeEmpty()) 195 }) 196 197 It("ignores extensions that are not sent with the EncryptedExtensions", func() { 198 done := make(chan struct{}) 199 go func() { 200 defer GinkgoRecover() 201 handlerClient.ReceivedExtensions(uint8(typeFinished), chExts) 202 close(done) 203 }() 204 205 Consistently(handlerClient.TransportParameters()).ShouldNot(Receive()) 206 Eventually(done).Should(BeClosed()) 207 }) 208 }) 209 }) 210}) 211