1// +build !js
2
3package webrtc
4
5import (
6	"crypto/ecdsa"
7	"crypto/elliptic"
8	"crypto/rand"
9	"crypto/tls"
10	"crypto/x509"
11	"errors"
12	"fmt"
13	"strings"
14	"sync"
15	"sync/atomic"
16	"time"
17
18	"github.com/pion/dtls/v2"
19	"github.com/pion/dtls/v2/pkg/crypto/fingerprint"
20	"github.com/pion/logging"
21	"github.com/pion/srtp/v2"
22	"github.com/pion/webrtc/v3/internal/mux"
23	"github.com/pion/webrtc/v3/internal/util"
24	"github.com/pion/webrtc/v3/pkg/rtcerr"
25)
26
27// DTLSTransport allows an application access to information about the DTLS
28// transport over which RTP and RTCP packets are sent and received by
29// RTPSender and RTPReceiver, as well other data such as SCTP packets sent
30// and received by data channels.
31type DTLSTransport struct {
32	lock sync.RWMutex
33
34	iceTransport          *ICETransport
35	certificates          []Certificate
36	remoteParameters      DTLSParameters
37	remoteCertificate     []byte
38	state                 DTLSTransportState
39	srtpProtectionProfile srtp.ProtectionProfile
40
41	onStateChangeHandler func(DTLSTransportState)
42
43	conn *dtls.Conn
44
45	srtpSession, srtcpSession   atomic.Value
46	srtpEndpoint, srtcpEndpoint *mux.Endpoint
47	simulcastStreams            []*srtp.ReadStreamSRTP
48	srtpReady                   chan struct{}
49
50	dtlsMatcher mux.MatchFunc
51
52	api *API
53	log logging.LeveledLogger
54}
55
56// NewDTLSTransport creates a new DTLSTransport.
57// This constructor is part of the ORTC API. It is not
58// meant to be used together with the basic WebRTC API.
59func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) {
60	t := &DTLSTransport{
61		iceTransport: transport,
62		api:          api,
63		state:        DTLSTransportStateNew,
64		dtlsMatcher:  mux.MatchDTLS,
65		srtpReady:    make(chan struct{}),
66		log:          api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"),
67	}
68
69	if len(certificates) > 0 {
70		now := time.Now()
71		for _, x509Cert := range certificates {
72			if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) {
73				return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired}
74			}
75			t.certificates = append(t.certificates, x509Cert)
76		}
77	} else {
78		sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
79		if err != nil {
80			return nil, &rtcerr.UnknownError{Err: err}
81		}
82		certificate, err := GenerateCertificate(sk)
83		if err != nil {
84			return nil, err
85		}
86		t.certificates = []Certificate{*certificate}
87	}
88
89	return t, nil
90}
91
92// ICETransport returns the currently-configured *ICETransport or nil
93// if one has not been configured
94func (t *DTLSTransport) ICETransport() *ICETransport {
95	t.lock.RLock()
96	defer t.lock.RUnlock()
97	return t.iceTransport
98}
99
100// onStateChange requires the caller holds the lock
101func (t *DTLSTransport) onStateChange(state DTLSTransportState) {
102	t.state = state
103	handler := t.onStateChangeHandler
104	if handler != nil {
105		handler(state)
106	}
107}
108
109// OnStateChange sets a handler that is fired when the DTLS
110// connection state changes.
111func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) {
112	t.lock.Lock()
113	defer t.lock.Unlock()
114	t.onStateChangeHandler = f
115}
116
117// State returns the current dtls transport state.
118func (t *DTLSTransport) State() DTLSTransportState {
119	t.lock.RLock()
120	defer t.lock.RUnlock()
121	return t.state
122}
123
124// GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction.
125func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) {
126	fingerprints := []DTLSFingerprint{}
127
128	for _, c := range t.certificates {
129		prints, err := c.GetFingerprints()
130		if err != nil {
131			return DTLSParameters{}, err
132		}
133
134		fingerprints = append(fingerprints, prints...)
135	}
136
137	return DTLSParameters{
138		Role:         DTLSRoleAuto, // always returns the default role
139		Fingerprints: fingerprints,
140	}, nil
141}
142
143// GetRemoteCertificate returns the certificate chain in use by the remote side
144// returns an empty list prior to selection of the remote certificate
145func (t *DTLSTransport) GetRemoteCertificate() []byte {
146	t.lock.RLock()
147	defer t.lock.RUnlock()
148	return t.remoteCertificate
149}
150
151func (t *DTLSTransport) startSRTP() error {
152	srtpConfig := &srtp.Config{
153		Profile:       t.srtpProtectionProfile,
154		BufferFactory: t.api.settingEngine.BufferFactory,
155		LoggerFactory: t.api.settingEngine.LoggerFactory,
156	}
157	if t.api.settingEngine.replayProtection.SRTP != nil {
158		srtpConfig.RemoteOptions = append(
159			srtpConfig.RemoteOptions,
160			srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP),
161		)
162	}
163
164	if t.api.settingEngine.disableSRTPReplayProtection {
165		srtpConfig.RemoteOptions = append(
166			srtpConfig.RemoteOptions,
167			srtp.SRTPNoReplayProtection(),
168		)
169	}
170
171	if t.api.settingEngine.replayProtection.SRTCP != nil {
172		srtpConfig.RemoteOptions = append(
173			srtpConfig.RemoteOptions,
174			srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP),
175		)
176	}
177
178	if t.api.settingEngine.disableSRTCPReplayProtection {
179		srtpConfig.RemoteOptions = append(
180			srtpConfig.RemoteOptions,
181			srtp.SRTCPNoReplayProtection(),
182		)
183	}
184
185	connState := t.conn.ConnectionState()
186	err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient)
187	if err != nil {
188		return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
189	}
190
191	srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
192	if err != nil {
193		return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
194	}
195
196	srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig)
197	if err != nil {
198		return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err)
199	}
200
201	t.srtpSession.Store(srtpSession)
202	t.srtcpSession.Store(srtcpSession)
203	close(t.srtpReady)
204	return nil
205}
206
207func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) {
208	if value := t.srtpSession.Load(); value != nil {
209		return value.(*srtp.SessionSRTP), nil
210	}
211
212	return nil, errDtlsTransportNotStarted
213}
214
215func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) {
216	if value := t.srtcpSession.Load(); value != nil {
217		return value.(*srtp.SessionSRTCP), nil
218	}
219
220	return nil, errDtlsTransportNotStarted
221}
222
223func (t *DTLSTransport) role() DTLSRole {
224	// If remote has an explicit role use the inverse
225	switch t.remoteParameters.Role {
226	case DTLSRoleClient:
227		return DTLSRoleServer
228	case DTLSRoleServer:
229		return DTLSRoleClient
230	default:
231	}
232
233	// If SettingEngine has an explicit role
234	switch t.api.settingEngine.answeringDTLSRole {
235	case DTLSRoleServer:
236		return DTLSRoleServer
237	case DTLSRoleClient:
238		return DTLSRoleClient
239	default:
240	}
241
242	// Remote was auto and no explicit role was configured via SettingEngine
243	if t.iceTransport.Role() == ICERoleControlling {
244		return DTLSRoleServer
245	}
246	return defaultDtlsRoleAnswer
247}
248
249// Start DTLS transport negotiation with the parameters of the remote DTLS transport
250func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
251	// Take lock and prepare connection, we must not hold the lock
252	// when connecting
253	prepareTransport := func() (DTLSRole, *dtls.Config, error) {
254		t.lock.Lock()
255		defer t.lock.Unlock()
256
257		if err := t.ensureICEConn(); err != nil {
258			return DTLSRole(0), nil, err
259		}
260
261		if t.state != DTLSTransportStateNew {
262			return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)}
263		}
264
265		t.srtpEndpoint = t.iceTransport.NewEndpoint(mux.MatchSRTP)
266		t.srtcpEndpoint = t.iceTransport.NewEndpoint(mux.MatchSRTCP)
267		t.remoteParameters = remoteParameters
268
269		cert := t.certificates[0]
270		t.onStateChange(DTLSTransportStateConnecting)
271
272		return t.role(), &dtls.Config{
273			Certificates: []tls.Certificate{
274				{
275					Certificate: [][]byte{cert.x509Cert.Raw},
276					PrivateKey:  cert.privateKey,
277				},
278			},
279			SRTPProtectionProfiles: []dtls.SRTPProtectionProfile{dtls.SRTP_AEAD_AES_128_GCM, dtls.SRTP_AES128_CM_HMAC_SHA1_80},
280			ClientAuth:             dtls.RequireAnyClientCert,
281			LoggerFactory:          t.api.settingEngine.LoggerFactory,
282			InsecureSkipVerify:     true,
283		}, nil
284	}
285
286	var dtlsConn *dtls.Conn
287	dtlsEndpoint := t.iceTransport.NewEndpoint(mux.MatchDTLS)
288	role, dtlsConfig, err := prepareTransport()
289	if err != nil {
290		return err
291	}
292
293	if t.api.settingEngine.replayProtection.DTLS != nil {
294		dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS)
295	}
296
297	// Connect as DTLS Client/Server, function is blocking and we
298	// must not hold the DTLSTransport lock
299	if role == DTLSRoleClient {
300		dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsConfig)
301	} else {
302		dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsConfig)
303	}
304
305	// Re-take the lock, nothing beyond here is blocking
306	t.lock.Lock()
307	defer t.lock.Unlock()
308
309	if err != nil {
310		t.onStateChange(DTLSTransportStateFailed)
311		return err
312	}
313
314	srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile()
315	if !ok {
316		t.onStateChange(DTLSTransportStateFailed)
317		return ErrNoSRTPProtectionProfile
318	}
319
320	switch srtpProfile {
321	case dtls.SRTP_AEAD_AES_128_GCM:
322		t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm
323	case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
324		t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
325	default:
326		t.onStateChange(DTLSTransportStateFailed)
327		return ErrNoSRTPProtectionProfile
328	}
329
330	if t.api.settingEngine.disableCertificateFingerprintVerification {
331		return nil
332	}
333
334	// Check the fingerprint if a certificate was exchanged
335	remoteCerts := dtlsConn.ConnectionState().PeerCertificates
336	if len(remoteCerts) == 0 {
337		t.onStateChange(DTLSTransportStateFailed)
338		return errNoRemoteCertificate
339	}
340	t.remoteCertificate = remoteCerts[0]
341
342	parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate)
343	if err != nil {
344		if closeErr := dtlsConn.Close(); closeErr != nil {
345			t.log.Error(err.Error())
346		}
347
348		t.onStateChange(DTLSTransportStateFailed)
349		return err
350	}
351
352	if err = t.validateFingerPrint(parsedRemoteCert); err != nil {
353		if closeErr := dtlsConn.Close(); closeErr != nil {
354			t.log.Error(err.Error())
355		}
356
357		t.onStateChange(DTLSTransportStateFailed)
358		return err
359	}
360
361	t.conn = dtlsConn
362	t.onStateChange(DTLSTransportStateConnected)
363
364	return t.startSRTP()
365}
366
367// Stop stops and closes the DTLSTransport object.
368func (t *DTLSTransport) Stop() error {
369	t.lock.Lock()
370	defer t.lock.Unlock()
371
372	// Try closing everything and collect the errors
373	var closeErrs []error
374
375	if srtpSessionValue := t.srtpSession.Load(); srtpSessionValue != nil {
376		closeErrs = append(closeErrs, srtpSessionValue.(*srtp.SessionSRTP).Close())
377	}
378
379	if srtcpSessionValue := t.srtcpSession.Load(); srtcpSessionValue != nil {
380		closeErrs = append(closeErrs, srtcpSessionValue.(*srtp.SessionSRTCP).Close())
381	}
382
383	for i := range t.simulcastStreams {
384		closeErrs = append(closeErrs, t.simulcastStreams[i].Close())
385	}
386
387	if t.conn != nil {
388		// dtls connection may be closed on sctp close.
389		if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) {
390			closeErrs = append(closeErrs, err)
391		}
392	}
393	t.onStateChange(DTLSTransportStateClosed)
394	return util.FlattenErrs(closeErrs)
395}
396
397func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error {
398	for _, fp := range t.remoteParameters.Fingerprints {
399		hashAlgo, err := fingerprint.HashFromString(fp.Algorithm)
400		if err != nil {
401			return err
402		}
403
404		remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo)
405		if err != nil {
406			return err
407		}
408
409		if strings.EqualFold(remoteValue, fp.Value) {
410			return nil
411		}
412	}
413
414	return errNoMatchingCertificateFingerprint
415}
416
417func (t *DTLSTransport) ensureICEConn() error {
418	if t.iceTransport == nil || t.iceTransport.State() == ICETransportStateNew {
419		return errICEConnectionNotStarted
420	}
421
422	return nil
423}
424
425func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) {
426	t.lock.Lock()
427	defer t.lock.Unlock()
428
429	t.simulcastStreams = append(t.simulcastStreams, s)
430}
431