1package qtls
2
3import (
4	"bytes"
5	"fmt"
6	"net"
7	"testing"
8	"time"
9)
10
11type recordLayer struct {
12	in  <-chan []byte
13	out chan<- []byte
14
15	alertSent alert
16}
17
18func (r *recordLayer) SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) {
19}
20func (r *recordLayer) SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) {
21}
22func (r *recordLayer) ReadHandshakeMessage() ([]byte, error) { return <-r.in, nil }
23func (r *recordLayer) WriteRecord(b []byte) (int, error)     { r.out <- b; return len(b), nil }
24func (r *recordLayer) SendAlert(a uint8)                     { r.alertSent = alert(a) }
25
26type exportedKey struct {
27	typ           string // "read" or "write"
28	encLevel      EncryptionLevel
29	suite         *CipherSuiteTLS13
30	trafficSecret []byte
31}
32
33func compareExportedKeys(t *testing.T, k1, k2 *exportedKey) {
34	if k1.encLevel != k2.encLevel || k1.suite.ID != k2.suite.ID || !bytes.Equal(k1.trafficSecret, k2.trafficSecret) {
35		t.Fatal("mismatching keys")
36	}
37}
38
39type recordLayerWithKeys struct {
40	in  <-chan []byte
41	out chan<- interface{}
42}
43
44func (r *recordLayerWithKeys) SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) {
45	r.out <- &exportedKey{typ: "read", encLevel: encLevel, suite: suite, trafficSecret: trafficSecret}
46}
47func (r *recordLayerWithKeys) SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) {
48	r.out <- &exportedKey{typ: "write", encLevel: encLevel, suite: suite, trafficSecret: trafficSecret}
49}
50func (r *recordLayerWithKeys) ReadHandshakeMessage() ([]byte, error) { return <-r.in, nil }
51func (r *recordLayerWithKeys) WriteRecord(b []byte) (int, error)     { r.out <- b; return len(b), nil }
52func (r *recordLayerWithKeys) SendAlert(uint8)                       {}
53
54type unusedConn struct {
55	remoteAddr net.Addr
56}
57
58var _ net.Conn = &unusedConn{}
59
60func (unusedConn) Read([]byte) (int, error)         { panic("unexpected call to Read()") }
61func (unusedConn) Write([]byte) (int, error)        { panic("unexpected call to Write()") }
62func (unusedConn) Close() error                     { return nil }
63func (unusedConn) LocalAddr() net.Addr              { return &net.TCPAddr{} }
64func (c *unusedConn) RemoteAddr() net.Addr          { return c.remoteAddr }
65func (unusedConn) SetDeadline(time.Time) error      { return nil }
66func (unusedConn) SetReadDeadline(time.Time) error  { return nil }
67func (unusedConn) SetWriteDeadline(time.Time) error { return nil }
68
69func TestAlternativeRecordLayer(t *testing.T) {
70	sIn := make(chan []byte, 10)
71	sOut := make(chan interface{}, 10)
72	defer close(sOut)
73	cIn := make(chan []byte, 10)
74	cOut := make(chan interface{}, 10)
75	defer close(cOut)
76
77	serverEvents := make(chan interface{}, 100)
78	go func() {
79		for {
80			c, ok := <-sOut
81			if !ok {
82				return
83			}
84			serverEvents <- c
85			if b, ok := c.([]byte); ok {
86				cIn <- b
87			}
88		}
89	}()
90
91	clientEvents := make(chan interface{}, 100)
92	go func() {
93		for {
94			c, ok := <-cOut
95			if !ok {
96				return
97			}
98			clientEvents <- c
99			if b, ok := c.([]byte); ok {
100				sIn <- b
101			}
102		}
103	}()
104
105	errChan := make(chan error)
106	go func() {
107		extraConf := &ExtraConfig{
108			AlternativeRecordLayer: &recordLayerWithKeys{in: sIn, out: sOut},
109		}
110		tlsConn := Server(&unusedConn{}, testConfig, extraConf)
111		defer tlsConn.Close()
112		errChan <- tlsConn.Handshake()
113	}()
114
115	extraConf := &ExtraConfig{
116		AlternativeRecordLayer: &recordLayerWithKeys{in: cIn, out: cOut},
117	}
118	tlsConn := Client(&unusedConn{}, testConfig, extraConf)
119	defer tlsConn.Close()
120	if err := tlsConn.Handshake(); err != nil {
121		t.Fatalf("Handshake failed: %s", err)
122	}
123
124	// Handshakes completed. Now check that events were received in the correct order.
125	var clientHandshakeReadKey, clientHandshakeWriteKey *exportedKey
126	var clientApplicationReadKey, clientApplicationWriteKey *exportedKey
127	for i := 0; i <= 5; i++ {
128		ev := <-clientEvents
129		switch i {
130		case 0:
131			if ev.([]byte)[0] != typeClientHello {
132				t.Fatalf("expected ClientHello")
133			}
134		case 1:
135			keyEv := ev.(*exportedKey)
136			if keyEv.typ != "write" || keyEv.encLevel != EncryptionHandshake {
137				t.Fatalf("expected the handshake write key")
138			}
139			clientHandshakeWriteKey = keyEv
140		case 2:
141			keyEv := ev.(*exportedKey)
142			if keyEv.typ != "read" || keyEv.encLevel != EncryptionHandshake {
143				t.Fatalf("expected the handshake read key")
144			}
145			clientHandshakeReadKey = keyEv
146		case 3:
147			keyEv := ev.(*exportedKey)
148			if keyEv.typ != "read" || keyEv.encLevel != EncryptionApplication {
149				t.Fatalf("expected the application read key")
150			}
151			clientApplicationReadKey = keyEv
152		case 4:
153			if ev.([]byte)[0] != typeFinished {
154				t.Fatalf("expected Finished")
155			}
156		case 5:
157			keyEv := ev.(*exportedKey)
158			if keyEv.typ != "write" || keyEv.encLevel != EncryptionApplication {
159				t.Fatalf("expected the application write key")
160			}
161			clientApplicationWriteKey = keyEv
162		}
163	}
164	if len(clientEvents) > 0 {
165		t.Fatal("didn't expect any more client events")
166	}
167
168	for i := 0; i <= 8; i++ {
169		ev := <-serverEvents
170		switch i {
171		case 0:
172			if ev.([]byte)[0] != typeServerHello {
173				t.Fatalf("expected ServerHello")
174			}
175		case 1:
176			keyEv := ev.(*exportedKey)
177			if keyEv.typ != "read" || keyEv.encLevel != EncryptionHandshake {
178				t.Fatalf("expected the handshake read key")
179			}
180			compareExportedKeys(t, clientHandshakeWriteKey, keyEv)
181		case 2:
182			keyEv := ev.(*exportedKey)
183			if keyEv.typ != "write" || keyEv.encLevel != EncryptionHandshake {
184				t.Fatalf("expected the handshake write key")
185			}
186			compareExportedKeys(t, clientHandshakeReadKey, keyEv)
187		case 3:
188			if ev.([]byte)[0] != typeEncryptedExtensions {
189				t.Fatalf("expected EncryptedExtensions")
190			}
191		case 4:
192			if ev.([]byte)[0] != typeCertificate {
193				t.Fatalf("expected Certificate")
194			}
195		case 5:
196			if ev.([]byte)[0] != typeCertificateVerify {
197				t.Fatalf("expected CertificateVerify")
198			}
199		case 6:
200			if ev.([]byte)[0] != typeFinished {
201				t.Fatalf("expected Finished")
202			}
203		case 7:
204			keyEv := ev.(*exportedKey)
205			if keyEv.typ != "write" || keyEv.encLevel != EncryptionApplication {
206				t.Fatalf("expected the application write key")
207			}
208			compareExportedKeys(t, clientApplicationReadKey, keyEv)
209		case 8:
210			keyEv := ev.(*exportedKey)
211			if keyEv.typ != "read" || keyEv.encLevel != EncryptionApplication {
212				t.Fatalf("expected the application read key")
213			}
214			compareExportedKeys(t, clientApplicationWriteKey, keyEv)
215		}
216	}
217	if len(serverEvents) > 0 {
218		t.Fatal("didn't expect any more server events")
219	}
220}
221
222func TestErrorOnOldTLSVersions(t *testing.T) {
223	sIn := make(chan []byte, 10)
224	cIn := make(chan []byte, 10)
225	cOut := make(chan []byte, 10)
226
227	go func() {
228		for {
229			b, ok := <-cOut
230			if !ok {
231				return
232			}
233			if b[0] == typeClientHello {
234				m := new(clientHelloMsg)
235				if !m.unmarshal(b) {
236					panic("unmarshal failed")
237				}
238				m.raw = nil // need to reset, so marshal() actually marshals the changes
239				m.supportedVersions = []uint16{VersionTLS11, VersionTLS13}
240				b = m.marshal()
241			}
242			sIn <- b
243		}
244	}()
245
246	done := make(chan struct{})
247	go func() {
248		defer close(done)
249		extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: cIn, out: cOut}}
250		Client(&unusedConn{}, testConfig, extraConf).Handshake()
251	}()
252
253	serverRecordLayer := &recordLayer{in: sIn, out: cIn}
254	extraConf := &ExtraConfig{AlternativeRecordLayer: serverRecordLayer}
255	tlsConn := Server(&unusedConn{}, testConfig, extraConf)
256	defer tlsConn.Close()
257	err := tlsConn.Handshake()
258	if err == nil || err.Error() != "tls: client offered old TLS version 0x302" {
259		t.Fatal("expected the server to error when the client offers old versions")
260	}
261	if serverRecordLayer.alertSent != alertProtocolVersion {
262		t.Fatal("expected a protocol version alert to be sent")
263	}
264
265	cIn <- []byte{'f'}
266	<-done
267}
268
269func TestRejectConfigWithOldMaxVersion(t *testing.T) {
270	t.Run("for the client", func(t *testing.T) {
271		config := testConfig.Clone()
272		config.MaxVersion = VersionTLS12
273		tlsConn := Client(&unusedConn{}, config, &ExtraConfig{AlternativeRecordLayer: &recordLayer{}})
274		err := tlsConn.Handshake()
275		if err == nil || err.Error() != "tls: MaxVersion prevents QUIC from using TLS 1.3" {
276			t.Errorf("expected the handshake to fail")
277		}
278	})
279
280	t.Run("for the server", func(t *testing.T) {
281		in := make(chan []byte, 10)
282		out := make(chan []byte, 10)
283
284		done := make(chan struct{})
285		go func() {
286			defer close(done)
287			Client(
288				&unusedConn{},
289				testConfig,
290				&ExtraConfig{AlternativeRecordLayer: &recordLayer{in: in, out: out}},
291			).Handshake()
292		}()
293
294		config := testConfig.Clone()
295		config.MaxVersion = VersionTLS12
296		serverRecordLayer := &recordLayer{in: out, out: in}
297		err := Server(
298			&unusedConn{},
299			config,
300			&ExtraConfig{AlternativeRecordLayer: serverRecordLayer},
301		).Handshake()
302		if err == nil || err.Error() != "tls: MaxVersion prevents QUIC from using TLS 1.3" {
303			t.Errorf("expected the handshake to fail")
304		}
305		if serverRecordLayer.alertSent != alertInternalError {
306			t.Fatal("expected an internal error alert to be sent")
307		}
308	})
309
310	t.Run("for the server (using GetConfigForClient)", func(t *testing.T) {
311		in := make(chan []byte, 10)
312		out := make(chan []byte, 10)
313
314		done := make(chan struct{})
315		go func() {
316			defer close(done)
317			Client(
318				&unusedConn{},
319				testConfig,
320				&ExtraConfig{AlternativeRecordLayer: &recordLayer{in: in, out: out}},
321			).Handshake()
322		}()
323
324		config := testConfig.Clone()
325		config.GetConfigForClient = func(*ClientHelloInfo) (*Config, error) {
326			conf := testConfig.Clone()
327			conf.MaxVersion = VersionTLS12
328			return conf, nil
329		}
330		serverRecordLayer := &recordLayer{in: out, out: in}
331		err := Server(
332			&unusedConn{},
333			config,
334			&ExtraConfig{AlternativeRecordLayer: serverRecordLayer},
335		).Handshake()
336		if err == nil || err.Error() != "tls: MaxVersion prevents QUIC from using TLS 1.3" {
337			t.Errorf("expected the handshake to fail")
338		}
339		if serverRecordLayer.alertSent != alertInternalError {
340			t.Fatal("expected an internal error alert to be sent")
341		}
342	})
343}
344
345func TestForbiddenZeroRTT(t *testing.T) {
346	// run the first handshake to get a session ticket
347	clientConn, serverConn := localPipe(t)
348	errChan := make(chan error, 1)
349	go func() {
350		tlsConn := Server(serverConn, testConfig.Clone(), nil)
351		defer tlsConn.Close()
352		err := tlsConn.Handshake()
353		errChan <- err
354		if err != nil {
355			return
356		}
357		tlsConn.Write([]byte{0})
358	}()
359
360	clientConfig := testConfig.Clone()
361	clientConfig.ClientSessionCache = NewLRUClientSessionCache(10)
362	tlsConn := Client(clientConn, clientConfig, nil)
363	if err := tlsConn.Handshake(); err != nil {
364		t.Fatalf("first handshake failed: %s", err)
365	}
366	tlsConn.Read([]byte{0}) // make sure to read the session ticket
367	tlsConn.Close()
368	if err := <-errChan; err != nil {
369		t.Fatalf("first handshake failed: %s", err)
370	}
371
372	sIn := make(chan []byte, 10)
373	cIn := make(chan []byte, 10)
374	cOut := make(chan []byte, 10)
375
376	go func() {
377		for {
378			b, ok := <-cOut
379			if !ok {
380				return
381			}
382			if b[0] == typeClientHello {
383				msg := &clientHelloMsg{}
384				if ok := msg.unmarshal(b); !ok {
385					panic("unmarshaling failed")
386				}
387				msg.earlyData = true
388				msg.raw = nil
389				b = msg.marshal()
390			}
391			sIn <- b
392		}
393	}()
394
395	done := make(chan struct{})
396	go func() {
397		defer close(done)
398		extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: cIn, out: cOut}}
399		Client(&unusedConn{remoteAddr: clientConn.RemoteAddr()}, clientConfig, extraConf).Handshake()
400	}()
401
402	config := testConfig.Clone()
403	config.MinVersion = VersionTLS13
404	serverRecordLayer := &recordLayer{in: sIn, out: cIn}
405	extraConf := &ExtraConfig{AlternativeRecordLayer: serverRecordLayer}
406	tlsConn = Server(&unusedConn{}, config, extraConf)
407	err := tlsConn.Handshake()
408	if err == nil {
409		t.Fatal("expected handshake to fail")
410	}
411	if err.Error() != "tls: client sent unexpected early data" {
412		t.Fatalf("expected early data error")
413	}
414	if serverRecordLayer.alertSent != alertUnsupportedExtension {
415		t.Fatal("expected an unsupported extension alert to be sent")
416	}
417	cIn <- []byte{0} // make the client handshake error
418	<-done
419}
420
421func TestZeroRTTKeys(t *testing.T) {
422	// run the first handshake to get a session ticket
423	clientConn, serverConn := localPipe(t)
424	errChan := make(chan error, 1)
425	go func() {
426		extraConf := &ExtraConfig{MaxEarlyData: 1000}
427		tlsConn := Server(serverConn, testConfig, extraConf)
428		defer tlsConn.Close()
429		err := tlsConn.Handshake()
430		errChan <- err
431		if err != nil {
432			return
433		}
434		tlsConn.Write([]byte{0})
435	}()
436
437	clientConfig := testConfig.Clone()
438	clientConfig.ClientSessionCache = NewLRUClientSessionCache(10)
439	tlsConn := Client(clientConn, clientConfig, nil)
440	if err := tlsConn.Handshake(); err != nil {
441		t.Fatalf("first handshake failed: %s", err)
442	}
443	tlsConn.Read([]byte{0}) // make sure to read the session ticket
444	tlsConn.Close()
445	if err := <-errChan; err != nil {
446		t.Fatalf("first handshake failed: %s", err)
447	}
448
449	sIn := make(chan []byte, 10)
450	sOut := make(chan interface{}, 10)
451	defer close(sOut)
452	cIn := make(chan []byte, 10)
453	cOut := make(chan interface{}, 10)
454	defer close(cOut)
455
456	var serverEarlyData bool
457	var serverExportedKey *exportedKey
458	go func() {
459		for {
460			c, ok := <-sOut
461			if !ok {
462				return
463			}
464			if b, ok := c.([]byte); ok {
465				if b[0] == typeEncryptedExtensions {
466					var msg encryptedExtensionsMsg
467					if ok := msg.unmarshal(b); !ok {
468						panic("failed to unmarshal EncryptedExtensions")
469					}
470					serverEarlyData = msg.earlyData
471				}
472				cIn <- b
473			}
474			if k, ok := c.(*exportedKey); ok && k.encLevel == Encryption0RTT {
475				serverExportedKey = k
476			}
477		}
478	}()
479
480	var clientEarlyData bool
481	var clientExportedKey *exportedKey
482	go func() {
483		for {
484			c, ok := <-cOut
485			if !ok {
486				return
487			}
488			if b, ok := c.([]byte); ok {
489				if b[0] == typeClientHello {
490					var msg clientHelloMsg
491					if ok := msg.unmarshal(b); !ok {
492						panic("failed to unmarshal ClientHello")
493					}
494					clientEarlyData = msg.earlyData
495				}
496				sIn <- b
497			}
498			if k, ok := c.(*exportedKey); ok && k.encLevel == Encryption0RTT {
499				clientExportedKey = k
500			}
501		}
502	}()
503
504	errChan = make(chan error)
505	go func() {
506		extraConf := &ExtraConfig{
507			AlternativeRecordLayer: &recordLayerWithKeys{in: sIn, out: sOut},
508			MaxEarlyData:           1,
509			Accept0RTT:             func([]byte) bool { return true },
510		}
511		tlsConn := Server(&unusedConn{}, testConfig, extraConf)
512		defer tlsConn.Close()
513		errChan <- tlsConn.Handshake()
514	}()
515
516	extraConf := &ExtraConfig{
517		AlternativeRecordLayer: &recordLayerWithKeys{in: cIn, out: cOut},
518		Enable0RTT:             true,
519	}
520	tlsConn = Client(&unusedConn{remoteAddr: clientConn.RemoteAddr()}, clientConfig, extraConf)
521	defer tlsConn.Close()
522	if err := tlsConn.Handshake(); err != nil {
523		t.Fatalf("Handshake failed: %s", err)
524	}
525	if err := <-errChan; err != nil {
526		t.Fatalf("Handshake failed: %s", err)
527	}
528
529	if !clientEarlyData {
530		t.Fatal("expected the client to offer early data")
531	}
532	if !serverEarlyData {
533		t.Fatal("expected the server to offer early data")
534	}
535	compareExportedKeys(t, clientExportedKey, serverExportedKey)
536}
537
538func TestEncodeIntoSessionTicket(t *testing.T) {
539	raddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
540	sIn := make(chan []byte, 10)
541	sOut := make(chan []byte, 10)
542
543	// do a first handshake and encode a "foobar" into the session ticket
544	errChan := make(chan error, 1)
545	stChan := make(chan []byte, 1)
546	go func() {
547		extraConf := &ExtraConfig{
548			AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut},
549			MaxEarlyData:           1,
550		}
551		server := Server(&unusedConn{remoteAddr: raddr}, testConfig, extraConf)
552		defer server.Close()
553		err := server.Handshake()
554		if err != nil {
555			errChan <- err
556			return
557		}
558		st, err := server.GetSessionTicket([]byte("foobar"))
559		if err != nil {
560			errChan <- err
561			return
562		}
563		stChan <- st
564		errChan <- nil
565	}()
566
567	clientConf := testConfig.Clone()
568	clientConf.ClientSessionCache = NewLRUClientSessionCache(10)
569	extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: sOut, out: sIn}}
570	client := Client(&unusedConn{remoteAddr: raddr}, clientConf, extraConf)
571	if err := client.Handshake(); err != nil {
572		t.Fatalf("first handshake failed %s", err)
573	}
574	if err := <-errChan; err != nil {
575		t.Fatalf("first handshake failed %s", err)
576	}
577	sOut <- <-stChan
578	if err := client.HandlePostHandshakeMessage(); err != nil {
579		t.Fatalf("handling the session ticket failed: %s", err)
580	}
581	client.Close()
582
583	dataChan := make(chan []byte, 1)
584	errChan = make(chan error, 1)
585	go func() {
586		extraConf := &ExtraConfig{
587			AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut},
588			MaxEarlyData:           1,
589			Accept0RTT: func(data []byte) bool {
590				dataChan <- data
591				return true
592			},
593		}
594		server := Server(&unusedConn{remoteAddr: raddr}, testConfig, extraConf)
595		defer server.Close()
596		errChan <- server.Handshake()
597	}()
598
599	extraConf2 := extraConf.Clone()
600	extraConf2.Enable0RTT = true
601	client = Client(&unusedConn{remoteAddr: raddr}, clientConf, extraConf2)
602	if err := client.Handshake(); err != nil {
603		t.Fatalf("second handshake failed %s", err)
604	}
605	defer client.Close()
606	if err := <-errChan; err != nil {
607		t.Fatalf("second handshake failed %s", err)
608	}
609	if len(dataChan) != 1 {
610		t.Fatal("expected to receive application data")
611	}
612	if data := <-dataChan; !bytes.Equal(data, []byte("foobar")) {
613		t.Fatalf("expected to receive a foobar, got %s", string(data))
614	}
615}
616
617func TestZeroRTTRejection(t *testing.T) {
618	for _, doReject := range []bool{true, false} {
619		t.Run(fmt.Sprintf("doing reject: %t", doReject), func(t *testing.T) {
620			raddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
621			sIn := make(chan []byte, 10)
622			sOut := make(chan []byte, 10)
623
624			// do a first handshake and encode a "foobar" into the session ticket
625			errChan := make(chan error, 1)
626			go func() {
627				extraConf := &ExtraConfig{
628					AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut},
629					MaxEarlyData:           1,
630				}
631				server := Server(&unusedConn{remoteAddr: raddr}, testConfig, extraConf)
632				defer server.Close()
633				err := server.Handshake()
634				if err != nil {
635					errChan <- err
636					return
637				}
638				st, err := server.GetSessionTicket(nil)
639				if err != nil {
640					errChan <- err
641					return
642				}
643				sOut <- st
644				errChan <- nil
645			}()
646
647			conf := testConfig.Clone()
648			conf.ClientSessionCache = NewLRUClientSessionCache(10)
649			extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: sOut, out: sIn}}
650			client := Client(&unusedConn{remoteAddr: raddr}, conf, extraConf)
651			if err := client.Handshake(); err != nil {
652				t.Fatalf("first handshake failed %s", err)
653			}
654			if err := <-errChan; err != nil {
655				t.Fatalf("first handshake failed %s", err)
656			}
657			if err := client.HandlePostHandshakeMessage(); err != nil {
658				t.Fatalf("handling the session ticket failed: %s", err)
659			}
660			client.Close()
661
662			// now dial the second connection
663			errChan = make(chan error, 1)
664			connStateChan := make(chan ConnectionStateWith0RTT, 1)
665			go func() {
666				extraConf := &ExtraConfig{
667					AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut},
668					MaxEarlyData:           1,
669					Accept0RTT:             func(data []byte) bool { return !doReject },
670				}
671				server := Server(&unusedConn{remoteAddr: raddr}, testConfig, extraConf)
672				defer server.Close()
673				errChan <- server.Handshake()
674				connStateChan <- server.ConnectionStateWith0RTT()
675			}()
676
677			extraConf2 := extraConf.Clone()
678			extraConf2.Enable0RTT = true
679			var rejected bool
680			extraConf2.Rejected0RTT = func() { rejected = true }
681			client = Client(&unusedConn{remoteAddr: raddr}, conf, extraConf2)
682			if err := client.Handshake(); err != nil {
683				t.Fatalf("second handshake failed %s", err)
684			}
685			defer client.Close()
686			if err := <-errChan; err != nil {
687				t.Fatalf("second handshake failed %s", err)
688			}
689			if rejected != doReject {
690				t.Fatal("wrong rejection")
691			}
692			if client.ConnectionStateWith0RTT().Used0RTT == doReject {
693				t.Fatal("wrong connection state on the client")
694			}
695			if (<-connStateChan).Used0RTT == doReject {
696				t.Fatal("wrong connection state on the server")
697			}
698		})
699	}
700}
701
702func TestZeroRTTALPN(t *testing.T) {
703	run := func(t *testing.T, proto1, proto2 string, expectReject bool) {
704		raddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
705		sIn := make(chan []byte, 10)
706		sOut := make(chan []byte, 10)
707
708		// do a first handshake and encode a "foobar" into the session ticket
709		errChan := make(chan error, 1)
710		go func() {
711			serverConf := testConfig.Clone()
712			serverConf.NextProtos = []string{proto1}
713			extraConf := &ExtraConfig{
714				AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut},
715				MaxEarlyData:           1,
716			}
717			server := Server(&unusedConn{remoteAddr: raddr}, serverConf, extraConf)
718			defer server.Close()
719			err := server.Handshake()
720			if err != nil {
721				errChan <- err
722				return
723			}
724			st, err := server.GetSessionTicket(nil)
725			if err != nil {
726				errChan <- err
727				return
728			}
729			sOut <- st
730			errChan <- nil
731		}()
732
733		clientConf := testConfig.Clone()
734		clientConf.NextProtos = []string{proto1}
735		clientConf.ClientSessionCache = NewLRUClientSessionCache(10)
736		extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: sOut, out: sIn}}
737		client := Client(&unusedConn{remoteAddr: raddr}, clientConf, extraConf)
738		if err := client.Handshake(); err != nil {
739			t.Fatalf("first handshake failed %s", err)
740		}
741		if err := <-errChan; err != nil {
742			t.Fatalf("first handshake failed %s", err)
743		}
744		if err := client.HandlePostHandshakeMessage(); err != nil {
745			t.Fatalf("handling the session ticket failed: %s", err)
746		}
747		client.Close()
748
749		// now dial the second connection
750		errChan = make(chan error, 1)
751		connStateChan := make(chan ConnectionStateWith0RTT, 1)
752		go func() {
753			serverConf := testConfig.Clone()
754			serverConf.NextProtos = []string{proto2}
755			extraConf := &ExtraConfig{
756				AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut},
757				Accept0RTT:             func([]byte) bool { return true },
758				MaxEarlyData:           1,
759			}
760			server := Server(&unusedConn{remoteAddr: raddr}, serverConf, extraConf)
761			defer server.Close()
762			errChan <- server.Handshake()
763			connStateChan <- server.ConnectionStateWith0RTT()
764		}()
765
766		clientConf.NextProtos = []string{proto2}
767		extraConf.Enable0RTT = true
768		var rejected bool
769		extraConf.Rejected0RTT = func() { rejected = true }
770		client = Client(&unusedConn{remoteAddr: raddr}, clientConf, extraConf)
771		if err := client.Handshake(); err != nil {
772			t.Fatalf("second handshake failed %s", err)
773		}
774		defer client.Close()
775		if err := <-errChan; err != nil {
776			t.Fatalf("second handshake failed %s", err)
777		}
778		if expectReject {
779			if !rejected {
780				t.Fatal("expected 0-RTT to be rejected")
781			}
782			if client.ConnectionStateWith0RTT().Used0RTT {
783				t.Fatal("expected 0-RTT to be rejected")
784			}
785			if (<-connStateChan).Used0RTT {
786				t.Fatal("expected 0-RTT to be rejected")
787			}
788		} else {
789			if rejected {
790				t.Fatal("didn't expect 0-RTT to be rejected")
791			}
792			if !client.ConnectionStateWith0RTT().Used0RTT {
793				t.Fatal("didn't expect 0-RTT to be rejected")
794			}
795			if !(<-connStateChan).Used0RTT {
796				t.Fatal("didn't expect 0-RTT to be rejected")
797			}
798		}
799	}
800
801	t.Run("with the same alpn", func(t *testing.T) {
802		run(t, "proto1", "proto1", false)
803	})
804	t.Run("with different alpn", func(t *testing.T) {
805		run(t, "proto1", "proto2", true)
806	})
807}
808