1package dtls
2
3import (
4	"bytes"
5	"crypto/rand"
6	"crypto/tls"
7	"crypto/x509"
8	"errors"
9	"fmt"
10	"net"
11	"testing"
12	"time"
13
14	"github.com/pion/transport/test"
15)
16
17// Seems to strict for out implementation at this point
18// func TestNetTest(t *testing.T) {
19// 	lim := test.TimeOut(time.Minute*1 + time.Second*10)
20// 	defer lim.Stop()
21//
22// 	nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
23// 		c1, c2, err = pipeMemory()
24// 		if err != nil {
25// 			return nil, nil, nil, err
26// 		}
27// 		stop = func() {
28// 			c1.Close()
29// 			c2.Close()
30// 		}
31// 		return
32// 	})
33// }
34
35func TestStressDuplex(t *testing.T) {
36	// Limit runtime in case of deadlocks
37	lim := test.TimeOut(time.Second * 20)
38	defer lim.Stop()
39
40	// Check for leaking routines
41	report := test.CheckRoutines(t)
42	defer report()
43
44	// Run the test
45	stressDuplex(t)
46}
47
48func stressDuplex(t *testing.T) {
49	ca, cb, err := pipeMemory()
50	if err != nil {
51		t.Fatal(err)
52	}
53
54	defer func() {
55		err = ca.Close()
56		if err != nil {
57			t.Fatal(err)
58		}
59		err = cb.Close()
60		if err != nil {
61			t.Fatal(err)
62		}
63	}()
64
65	opt := test.Options{
66		MsgSize:  2048,
67		MsgCount: 100,
68	}
69
70	err = test.StressDuplex(ca, cb, opt)
71	if err != nil {
72		t.Fatal(err)
73	}
74}
75
76func pipeMemory() (*Conn, *Conn, error) {
77	// In memory pipe
78	ca, cb := net.Pipe()
79
80	type result struct {
81		c   *Conn
82		err error
83	}
84
85	c := make(chan result)
86
87	// Setup client
88	go func() {
89		client, err := testClient(ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true)
90		c <- result{client, err}
91	}()
92
93	// Setup server
94	server, err := testServer(cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true)
95	if err != nil {
96		return nil, nil, err
97	}
98
99	// Receive client
100	res := <-c
101	if res.err != nil {
102		return nil, nil, res.err
103	}
104
105	return res.c, server, nil
106}
107
108func testClient(c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) {
109	if generateCertificate {
110		clientCert, err := GenerateSelfSigned()
111		if err != nil {
112			return nil, err
113		}
114		cfg.Certificates = []tls.Certificate{clientCert}
115	}
116	cfg.InsecureSkipVerify = true
117	return Client(c, cfg)
118}
119
120func testServer(c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) {
121	if generateCertificate {
122		serverCert, err := GenerateSelfSigned()
123		if err != nil {
124			return nil, err
125		}
126		cfg.Certificates = []tls.Certificate{serverCert}
127	}
128	return Server(c, cfg)
129}
130
131func TestHandshakeWithAlert(t *testing.T) {
132	alertErr := errors.New("alert: Alert LevelFatal: InsufficientSecurity")
133	// Limit runtime in case of deadlocks
134	lim := test.TimeOut(time.Second * 20)
135	defer lim.Stop()
136
137	clientErr := make(chan error, 1)
138
139	ca, cb := net.Pipe()
140	go func() {
141		conf := &Config{
142			CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
143		}
144
145		_, err := testClient(ca, conf, true)
146		clientErr <- err
147	}()
148
149	config := &Config{
150		CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
151	}
152
153	if _, err := testServer(cb, config, true); err != errCipherSuiteNoIntersection {
154		t.Fatalf("TestHandshakeWithAlert: Client error exp(%v) failed(%v)", errCipherSuiteNoIntersection, err)
155	}
156
157	if err := <-clientErr; err.Error() != alertErr.Error() {
158		t.Fatalf("TestHandshakeWithAlert: Client error exp(%v) failed(%v)", alertErr, err)
159	}
160}
161
162func TestExportKeyingMaterial(t *testing.T) {
163	var rand [28]byte
164	exportLabel := "EXTRACTOR-dtls_srtp"
165
166	expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b}
167	expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77}
168
169	c := &Conn{
170		state: State{
171			localRandom:  handshakeRandom{time.Unix(500, 0), rand},
172			remoteRandom: handshakeRandom{time.Unix(1000, 0), rand},
173			cipherSuite:  &cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{},
174		},
175	}
176	c.setLocalEpoch(0)
177
178	_, err := c.ExportKeyingMaterial(exportLabel, nil, 0)
179	if err != errHandshakeInProgress {
180		t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err)
181	}
182
183	c.setLocalEpoch(1)
184	_, err = c.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0)
185	if err != errContextUnsupported {
186		t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err)
187	}
188
189	for k := range invalidKeyingLabels {
190		_, err = c.ExportKeyingMaterial(k, nil, 0)
191		if err != errReservedExportKeyingMaterial {
192			t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err)
193		}
194	}
195
196	keyingMaterial, err := c.ExportKeyingMaterial(exportLabel, nil, 10)
197	if err != nil {
198		t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
199	} else if !bytes.Equal(keyingMaterial, expectedServerKey) {
200		t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedServerKey, keyingMaterial)
201	}
202
203	c.state.isClient = true
204	keyingMaterial, err = c.ExportKeyingMaterial(exportLabel, nil, 10)
205	if err != nil {
206		t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
207	} else if !bytes.Equal(keyingMaterial, expectedClientKey) {
208		t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedClientKey, keyingMaterial)
209	}
210}
211
212func TestPSK(t *testing.T) {
213	// Limit runtime in case of deadlocks
214	lim := test.TimeOut(time.Second * 20)
215	defer lim.Stop()
216
217	for _, test := range []struct {
218		Name           string
219		ServerIdentity []byte
220	}{
221		{
222			Name:           "Server identity specified",
223			ServerIdentity: []byte("Test Identity"),
224		},
225		{
226			Name:           "Server identity nil",
227			ServerIdentity: nil,
228		},
229	} {
230
231		clientIdentity := []byte("Client Identity")
232		clientErr := make(chan error, 1)
233
234		ca, cb := net.Pipe()
235		go func() {
236			conf := &Config{
237				PSK: func(hint []byte) ([]byte, error) {
238					if !bytes.Equal(test.ServerIdentity, hint) { // nolint
239						return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) // nolint
240					}
241
242					return []byte{0xAB, 0xC1, 0x23}, nil
243				},
244				PSKIdentityHint: clientIdentity,
245				CipherSuites:    []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
246			}
247
248			_, err := testClient(ca, conf, false)
249			clientErr <- err
250		}()
251
252		config := &Config{
253			PSK: func(hint []byte) ([]byte, error) {
254				if !bytes.Equal(clientIdentity, hint) {
255					return nil, fmt.Errorf("TestPSK: Server got invalid identity expected(% 02x) actual(% 02x)", clientIdentity, hint)
256				}
257				return []byte{0xAB, 0xC1, 0x23}, nil
258			},
259			PSKIdentityHint: test.ServerIdentity,
260			CipherSuites:    []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
261		}
262
263		if _, err := testServer(cb, config, false); err != nil {
264			t.Fatalf("TestPSK: Server failed(%v)", err)
265		}
266
267		if err := <-clientErr; err != nil {
268			t.Fatal(err)
269		}
270	}
271}
272
273func TestPSKHintFail(t *testing.T) {
274	serverAlertError := errors.New("alert: Alert LevelFatal: InternalError")
275	pskRejected := errors.New("PSK Rejected")
276
277	// Limit runtime in case of deadlocks
278	lim := test.TimeOut(time.Second * 20)
279	defer lim.Stop()
280
281	clientErr := make(chan error, 1)
282
283	ca, cb := net.Pipe()
284	go func() {
285		conf := &Config{
286			PSK: func(hint []byte) ([]byte, error) {
287				return nil, pskRejected
288			},
289			PSKIdentityHint: []byte{},
290			CipherSuites:    []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
291		}
292
293		_, err := testClient(ca, conf, false)
294		clientErr <- err
295	}()
296
297	config := &Config{
298		PSK: func(hint []byte) ([]byte, error) {
299			return nil, pskRejected
300		},
301		PSKIdentityHint: []byte{},
302		CipherSuites:    []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
303	}
304
305	if _, err := testServer(cb, config, false); err.Error() != serverAlertError.Error() {
306		t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err)
307	}
308
309	if err := <-clientErr; err != pskRejected {
310		t.Fatalf("TestPSK: Client error exp(%v) failed(%v)", pskRejected, err)
311	}
312}
313
314func TestClientTimeout(t *testing.T) {
315	// Limit runtime in case of deadlocks
316	lim := test.TimeOut(time.Second * 20)
317	defer lim.Stop()
318
319	clientErr := make(chan error, 1)
320
321	ca, _ := net.Pipe()
322	go func() {
323		conf := &Config{
324			ConnectTimeout: ConnectTimeoutOption(1 * time.Second),
325		}
326
327		_, err := testClient(ca, conf, true)
328		clientErr <- err
329	}()
330
331	// no server!
332
333	if err := <-clientErr; err != errConnectTimeout {
334		t.Fatalf("TestClientTimeout: Client error exp(%v) failed(%v)", errConnectTimeout, err)
335	}
336}
337
338func TestSRTPConfiguration(t *testing.T) {
339	for _, test := range []struct {
340		Name            string
341		ClientSRTP      []SRTPProtectionProfile
342		ServerSRTP      []SRTPProtectionProfile
343		ExpectedProfile SRTPProtectionProfile
344		WantClientError error
345		WantServerError error
346	}{
347		{
348			Name:            "No SRTP in use",
349			ClientSRTP:      nil,
350			ServerSRTP:      nil,
351			ExpectedProfile: 0,
352			WantClientError: nil,
353			WantServerError: nil,
354		},
355		{
356			Name:            "SRTP both ends",
357			ClientSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
358			ServerSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
359			ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
360			WantClientError: nil,
361			WantServerError: nil,
362		},
363		{
364			Name:            "SRTP client only",
365			ClientSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
366			ServerSRTP:      nil,
367			ExpectedProfile: 0,
368			WantClientError: fmt.Errorf("alert: Alert LevelFatal: InsufficientSecurity"),
369			WantServerError: errServerNoMatchingSRTPProfile,
370		},
371		{
372			Name:            "SRTP server only",
373			ClientSRTP:      nil,
374			ServerSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
375			ExpectedProfile: 0,
376			WantClientError: nil,
377			WantServerError: nil,
378		},
379	} {
380		ca, cb := net.Pipe()
381		type result struct {
382			c   *Conn
383			err error
384		}
385		c := make(chan result)
386
387		go func() {
388			client, err := testClient(ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}, true)
389			c <- result{client, err}
390		}()
391
392		server, err := testServer(cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true)
393		if err != nil || test.WantServerError != nil {
394			if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
395				t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
396			}
397		}
398
399		res := <-c
400		if res.err != nil || test.WantClientError != nil {
401			if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
402				t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
403			}
404		}
405
406		actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile()
407		if actualClientSRTP != test.ExpectedProfile {
408			t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP)
409		}
410
411		actualServerSRTP, _ := server.SelectedSRTPProtectionProfile()
412		if actualServerSRTP != test.ExpectedProfile {
413			t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP)
414		}
415	}
416}
417
418func TestClientCertificate(t *testing.T) {
419	srvCert, err := GenerateSelfSigned()
420	if err != nil {
421		t.Fatal(err)
422	}
423	srvCAPool := x509.NewCertPool()
424	srvCertificate, err := x509.ParseCertificate(srvCert.Certificate[0])
425	if err != nil {
426		t.Fatal(err)
427	}
428	srvCAPool.AddCert(srvCertificate)
429
430	cert, err := GenerateSelfSigned()
431	if err != nil {
432		t.Fatal(err)
433	}
434	certificate, err := x509.ParseCertificate(cert.Certificate[0])
435	if err != nil {
436		t.Fatal(err)
437	}
438	caPool := x509.NewCertPool()
439	caPool.AddCert(certificate)
440
441	t.Parallel()
442	tests := map[string]struct {
443		clientCfg *Config
444		serverCfg *Config
445		wantErr   bool
446	}{
447		"NoClientCert": {
448			clientCfg: &Config{RootCAs: srvCAPool},
449			serverCfg: &Config{
450				Certificates: []tls.Certificate{srvCert},
451				ClientAuth:   NoClientCert,
452				ClientCAs:    caPool,
453			},
454		},
455		"NoClientCert_cert": {
456			clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
457			serverCfg: &Config{
458				Certificates: []tls.Certificate{srvCert},
459				ClientAuth:   RequireAnyClientCert,
460			},
461		},
462		"RequestClientCert_cert": {
463			clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
464			serverCfg: &Config{
465				Certificates: []tls.Certificate{srvCert},
466				ClientAuth:   RequestClientCert,
467			},
468		},
469		"RequestClientCert_no_cert": {
470			clientCfg: &Config{RootCAs: srvCAPool},
471			serverCfg: &Config{
472				Certificates: []tls.Certificate{srvCert},
473				ClientAuth:   RequestClientCert,
474				ClientCAs:    caPool,
475			},
476		},
477		"RequireAnyClientCert": {
478			clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
479			serverCfg: &Config{
480				Certificates: []tls.Certificate{srvCert},
481				ClientAuth:   RequireAnyClientCert,
482			},
483		},
484		"RequireAnyClientCert_error": {
485			clientCfg: &Config{RootCAs: srvCAPool},
486			serverCfg: &Config{
487				Certificates: []tls.Certificate{srvCert},
488				ClientAuth:   RequireAnyClientCert,
489			},
490			wantErr: true,
491		},
492		"VerifyClientCertIfGiven_no_cert": {
493			clientCfg: &Config{RootCAs: srvCAPool},
494			serverCfg: &Config{
495				Certificates: []tls.Certificate{srvCert},
496				ClientAuth:   VerifyClientCertIfGiven,
497				ClientCAs:    caPool,
498			},
499		},
500		"VerifyClientCertIfGiven_cert": {
501			clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
502			serverCfg: &Config{
503				Certificates: []tls.Certificate{srvCert},
504				ClientAuth:   VerifyClientCertIfGiven,
505				ClientCAs:    caPool,
506			},
507		},
508		"VerifyClientCertIfGiven_error": {
509			clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
510			serverCfg: &Config{
511				Certificates: []tls.Certificate{srvCert},
512				ClientAuth:   VerifyClientCertIfGiven,
513			},
514			wantErr: true,
515		},
516		"RequireAndVerifyClientCert": {
517			clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
518			serverCfg: &Config{
519				Certificates: []tls.Certificate{srvCert},
520				ClientAuth:   RequireAndVerifyClientCert,
521				ClientCAs:    caPool,
522			},
523		},
524	}
525	for name, tt := range tests {
526		tt := tt
527		t.Run(name, func(t *testing.T) {
528			ca, cb := net.Pipe()
529			type result struct {
530				c   *Conn
531				err error
532			}
533			c := make(chan result)
534
535			go func() {
536				client, err := Client(ca, tt.clientCfg)
537				c <- result{client, err}
538			}()
539
540			server, err := Server(cb, tt.serverCfg)
541			res := <-c
542
543			if tt.wantErr {
544				if err != nil {
545					// Error expected, test succeeded
546					return
547				}
548				t.Error("Error expected")
549			}
550			if err != nil {
551				t.Errorf("TestClientCertificate: Server failed(%v)", err)
552			}
553
554			if res.err != nil {
555				t.Errorf("TestClientCertificate: Client failed(%v)", res.err)
556			}
557			actualClientCert := server.RemoteCertificate()
558			if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert {
559				if actualClientCert == nil {
560					t.Errorf("TestClientCertificate: Client did not provide a certificate")
561				}
562
563				if len(actualClientCert) != len(tt.clientCfg.Certificates[0].Certificate) || !bytes.Equal(tt.clientCfg.Certificates[0].Certificate[0], actualClientCert[0]) {
564					t.Errorf("TestClientCertificate: Client certificate was not communicated correctly")
565				}
566			}
567			if tt.serverCfg.ClientAuth == NoClientCert {
568				if actualClientCert != nil {
569					t.Errorf("TestClientCertificate: Client certificate wasn't expected")
570				}
571			}
572
573			actualServerCert := res.c.RemoteCertificate()
574			if actualServerCert == nil {
575				t.Errorf("TestClientCertificate: Server did not provide a certificate")
576			}
577
578			if len(actualServerCert) != len(tt.serverCfg.Certificates[0].Certificate) || !bytes.Equal(tt.serverCfg.Certificates[0].Certificate[0], actualServerCert[0]) {
579				t.Errorf("TestClientCertificate: Server certificate was not communicated correctly")
580			}
581		})
582	}
583}
584
585func TestExtendedMasterSecret(t *testing.T) {
586	t.Parallel()
587	tests := map[string]struct {
588		clientCfg         *Config
589		serverCfg         *Config
590		expectedClientErr error
591		expectedServerErr error
592	}{
593		"Request_Request_ExtendedMasterSecret": {
594			clientCfg: &Config{
595				ExtendedMasterSecret: RequestExtendedMasterSecret,
596			},
597			serverCfg: &Config{
598				ExtendedMasterSecret: RequestExtendedMasterSecret,
599			},
600			expectedClientErr: nil,
601			expectedServerErr: nil,
602		},
603		"Request_Require_ExtendedMasterSecret": {
604			clientCfg: &Config{
605				ExtendedMasterSecret: RequestExtendedMasterSecret,
606			},
607			serverCfg: &Config{
608				ExtendedMasterSecret: RequireExtendedMasterSecret,
609			},
610			expectedClientErr: nil,
611			expectedServerErr: nil,
612		},
613		"Request_Disable_ExtendedMasterSecret": {
614			clientCfg: &Config{
615				ExtendedMasterSecret: RequestExtendedMasterSecret,
616			},
617			serverCfg: &Config{
618				ExtendedMasterSecret: DisableExtendedMasterSecret,
619			},
620			expectedClientErr: nil,
621			expectedServerErr: nil,
622		},
623		"Require_Request_ExtendedMasterSecret": {
624			clientCfg: &Config{
625				ExtendedMasterSecret: RequireExtendedMasterSecret,
626			},
627			serverCfg: &Config{
628				ExtendedMasterSecret: RequestExtendedMasterSecret,
629			},
630			expectedClientErr: nil,
631			expectedServerErr: nil,
632		},
633		"Require_Require_ExtendedMasterSecret": {
634			clientCfg: &Config{
635				ExtendedMasterSecret: RequireExtendedMasterSecret,
636			},
637			serverCfg: &Config{
638				ExtendedMasterSecret: RequireExtendedMasterSecret,
639			},
640			expectedClientErr: nil,
641			expectedServerErr: nil,
642		},
643		"Require_Disable_ExtendedMasterSecret": {
644			clientCfg: &Config{
645				ExtendedMasterSecret: RequireExtendedMasterSecret,
646			},
647			serverCfg: &Config{
648				ExtendedMasterSecret: DisableExtendedMasterSecret,
649			},
650			expectedClientErr: errClientRequiredButNoServerEMS,
651			expectedServerErr: fmt.Errorf("alert: Alert LevelFatal: InsufficientSecurity"),
652		},
653		"Disable_Request_ExtendedMasterSecret": {
654			clientCfg: &Config{
655				ExtendedMasterSecret: DisableExtendedMasterSecret,
656			},
657			serverCfg: &Config{
658				ExtendedMasterSecret: RequestExtendedMasterSecret,
659			},
660			expectedClientErr: nil,
661			expectedServerErr: nil,
662		},
663		"Disable_Require_ExtendedMasterSecret": {
664			clientCfg: &Config{
665				ExtendedMasterSecret: DisableExtendedMasterSecret,
666			},
667			serverCfg: &Config{
668				ExtendedMasterSecret: RequireExtendedMasterSecret,
669			},
670			expectedClientErr: fmt.Errorf("alert: Alert LevelFatal: InsufficientSecurity"),
671			expectedServerErr: errServerRequiredButNoClientEMS,
672		},
673		"Disable_Disable_ExtendedMasterSecret": {
674			clientCfg: &Config{
675				ExtendedMasterSecret: DisableExtendedMasterSecret,
676			},
677			serverCfg: &Config{
678				ExtendedMasterSecret: DisableExtendedMasterSecret,
679			},
680			expectedClientErr: nil,
681			expectedServerErr: nil,
682		},
683	}
684	for name, tt := range tests {
685		tt := tt
686		t.Run(name, func(t *testing.T) {
687			ca, cb := net.Pipe()
688			type result struct {
689				c   *Conn
690				err error
691			}
692			c := make(chan result)
693
694			go func() {
695				client, err := testClient(ca, tt.clientCfg, true)
696				c <- result{client, err}
697			}()
698
699			_, err := testServer(cb, tt.serverCfg, true)
700			res := <-c
701
702			if tt.expectedClientErr != nil {
703				if res.err.Error() != tt.expectedClientErr.Error() {
704					t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err)
705				}
706			}
707
708			if tt.expectedServerErr != nil {
709				if err.Error() != tt.expectedServerErr.Error() {
710					t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err)
711				}
712			}
713
714		})
715	}
716}
717
718func TestServerCertificate(t *testing.T) {
719	t.Parallel()
720
721	cert, err := GenerateSelfSigned()
722	if err != nil {
723		t.Fatal(err)
724	}
725	certificate, err := x509.ParseCertificate(cert.Certificate[0])
726	if err != nil {
727		t.Fatal(err)
728	}
729	caPool := x509.NewCertPool()
730	caPool.AddCert(certificate)
731
732	tests := map[string]struct {
733		clientCfg *Config
734		serverCfg *Config
735		wantErr   bool
736	}{
737		"no_ca": {
738			clientCfg: &Config{},
739			serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
740			wantErr:   true,
741		},
742		"good_ca": {
743			clientCfg: &Config{RootCAs: caPool},
744			serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
745		},
746		"no_ca_skip_verify": {
747			clientCfg: &Config{InsecureSkipVerify: true},
748			serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
749		},
750		"good_ca_skip_verify_custom_verify_peer": {
751			clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}},
752			serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error {
753				if len(chain) != 0 {
754					return errors.New("not expected chain")
755				}
756				return nil
757			}},
758		},
759		"good_ca_verify_custom_verify_peer": {
760			clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}},
761			serverCfg: &Config{ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error {
762				if len(chain) == 0 {
763					return errors.New("expected chain")
764				}
765				return nil
766			}},
767		},
768		"good_ca_custom_verify_peer": {
769			clientCfg: &Config{
770				RootCAs: caPool,
771				VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
772					return errors.New("wrong cert")
773				},
774			},
775			serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
776			wantErr:   true,
777		},
778		"server_name": {
779			clientCfg: &Config{RootCAs: caPool, ServerName: certificate.Subject.CommonName},
780			serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
781		},
782		"server_name_error": {
783			clientCfg: &Config{RootCAs: caPool, ServerName: "barfoo"},
784			serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
785			wantErr:   true,
786		},
787	}
788	for name, tt := range tests {
789		tt := tt
790		t.Run(name, func(t *testing.T) {
791			ca, cb := net.Pipe()
792			go func() {
793				_, _ = Server(cb, tt.serverCfg)
794			}()
795			_, err := Client(ca, tt.clientCfg)
796			if !tt.wantErr && err != nil {
797				t.Errorf("TestClientCertificate: Client failed(%v)", err)
798			}
799			if tt.wantErr && err == nil {
800				t.Fatal("Error expected")
801			}
802		})
803	}
804
805}
806
807func TestCipherSuiteConfiguration(t *testing.T) {
808	for _, test := range []struct {
809		Name               string
810		ClientCipherSuites []CipherSuiteID
811		ServerCipherSuites []CipherSuiteID
812		WantClientError    error
813		WantServerError    error
814	}{
815		{
816			Name:               "No CipherSuites specified",
817			ClientCipherSuites: nil,
818			ServerCipherSuites: nil,
819			WantClientError:    nil,
820			WantServerError:    nil,
821		},
822		{
823			Name:               "Invalid CipherSuite",
824			ClientCipherSuites: []CipherSuiteID{0x00},
825			ServerCipherSuites: []CipherSuiteID{0x00},
826			WantClientError:    errors.New("CipherSuite with id(0) is not valid"),
827			WantServerError:    errors.New("CipherSuite with id(0) is not valid"),
828		},
829		{
830			Name:               "Valid CipherSuites specified",
831			ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
832			ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
833			WantClientError:    nil,
834			WantServerError:    nil,
835		},
836		{
837			Name:               "CipherSuites mismatch",
838			ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
839			ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
840			WantClientError:    errors.New("alert: Alert LevelFatal: InsufficientSecurity"),
841			WantServerError:    errCipherSuiteNoIntersection,
842		},
843	} {
844		ca, cb := net.Pipe()
845		type result struct {
846			c   *Conn
847			err error
848		}
849		c := make(chan result)
850
851		go func() {
852			client, err := testClient(ca, &Config{CipherSuites: test.ClientCipherSuites}, true)
853			c <- result{client, err}
854		}()
855
856		_, err := testServer(cb, &Config{CipherSuites: test.ServerCipherSuites}, true)
857		if err != nil || test.WantServerError != nil {
858			if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
859				t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
860			}
861		}
862
863		res := <-c
864		if res.err != nil || test.WantClientError != nil {
865			if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
866				t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
867			}
868		}
869	}
870}
871
872func TestPSKConfiguration(t *testing.T) {
873	for _, test := range []struct {
874		Name                 string
875		ClientHasCertificate bool
876		ServerHasCertificate bool
877		ClientPSK            PSKCallback
878		ServerPSK            PSKCallback
879		ClientPSKIdentity    []byte
880		ServerPSKIdentity    []byte
881		WantClientError      error
882		WantServerError      error
883	}{
884		{
885			Name:                 "PSK specified",
886			ClientHasCertificate: false,
887			ServerHasCertificate: false,
888			ClientPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
889			ServerPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
890			ClientPSKIdentity:    []byte{0x00},
891			ServerPSKIdentity:    []byte{0x00},
892			WantClientError:      errNoAvailableCipherSuites,
893			WantServerError:      errNoAvailableCipherSuites,
894		},
895		{
896			Name:                 "PSK and certificate specified",
897			ClientHasCertificate: true,
898			ServerHasCertificate: true,
899			ClientPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
900			ServerPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
901			ClientPSKIdentity:    []byte{0x00},
902			ServerPSKIdentity:    []byte{0x00},
903			WantClientError:      errPSKAndCertificate,
904			WantServerError:      errPSKAndCertificate,
905		},
906		{
907			Name:                 "PSK and no identity specified",
908			ClientHasCertificate: false,
909			ServerHasCertificate: false,
910			ClientPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
911			ServerPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
912			ClientPSKIdentity:    nil,
913			ServerPSKIdentity:    nil,
914			WantClientError:      errPSKAndIdentityMustBeSetForClient,
915			WantServerError:      errNoAvailableCipherSuites,
916		},
917		{
918			Name:                 "No PSK and identity specified",
919			ClientHasCertificate: false,
920			ServerHasCertificate: false,
921			ClientPSK:            nil,
922			ServerPSK:            nil,
923			ClientPSKIdentity:    []byte{0x00},
924			ServerPSKIdentity:    []byte{0x00},
925			WantClientError:      errIdentityNoPSK,
926			WantServerError:      errServerMustHaveCertificate,
927		},
928	} {
929		ca, cb := net.Pipe()
930		type result struct {
931			c   *Conn
932			err error
933		}
934		c := make(chan result)
935
936		go func() {
937			client, err := testClient(ca, &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate)
938			c <- result{client, err}
939		}()
940
941		_, err := testServer(cb, &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate)
942		if err != nil || test.WantServerError != nil {
943			if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
944				t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
945			}
946		}
947
948		res := <-c
949		if res.err != nil || test.WantClientError != nil {
950			if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
951				t.Fatalf("TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
952			}
953		}
954	}
955}
956
957func TestServerTimeout(t *testing.T) {
958	// Limit runtime in case of deadlocks
959	lim := test.TimeOut(time.Second * 20)
960	defer lim.Stop()
961
962	cookie := make([]byte, 20)
963	_, err := rand.Read(cookie)
964	if err != nil {
965		t.Fatal(err)
966	}
967
968	var rand [28]byte
969	random := handshakeRandom{time.Unix(500, 0), rand}
970
971	cipherSuites := []cipherSuite{
972		&cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256{},
973		&cipherSuiteTLSEcdheRsaWithAes128GcmSha256{},
974	}
975
976	extensions := []extension{
977		&extensionSupportedSignatureAlgorithms{
978			signatureHashAlgorithms: []signatureHashAlgorithm{
979				{HashAlgorithmSHA256, signatureAlgorithmECDSA},
980				{HashAlgorithmSHA384, signatureAlgorithmECDSA},
981				{HashAlgorithmSHA512, signatureAlgorithmECDSA},
982				{HashAlgorithmSHA256, signatureAlgorithmRSA},
983				{HashAlgorithmSHA384, signatureAlgorithmRSA},
984				{HashAlgorithmSHA512, signatureAlgorithmRSA},
985			},
986		},
987		&extensionSupportedEllipticCurves{
988			ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256, namedCurveP384},
989		},
990		&extensionSupportedPointFormats{
991			pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
992		},
993	}
994
995	record := &recordLayer{
996		recordLayerHeader: recordLayerHeader{
997			sequenceNumber:  0,
998			protocolVersion: protocolVersion1_2,
999		},
1000		content: &handshake{
1001			// sequenceNumber and messageSequence line up, may need to be re-evaluated
1002			handshakeHeader: handshakeHeader{
1003				messageSequence: 0,
1004			},
1005			handshakeMessage: &handshakeMessageClientHello{
1006				version:            protocolVersion1_2,
1007				cookie:             cookie,
1008				random:             random,
1009				cipherSuites:       cipherSuites,
1010				compressionMethods: defaultCompressionMethods,
1011				extensions:         extensions,
1012			}},
1013	}
1014
1015	packet, err := record.Marshal()
1016	if err != nil {
1017		t.Fatal(err)
1018	}
1019
1020	ca, cb := net.Pipe()
1021	defer func() {
1022		err := ca.Close()
1023		if err != nil {
1024			t.Fatal(err)
1025		}
1026	}()
1027
1028	// Client reader
1029	caReadChan := make(chan []byte, 1000)
1030	go func() {
1031		for {
1032			data := make([]byte, 8192)
1033			n, err := ca.Read(data)
1034			if err != nil {
1035				return
1036			}
1037
1038			caReadChan <- data[:n]
1039		}
1040	}()
1041
1042	// Start sending ClientHello packets until server responds with first packet
1043	go func() {
1044		for {
1045			select {
1046			case <-time.After(10 * time.Millisecond):
1047				_, err := ca.Write(packet)
1048				if err != nil {
1049					return
1050				}
1051			case <-caReadChan:
1052				// Once we receive the first reply from the server, stop
1053				return
1054			}
1055		}
1056	}()
1057
1058	config := &Config{
1059		CipherSuites:   []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1060		ConnectTimeout: ConnectTimeoutOption(50 * time.Millisecond),
1061		FlightInterval: 100 * time.Millisecond,
1062	}
1063
1064	if _, err := testServer(cb, config, true); err != errConnectTimeout {
1065		t.Fatalf("TestServerTimeout: Client error exp(%v) failed(%v)", errConnectTimeout, err)
1066	}
1067
1068	// Wait a little longer to ensure no additional messages have been sent by the server
1069	time.Sleep(300 * time.Millisecond)
1070	select {
1071	case msg := <-caReadChan:
1072		t.Fatalf("TestServerTimeout: Expected no additional messages from server, got: %+v", msg)
1073	default:
1074	}
1075}
1076