1package dtls
2
3import (
4	"context"
5	"crypto/x509"
6
7	"github.com/pion/dtls/v2/pkg/crypto/clientcertificate"
8	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
9	"github.com/pion/dtls/v2/pkg/crypto/prf"
10	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
11	"github.com/pion/dtls/v2/pkg/protocol"
12	"github.com/pion/dtls/v2/pkg/protocol/alert"
13	"github.com/pion/dtls/v2/pkg/protocol/extension"
14	"github.com/pion/dtls/v2/pkg/protocol/handshake"
15	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
16)
17
18func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
19	seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence,
20		handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true},
21		handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
22		handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true},
23	)
24	if !ok {
25		// No valid message received. Keep reading
26		return 0, nil, nil
27	}
28
29	// Validate type
30	var clientKeyExchange *handshake.MessageClientKeyExchange
31	if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok {
32		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
33	}
34
35	if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert {
36		state.PeerCertificates = h.Certificate
37	}
38
39	if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify {
40		if state.PeerCertificates == nil {
41			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate
42		}
43
44		plainText := cache.pullAndMerge(
45			handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
46			handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
47			handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false},
48			handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
49			handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false},
50			handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
51			handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false},
52			handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false},
53		)
54
55		// Verify that the pair of hash algorithm and signiture is listed.
56		var validSignatureScheme bool
57		for _, ss := range cfg.localSignatureSchemes {
58			if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm {
59				validSignatureScheme = true
60				break
61			}
62		}
63		if !validSignatureScheme {
64			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes
65		}
66
67		if err := verifyCertificateVerify(plainText, h.HashAlgorithm, h.Signature, state.PeerCertificates); err != nil {
68			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
69		}
70		var chains [][]*x509.Certificate
71		var err error
72		var verified bool
73		if cfg.clientAuth >= VerifyClientCertIfGiven {
74			if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil {
75				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
76			}
77			verified = true
78		}
79		if cfg.verifyPeerCertificate != nil {
80			if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil {
81				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
82			}
83		}
84		state.peerCertificatesVerified = verified
85	}
86
87	if !state.cipherSuite.IsInitialized() {
88		serverRandom := state.localRandom.MarshalFixed()
89		clientRandom := state.remoteRandom.MarshalFixed()
90
91		var err error
92		var preMasterSecret []byte
93		if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey {
94			var psk []byte
95			if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil {
96				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
97			}
98			state.IdentityHint = clientKeyExchange.IdentityHint
99			preMasterSecret = prf.PSKPreMasterSecret(psk)
100		} else {
101			preMasterSecret, err = prf.PreMasterSecret(clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
102			if err != nil {
103				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err
104			}
105		}
106
107		if state.extendedMasterSecret {
108			var sessionHash []byte
109			sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch)
110			if err != nil {
111				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
112			}
113
114			state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc())
115			if err != nil {
116				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
117			}
118		} else {
119			state.masterSecret, err = prf.MasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc())
120			if err != nil {
121				return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
122			}
123		}
124
125		if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil {
126			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
127		}
128		cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
129	}
130
131	// Now, encrypted packets can be handled
132	if err := c.handleQueuedPackets(ctx); err != nil {
133		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
134	}
135
136	seq, msgs, ok = cache.fullPullMap(seq,
137		handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false},
138	)
139	if !ok {
140		// No valid message received. Keep reading
141		return 0, nil, nil
142	}
143	state.handshakeRecvSequence = seq
144
145	if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
146		return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
147	}
148
149	if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous {
150		return flight6, nil, nil
151	}
152
153	switch cfg.clientAuth {
154	case RequireAnyClientCert:
155		if state.PeerCertificates == nil {
156			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
157		}
158	case VerifyClientCertIfGiven:
159		if state.PeerCertificates != nil && !state.peerCertificatesVerified {
160			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
161		}
162	case RequireAndVerifyClientCert:
163		if state.PeerCertificates == nil {
164			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired
165		}
166		if !state.peerCertificatesVerified {
167			return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified
168		}
169	case NoClientCert, RequestClientCert:
170		return flight6, nil, nil
171	}
172
173	return flight6, nil, nil
174}
175
176func flight4Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
177	extensions := []extension.Extension{&extension.RenegotiationInfo{
178		RenegotiatedConnection: 0,
179	}}
180	if (cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
181		cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret {
182		extensions = append(extensions, &extension.UseExtendedMasterSecret{
183			Supported: true,
184		})
185	}
186	if state.srtpProtectionProfile != 0 {
187		extensions = append(extensions, &extension.UseSRTP{
188			ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile},
189		})
190	}
191	if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
192		extensions = append(extensions, []extension.Extension{
193			&extension.SupportedEllipticCurves{
194				EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
195			},
196			&extension.SupportedPointFormats{
197				PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
198			},
199		}...)
200	}
201
202	var pkts []*packet
203	cipherSuiteID := uint16(state.cipherSuite.ID())
204
205	pkts = append(pkts, &packet{
206		record: &recordlayer.RecordLayer{
207			Header: recordlayer.Header{
208				Version: protocol.Version1_2,
209			},
210			Content: &handshake.Handshake{
211				Message: &handshake.MessageServerHello{
212					Version:           protocol.Version1_2,
213					Random:            state.localRandom,
214					CipherSuiteID:     &cipherSuiteID,
215					CompressionMethod: defaultCompressionMethods()[0],
216					Extensions:        extensions,
217				},
218			},
219		},
220	})
221
222	switch {
223	case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate:
224		certificate, err := cfg.getCertificate(cfg.serverName)
225		if err != nil {
226			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err
227		}
228
229		pkts = append(pkts, &packet{
230			record: &recordlayer.RecordLayer{
231				Header: recordlayer.Header{
232					Version: protocol.Version1_2,
233				},
234				Content: &handshake.Handshake{
235					Message: &handshake.MessageCertificate{
236						Certificate: certificate.Certificate,
237					},
238				},
239			},
240		})
241
242		serverRandom := state.localRandom.MarshalFixed()
243		clientRandom := state.remoteRandom.MarshalFixed()
244
245		// Find compatible signature scheme
246		signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey)
247		if err != nil {
248			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err
249		}
250
251		signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash)
252		if err != nil {
253			return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
254		}
255		state.localKeySignature = signature
256
257		pkts = append(pkts, &packet{
258			record: &recordlayer.RecordLayer{
259				Header: recordlayer.Header{
260					Version: protocol.Version1_2,
261				},
262				Content: &handshake.Handshake{
263					Message: &handshake.MessageServerKeyExchange{
264						EllipticCurveType:  elliptic.CurveTypeNamedCurve,
265						NamedCurve:         state.namedCurve,
266						PublicKey:          state.localKeypair.PublicKey,
267						HashAlgorithm:      signatureHashAlgo.Hash,
268						SignatureAlgorithm: signatureHashAlgo.Signature,
269						Signature:          state.localKeySignature,
270					},
271				},
272			},
273		})
274
275		if cfg.clientAuth > NoClientCert {
276			pkts = append(pkts, &packet{
277				record: &recordlayer.RecordLayer{
278					Header: recordlayer.Header{
279						Version: protocol.Version1_2,
280					},
281					Content: &handshake.Handshake{
282						Message: &handshake.MessageCertificateRequest{
283							CertificateTypes:        []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign},
284							SignatureHashAlgorithms: cfg.localSignatureSchemes,
285						},
286					},
287				},
288			})
289		}
290	case cfg.localPSKIdentityHint != nil:
291		// To help the client in selecting which identity to use, the server
292		// can provide a "PSK identity hint" in the ServerKeyExchange message.
293		// If no hint is provided, the ServerKeyExchange message is omitted.
294		//
295		// https://tools.ietf.org/html/rfc4279#section-2
296		pkts = append(pkts, &packet{
297			record: &recordlayer.RecordLayer{
298				Header: recordlayer.Header{
299					Version: protocol.Version1_2,
300				},
301				Content: &handshake.Handshake{
302					Message: &handshake.MessageServerKeyExchange{
303						IdentityHint: cfg.localPSKIdentityHint,
304					},
305				},
306			},
307		})
308	case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous:
309		pkts = append(pkts, &packet{
310			record: &recordlayer.RecordLayer{
311				Header: recordlayer.Header{
312					Version: protocol.Version1_2,
313				},
314				Content: &handshake.Handshake{
315					Message: &handshake.MessageServerKeyExchange{
316						EllipticCurveType: elliptic.CurveTypeNamedCurve,
317						NamedCurve:        state.namedCurve,
318						PublicKey:         state.localKeypair.PublicKey,
319					},
320				},
321			},
322		})
323	}
324
325	pkts = append(pkts, &packet{
326		record: &recordlayer.RecordLayer{
327			Header: recordlayer.Header{
328				Version: protocol.Version1_2,
329			},
330			Content: &handshake.Handshake{
331				Message: &handshake.MessageServerHelloDone{},
332			},
333		},
334	})
335
336	return pkts, nil, nil
337}
338