1package dtls 2 3import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "sync" 8 "testing" 9 "time" 10 11 "github.com/pion/dtls/v2/pkg/crypto/selfsign" 12 "github.com/pion/dtls/v2/pkg/crypto/signaturehash" 13 "github.com/pion/dtls/v2/pkg/protocol/alert" 14 "github.com/pion/dtls/v2/pkg/protocol/handshake" 15 "github.com/pion/dtls/v2/pkg/protocol/recordlayer" 16 "github.com/pion/logging" 17 "github.com/pion/transport/test" 18) 19 20const nonZeroRetransmitInterval = 100 * time.Millisecond 21 22// Test that writes to the key log are in the correct format and only applies 23// when a key log writer is given. 24func TestWriteKeyLog(t *testing.T) { 25 var buf bytes.Buffer 26 cfg := handshakeConfig{ 27 keyLogWriter: &buf, 28 } 29 cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) 30 31 // Secrets follow the format <Label> <space> <ClientRandom> <space> <Secret> 32 // https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format 33 want := "LABEL aabbcc ddeeff\n" 34 if buf.String() != want { 35 t.Fatalf("Got %s want %s", buf.String(), want) 36 } 37 38 // no key log writer = no writes 39 cfg = handshakeConfig{} 40 cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) 41} 42 43func TestHandshaker(t *testing.T) { 44 // Check for leaking routines 45 report := test.CheckRoutines(t) 46 defer report() 47 48 loggerFactory := logging.NewDefaultLoggerFactory() 49 logger := loggerFactory.NewLogger("dtls") 50 51 cipherSuites, err := parseCipherSuites(nil, nil, true, false) 52 if err != nil { 53 t.Fatal(err) 54 } 55 clientCert, err := selfsign.GenerateSelfSigned() 56 if err != nil { 57 t.Fatal(err) 58 } 59 60 genFilters := map[string]func() (packetFilter, packetFilter, func(t *testing.T)){ 61 "PassThrough": func() (packetFilter, packetFilter, func(t *testing.T)) { 62 return nil, nil, nil 63 }, 64 "HelloVerifyRequestLost": func() (packetFilter, packetFilter, func(t *testing.T)) { 65 var ( 66 cntHelloVerifyRequest = 0 67 cntClientHelloNoCookie = 0 68 ) 69 const helloVerifyDrop = 5 70 return func(p *packet) bool { 71 h, ok := p.record.Content.(*handshake.Handshake) 72 if !ok { 73 return true 74 } 75 if hmch, ok := h.Message.(*handshake.MessageClientHello); ok { 76 if len(hmch.Cookie) == 0 { 77 cntClientHelloNoCookie++ 78 } 79 } 80 return true 81 }, 82 func(p *packet) bool { 83 h, ok := p.record.Content.(*handshake.Handshake) 84 if !ok { 85 return true 86 } 87 if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok { 88 cntHelloVerifyRequest++ 89 return cntHelloVerifyRequest > helloVerifyDrop 90 } 91 return true 92 }, 93 func(t *testing.T) { 94 if cntHelloVerifyRequest != helloVerifyDrop+1 { 95 t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest) 96 } 97 if cntClientHelloNoCookie != cntHelloVerifyRequest { 98 t.Errorf( 99 "HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times", 100 cntHelloVerifyRequest, cntClientHelloNoCookie, 101 ) 102 } 103 } 104 }, 105 } 106 107 for name, filters := range genFilters { 108 f1, f2, report := filters() 109 t.Run(name, func(t *testing.T) { 110 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 111 defer cancel() 112 113 if report != nil { 114 defer report(t) 115 } 116 117 ca, cb := flightTestPipe(ctx, f1, f2) 118 ca.state.isClient = true 119 120 var wg sync.WaitGroup 121 wg.Add(2) 122 123 ctxCliFinished, cancelCli := context.WithCancel(ctx) 124 ctxSrvFinished, cancelSrv := context.WithCancel(ctx) 125 go func() { 126 defer wg.Done() 127 cfg := &handshakeConfig{ 128 localCipherSuites: cipherSuites, 129 localCertificates: []tls.Certificate{clientCert}, 130 localSignatureSchemes: signaturehash.Algorithms(), 131 insecureSkipVerify: true, 132 log: logger, 133 onFlightState: func(f flightVal, s handshakeState) { 134 if s == handshakeFinished { 135 cancelCli() 136 } 137 }, 138 retransmitInterval: nonZeroRetransmitInterval, 139 } 140 141 fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1) 142 switch err := fsm.Run(ctx, ca, handshakePreparing); err { 143 case context.Canceled: 144 case context.DeadlineExceeded: 145 t.Error("Timeout") 146 default: 147 t.Error(err) 148 } 149 }() 150 151 go func() { 152 defer wg.Done() 153 cfg := &handshakeConfig{ 154 localCipherSuites: cipherSuites, 155 localCertificates: []tls.Certificate{clientCert}, 156 localSignatureSchemes: signaturehash.Algorithms(), 157 insecureSkipVerify: true, 158 log: logger, 159 onFlightState: func(f flightVal, s handshakeState) { 160 if s == handshakeFinished { 161 cancelSrv() 162 } 163 }, 164 retransmitInterval: nonZeroRetransmitInterval, 165 } 166 167 fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0) 168 switch err := fsm.Run(ctx, cb, handshakePreparing); err { 169 case context.Canceled: 170 case context.DeadlineExceeded: 171 t.Error("Timeout") 172 default: 173 t.Error(err) 174 } 175 }() 176 177 <-ctxCliFinished.Done() 178 <-ctxSrvFinished.Done() 179 180 cancel() 181 wg.Wait() 182 }) 183 } 184} 185 186type packetFilter func(*packet) bool 187 188func flightTestPipe(ctx context.Context, filter1 packetFilter, filter2 packetFilter) (*flightTestConn, *flightTestConn) { 189 ca := newHandshakeCache() 190 cb := newHandshakeCache() 191 chA := make(chan chan struct{}) 192 chB := make(chan chan struct{}) 193 return &flightTestConn{ 194 handshakeCache: ca, 195 otherEndCache: cb, 196 recv: chA, 197 otherEndRecv: chB, 198 done: ctx.Done(), 199 filter: filter1, 200 }, &flightTestConn{ 201 handshakeCache: cb, 202 otherEndCache: ca, 203 recv: chB, 204 otherEndRecv: chA, 205 done: ctx.Done(), 206 filter: filter2, 207 } 208} 209 210type flightTestConn struct { 211 state State 212 handshakeCache *handshakeCache 213 recv chan chan struct{} 214 done <-chan struct{} 215 epoch uint16 216 217 filter packetFilter 218 219 otherEndCache *handshakeCache 220 otherEndRecv chan chan struct{} 221} 222 223func (c *flightTestConn) recvHandshake() <-chan chan struct{} { 224 return c.recv 225} 226 227func (c *flightTestConn) setLocalEpoch(epoch uint16) { 228 c.epoch = epoch 229} 230 231func (c *flightTestConn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { 232 return nil 233} 234 235func (c *flightTestConn) writePackets(ctx context.Context, pkts []*packet) error { 236 for _, p := range pkts { 237 if c.filter != nil && !c.filter(p) { 238 continue 239 } 240 if h, ok := p.record.Content.(*handshake.Handshake); ok { 241 handshakeRaw, err := p.record.Marshal() 242 if err != nil { 243 return err 244 } 245 246 c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) 247 248 content, err := h.Message.Marshal() 249 if err != nil { 250 return err 251 } 252 h.Header.Length = uint32(len(content)) 253 h.Header.FragmentLength = uint32(len(content)) 254 hdr, err := h.Header.Marshal() 255 if err != nil { 256 return err 257 } 258 c.otherEndCache.push( 259 append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) 260 } 261 } 262 go func() { 263 select { 264 case c.otherEndRecv <- make(chan struct{}): 265 case <-c.done: 266 } 267 }() 268 269 // Avoid deadlock on JS/WASM environment due to context switch problem. 270 time.Sleep(10 * time.Millisecond) 271 272 return nil 273} 274 275func (c *flightTestConn) handleQueuedPackets(ctx context.Context) error { 276 return nil 277} 278