1package dtls
2
3import (
4	"bytes"
5	"context"
6	"crypto/tls"
7	"sync"
8	"testing"
9	"time"
10
11	"github.com/pion/dtls/v2/pkg/crypto/selfsign"
12	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
13	"github.com/pion/dtls/v2/pkg/protocol/alert"
14	"github.com/pion/dtls/v2/pkg/protocol/handshake"
15	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
16	"github.com/pion/logging"
17	"github.com/pion/transport/test"
18)
19
20const nonZeroRetransmitInterval = 100 * time.Millisecond
21
22// Test that writes to the key log are in the correct format and only applies
23// when a key log writer is given.
24func TestWriteKeyLog(t *testing.T) {
25	var buf bytes.Buffer
26	cfg := handshakeConfig{
27		keyLogWriter: &buf,
28	}
29	cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
30
31	// Secrets follow the format <Label> <space> <ClientRandom> <space> <Secret>
32	// https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format
33	want := "LABEL aabbcc ddeeff\n"
34	if buf.String() != want {
35		t.Fatalf("Got %s want %s", buf.String(), want)
36	}
37
38	// no key log writer = no writes
39	cfg = handshakeConfig{}
40	cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
41}
42
43func TestHandshaker(t *testing.T) {
44	// Check for leaking routines
45	report := test.CheckRoutines(t)
46	defer report()
47
48	loggerFactory := logging.NewDefaultLoggerFactory()
49	logger := loggerFactory.NewLogger("dtls")
50
51	cipherSuites, err := parseCipherSuites(nil, nil, true, false)
52	if err != nil {
53		t.Fatal(err)
54	}
55	clientCert, err := selfsign.GenerateSelfSigned()
56	if err != nil {
57		t.Fatal(err)
58	}
59
60	genFilters := map[string]func() (packetFilter, packetFilter, func(t *testing.T)){
61		"PassThrough": func() (packetFilter, packetFilter, func(t *testing.T)) {
62			return nil, nil, nil
63		},
64		"HelloVerifyRequestLost": func() (packetFilter, packetFilter, func(t *testing.T)) {
65			var (
66				cntHelloVerifyRequest  = 0
67				cntClientHelloNoCookie = 0
68			)
69			const helloVerifyDrop = 5
70			return func(p *packet) bool {
71					h, ok := p.record.Content.(*handshake.Handshake)
72					if !ok {
73						return true
74					}
75					if hmch, ok := h.Message.(*handshake.MessageClientHello); ok {
76						if len(hmch.Cookie) == 0 {
77							cntClientHelloNoCookie++
78						}
79					}
80					return true
81				},
82				func(p *packet) bool {
83					h, ok := p.record.Content.(*handshake.Handshake)
84					if !ok {
85						return true
86					}
87					if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok {
88						cntHelloVerifyRequest++
89						return cntHelloVerifyRequest > helloVerifyDrop
90					}
91					return true
92				},
93				func(t *testing.T) {
94					if cntHelloVerifyRequest != helloVerifyDrop+1 {
95						t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest)
96					}
97					if cntClientHelloNoCookie != cntHelloVerifyRequest {
98						t.Errorf(
99							"HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times",
100							cntHelloVerifyRequest, cntClientHelloNoCookie,
101						)
102					}
103				}
104		},
105	}
106
107	for name, filters := range genFilters {
108		f1, f2, report := filters()
109		t.Run(name, func(t *testing.T) {
110			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
111			defer cancel()
112
113			if report != nil {
114				defer report(t)
115			}
116
117			ca, cb := flightTestPipe(ctx, f1, f2)
118			ca.state.isClient = true
119
120			var wg sync.WaitGroup
121			wg.Add(2)
122
123			ctxCliFinished, cancelCli := context.WithCancel(ctx)
124			ctxSrvFinished, cancelSrv := context.WithCancel(ctx)
125			go func() {
126				defer wg.Done()
127				cfg := &handshakeConfig{
128					localCipherSuites:     cipherSuites,
129					localCertificates:     []tls.Certificate{clientCert},
130					localSignatureSchemes: signaturehash.Algorithms(),
131					insecureSkipVerify:    true,
132					log:                   logger,
133					onFlightState: func(f flightVal, s handshakeState) {
134						if s == handshakeFinished {
135							cancelCli()
136						}
137					},
138					retransmitInterval: nonZeroRetransmitInterval,
139				}
140
141				fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1)
142				switch err := fsm.Run(ctx, ca, handshakePreparing); err {
143				case context.Canceled:
144				case context.DeadlineExceeded:
145					t.Error("Timeout")
146				default:
147					t.Error(err)
148				}
149			}()
150
151			go func() {
152				defer wg.Done()
153				cfg := &handshakeConfig{
154					localCipherSuites:     cipherSuites,
155					localCertificates:     []tls.Certificate{clientCert},
156					localSignatureSchemes: signaturehash.Algorithms(),
157					insecureSkipVerify:    true,
158					log:                   logger,
159					onFlightState: func(f flightVal, s handshakeState) {
160						if s == handshakeFinished {
161							cancelSrv()
162						}
163					},
164					retransmitInterval: nonZeroRetransmitInterval,
165				}
166
167				fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0)
168				switch err := fsm.Run(ctx, cb, handshakePreparing); err {
169				case context.Canceled:
170				case context.DeadlineExceeded:
171					t.Error("Timeout")
172				default:
173					t.Error(err)
174				}
175			}()
176
177			<-ctxCliFinished.Done()
178			<-ctxSrvFinished.Done()
179
180			cancel()
181			wg.Wait()
182		})
183	}
184}
185
186type packetFilter func(*packet) bool
187
188func flightTestPipe(ctx context.Context, filter1 packetFilter, filter2 packetFilter) (*flightTestConn, *flightTestConn) {
189	ca := newHandshakeCache()
190	cb := newHandshakeCache()
191	chA := make(chan chan struct{})
192	chB := make(chan chan struct{})
193	return &flightTestConn{
194			handshakeCache: ca,
195			otherEndCache:  cb,
196			recv:           chA,
197			otherEndRecv:   chB,
198			done:           ctx.Done(),
199			filter:         filter1,
200		}, &flightTestConn{
201			handshakeCache: cb,
202			otherEndCache:  ca,
203			recv:           chB,
204			otherEndRecv:   chA,
205			done:           ctx.Done(),
206			filter:         filter2,
207		}
208}
209
210type flightTestConn struct {
211	state          State
212	handshakeCache *handshakeCache
213	recv           chan chan struct{}
214	done           <-chan struct{}
215	epoch          uint16
216
217	filter packetFilter
218
219	otherEndCache *handshakeCache
220	otherEndRecv  chan chan struct{}
221}
222
223func (c *flightTestConn) recvHandshake() <-chan chan struct{} {
224	return c.recv
225}
226
227func (c *flightTestConn) setLocalEpoch(epoch uint16) {
228	c.epoch = epoch
229}
230
231func (c *flightTestConn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
232	return nil
233}
234
235func (c *flightTestConn) writePackets(ctx context.Context, pkts []*packet) error {
236	for _, p := range pkts {
237		if c.filter != nil && !c.filter(p) {
238			continue
239		}
240		if h, ok := p.record.Content.(*handshake.Handshake); ok {
241			handshakeRaw, err := p.record.Marshal()
242			if err != nil {
243				return err
244			}
245
246			c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
247
248			content, err := h.Message.Marshal()
249			if err != nil {
250				return err
251			}
252			h.Header.Length = uint32(len(content))
253			h.Header.FragmentLength = uint32(len(content))
254			hdr, err := h.Header.Marshal()
255			if err != nil {
256				return err
257			}
258			c.otherEndCache.push(
259				append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
260		}
261	}
262	go func() {
263		select {
264		case c.otherEndRecv <- make(chan struct{}):
265		case <-c.done:
266		}
267	}()
268
269	// Avoid deadlock on JS/WASM environment due to context switch problem.
270	time.Sleep(10 * time.Millisecond)
271
272	return nil
273}
274
275func (c *flightTestConn) handleQueuedPackets(ctx context.Context) error {
276	return nil
277}
278