1/*
2 * Copyright (c) 2014, Yawning Angel <yawning at schwanenlied dot me>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  * Redistributions of source code must retain the above copyright notice,
9 *    this list of conditions and the following disclaimer.
10 *
11 *  * Redistributions in binary form must reproduce the above copyright notice,
12 *    this list of conditions and the following disclaimer in the documentation
13 *    and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
19 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
20 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
21 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
22 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
23 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
24 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
25 * POSSIBILITY OF SUCH DAMAGE.
26 */
27
28// Package ntor implements the Tor Project's ntor handshake as defined in
29// proposal 216 "Improved circuit-creation key exchange".  It also supports
30// using Elligator to transform the Curve25519 public keys sent over the wire
31// to a form that is indistinguishable from random strings.
32//
33// Before using this package, it is strongly recommended that the specification
34// is read and understood.
35package ntor // import "gitlab.com/yawning/obfs4.git/common/ntor"
36
37import (
38	"bytes"
39	"crypto/hmac"
40	"crypto/sha256"
41	"crypto/subtle"
42	"encoding/hex"
43	"fmt"
44	"io"
45
46	"golang.org/x/crypto/curve25519"
47	"golang.org/x/crypto/hkdf"
48
49	"gitlab.com/yawning/obfs4.git/common/csrand"
50	"gitlab.com/yawning/obfs4.git/internal/extra25519"
51)
52
53const (
54	// PublicKeyLength is the length of a Curve25519 public key.
55	PublicKeyLength = 32
56
57	// RepresentativeLength is the length of an Elligator representative.
58	RepresentativeLength = 32
59
60	// PrivateKeyLength is the length of a Curve25519 private key.
61	PrivateKeyLength = 32
62
63	// SharedSecretLength is the length of a Curve25519 shared secret.
64	SharedSecretLength = 32
65
66	// NodeIDLength is the length of a ntor node identifier.
67	NodeIDLength = 20
68
69	// KeySeedLength is the length of the derived KEY_SEED.
70	KeySeedLength = sha256.Size
71
72	// AuthLength is the lenght of the derived AUTH.
73	AuthLength = sha256.Size
74)
75
76var protoID = []byte("ntor-curve25519-sha256-1")
77var tMac = append(protoID, []byte(":mac")...)
78var tKey = append(protoID, []byte(":key_extract")...)
79var tVerify = append(protoID, []byte(":key_verify")...)
80var mExpand = append(protoID, []byte(":key_expand")...)
81
82// PublicKeyLengthError is the error returned when the public key being
83// imported is an invalid length.
84type PublicKeyLengthError int
85
86func (e PublicKeyLengthError) Error() string {
87	return fmt.Sprintf("ntor: Invalid Curve25519 public key length: %d",
88		int(e))
89}
90
91// PrivateKeyLengthError is the error returned when the private key being
92// imported is an invalid length.
93type PrivateKeyLengthError int
94
95func (e PrivateKeyLengthError) Error() string {
96	return fmt.Sprintf("ntor: Invalid Curve25519 private key length: %d",
97		int(e))
98}
99
100// NodeIDLengthError is the error returned when the node ID being imported is
101// an invalid length.
102type NodeIDLengthError int
103
104func (e NodeIDLengthError) Error() string {
105	return fmt.Sprintf("ntor: Invalid NodeID length: %d", int(e))
106}
107
108// KeySeed is the key material that results from a handshake (KEY_SEED).
109type KeySeed [KeySeedLength]byte
110
111// Bytes returns a pointer to the raw key material.
112func (key_seed *KeySeed) Bytes() *[KeySeedLength]byte {
113	return (*[KeySeedLength]byte)(key_seed)
114}
115
116// Auth is the verifier that results from a handshake (AUTH).
117type Auth [AuthLength]byte
118
119// Bytes returns a pointer to the raw auth.
120func (auth *Auth) Bytes() *[AuthLength]byte {
121	return (*[AuthLength]byte)(auth)
122}
123
124// NodeID is a ntor node identifier.
125type NodeID [NodeIDLength]byte
126
127// NewNodeID creates a NodeID from the raw bytes.
128func NewNodeID(raw []byte) (*NodeID, error) {
129	if len(raw) != NodeIDLength {
130		return nil, NodeIDLengthError(len(raw))
131	}
132
133	nodeID := new(NodeID)
134	copy(nodeID[:], raw)
135
136	return nodeID, nil
137}
138
139// NodeIDFromHex creates a new NodeID from the hexdecimal representation.
140func NodeIDFromHex(encoded string) (*NodeID, error) {
141	raw, err := hex.DecodeString(encoded)
142	if err != nil {
143		return nil, err
144	}
145
146	return NewNodeID(raw)
147}
148
149// Bytes returns a pointer to the raw NodeID.
150func (id *NodeID) Bytes() *[NodeIDLength]byte {
151	return (*[NodeIDLength]byte)(id)
152}
153
154// Hex returns the hexdecimal representation of the NodeID.
155func (id *NodeID) Hex() string {
156	return hex.EncodeToString(id[:])
157}
158
159// PublicKey is a Curve25519 public key in little-endian byte order.
160type PublicKey [PublicKeyLength]byte
161
162// Bytes returns a pointer to the raw Curve25519 public key.
163func (public *PublicKey) Bytes() *[PublicKeyLength]byte {
164	return (*[PublicKeyLength]byte)(public)
165}
166
167// Hex returns the hexdecimal representation of the Curve25519 public key.
168func (public *PublicKey) Hex() string {
169	return hex.EncodeToString(public.Bytes()[:])
170}
171
172// NewPublicKey creates a PublicKey from the raw bytes.
173func NewPublicKey(raw []byte) (*PublicKey, error) {
174	if len(raw) != PublicKeyLength {
175		return nil, PublicKeyLengthError(len(raw))
176	}
177
178	pubKey := new(PublicKey)
179	copy(pubKey[:], raw)
180
181	return pubKey, nil
182}
183
184// PublicKeyFromHex returns a PublicKey from the hexdecimal representation.
185func PublicKeyFromHex(encoded string) (*PublicKey, error) {
186	raw, err := hex.DecodeString(encoded)
187	if err != nil {
188		return nil, err
189	}
190
191	return NewPublicKey(raw)
192}
193
194// Representative is an Elligator representative of a Curve25519 public key
195// in little-endian byte order.
196type Representative [RepresentativeLength]byte
197
198// Bytes returns a pointer to the raw Elligator representative.
199func (repr *Representative) Bytes() *[RepresentativeLength]byte {
200	return (*[RepresentativeLength]byte)(repr)
201}
202
203// ToPublic converts a Elligator representative to a Curve25519 public key.
204func (repr *Representative) ToPublic() *PublicKey {
205	pub := new(PublicKey)
206
207	extra25519.UnsafeBrokenRepresentativeToPublicKey(pub.Bytes(), repr.Bytes())
208	return pub
209}
210
211// PrivateKey is a Curve25519 private key in little-endian byte order.
212type PrivateKey [PrivateKeyLength]byte
213
214// Bytes returns a pointer to the raw Curve25519 private key.
215func (private *PrivateKey) Bytes() *[PrivateKeyLength]byte {
216	return (*[PrivateKeyLength]byte)(private)
217}
218
219// Hex returns the hexdecimal representation of the Curve25519 private key.
220func (private *PrivateKey) Hex() string {
221	return hex.EncodeToString(private.Bytes()[:])
222}
223
224// Keypair is a Curve25519 keypair with an optional Elligator representative.
225// As only certain Curve25519 keys can be obfuscated with Elligator, the
226// representative must be generated along with the keypair.
227type Keypair struct {
228	public         *PublicKey
229	private        *PrivateKey
230	representative *Representative
231}
232
233// Public returns the Curve25519 public key belonging to the Keypair.
234func (keypair *Keypair) Public() *PublicKey {
235	return keypair.public
236}
237
238// Private returns the Curve25519 private key belonging to the Keypair.
239func (keypair *Keypair) Private() *PrivateKey {
240	return keypair.private
241}
242
243// Representative returns the Elligator representative of the public key
244// belonging to the Keypair.
245func (keypair *Keypair) Representative() *Representative {
246	return keypair.representative
247}
248
249// HasElligator returns true if the Keypair has an Elligator representative.
250func (keypair *Keypair) HasElligator() bool {
251	return nil != keypair.representative
252}
253
254// NewKeypair generates a new Curve25519 keypair, and optionally also generates
255// an Elligator representative of the public key.
256func NewKeypair(elligator bool) (*Keypair, error) {
257	keypair := new(Keypair)
258	keypair.private = new(PrivateKey)
259	keypair.public = new(PublicKey)
260	if elligator {
261		keypair.representative = new(Representative)
262	}
263
264	for {
265		// Generate a Curve25519 private key.  Like everyone who does this,
266		// run the CSPRNG output through SHA256 for extra tinfoil hattery.
267		priv := keypair.private.Bytes()[:]
268		if err := csrand.Bytes(priv); err != nil {
269			return nil, err
270		}
271		digest := sha256.Sum256(priv)
272		digest[0] &= 248
273		digest[31] &= 127
274		digest[31] |= 64
275		copy(priv, digest[:])
276
277		if elligator {
278			// Apply the Elligator transform.  This fails ~50% of the time.
279			if !extra25519.UnsafeBrokenScalarBaseMult(keypair.public.Bytes(),
280				keypair.representative.Bytes(),
281				keypair.private.Bytes()) {
282				continue
283			}
284		} else {
285			// Generate the corresponding Curve25519 public key.
286			curve25519.ScalarBaseMult(keypair.public.Bytes(),
287				keypair.private.Bytes())
288		}
289
290		return keypair, nil
291	}
292}
293
294// KeypairFromHex returns a Keypair from the hexdecimal representation of the
295// private key.
296func KeypairFromHex(encoded string) (*Keypair, error) {
297	raw, err := hex.DecodeString(encoded)
298	if err != nil {
299		return nil, err
300	}
301
302	if len(raw) != PrivateKeyLength {
303		return nil, PrivateKeyLengthError(len(raw))
304	}
305
306	keypair := new(Keypair)
307	keypair.private = new(PrivateKey)
308	keypair.public = new(PublicKey)
309
310	copy(keypair.private[:], raw)
311	curve25519.ScalarBaseMult(keypair.public.Bytes(),
312		keypair.private.Bytes())
313
314	return keypair, nil
315}
316
317// ServerHandshake does the server side of a ntor handshake and returns status,
318// KEY_SEED, and AUTH.  If status is not true, the handshake MUST be aborted.
319func ServerHandshake(clientPublic *PublicKey, serverKeypair *Keypair, idKeypair *Keypair, id *NodeID) (ok bool, keySeed *KeySeed, auth *Auth) {
320	var notOk int
321	var secretInput bytes.Buffer
322
323	// Server side uses EXP(X,y) | EXP(X,b)
324	var exp [SharedSecretLength]byte
325	curve25519.ScalarMult(&exp, serverKeypair.private.Bytes(),
326		clientPublic.Bytes())
327	notOk |= constantTimeIsZero(exp[:])
328	secretInput.Write(exp[:])
329
330	curve25519.ScalarMult(&exp, idKeypair.private.Bytes(),
331		clientPublic.Bytes())
332	notOk |= constantTimeIsZero(exp[:])
333	secretInput.Write(exp[:])
334
335	keySeed, auth = ntorCommon(secretInput, id, idKeypair.public,
336		clientPublic, serverKeypair.public)
337	return notOk == 0, keySeed, auth
338}
339
340// ClientHandshake does the client side of a ntor handshake and returnes
341// status, KEY_SEED, and AUTH.  If status is not true or AUTH does not match
342// the value recieved from the server, the handshake MUST be aborted.
343func ClientHandshake(clientKeypair *Keypair, serverPublic *PublicKey, idPublic *PublicKey, id *NodeID) (ok bool, keySeed *KeySeed, auth *Auth) {
344	var notOk int
345	var secretInput bytes.Buffer
346
347	// Client side uses EXP(Y,x) | EXP(B,x)
348	var exp [SharedSecretLength]byte
349	curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(),
350		serverPublic.Bytes())
351	notOk |= constantTimeIsZero(exp[:])
352	secretInput.Write(exp[:])
353
354	curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(),
355		idPublic.Bytes())
356	notOk |= constantTimeIsZero(exp[:])
357	secretInput.Write(exp[:])
358
359	keySeed, auth = ntorCommon(secretInput, id, idPublic,
360		clientKeypair.public, serverPublic)
361	return notOk == 0, keySeed, auth
362}
363
364// CompareAuth does a constant time compare of a Auth and a byte slice
365// (presumably received over a network).
366func CompareAuth(auth1 *Auth, auth2 []byte) bool {
367	auth1Bytes := auth1.Bytes()
368	return hmac.Equal(auth1Bytes[:], auth2)
369}
370
371func ntorCommon(secretInput bytes.Buffer, id *NodeID, b *PublicKey, x *PublicKey, y *PublicKey) (*KeySeed, *Auth) {
372	keySeed := new(KeySeed)
373	auth := new(Auth)
374
375	// secret_input/auth_input use this common bit, build it once.
376	suffix := bytes.NewBuffer(b.Bytes()[:])
377	suffix.Write(b.Bytes()[:])
378	suffix.Write(x.Bytes()[:])
379	suffix.Write(y.Bytes()[:])
380	suffix.Write(protoID)
381	suffix.Write(id[:])
382
383	// At this point secret_input has the 2 exponents, concatenated, append the
384	// client/server common suffix.
385	secretInput.Write(suffix.Bytes())
386
387	// KEY_SEED = H(secret_input, t_key)
388	h := hmac.New(sha256.New, tKey)
389	_, _ = h.Write(secretInput.Bytes())
390	tmp := h.Sum(nil)
391	copy(keySeed[:], tmp)
392
393	// verify = H(secret_input, t_verify)
394	h = hmac.New(sha256.New, tVerify)
395	_, _ = h.Write(secretInput.Bytes())
396	verify := h.Sum(nil)
397
398	// auth_input = verify | ID | B | Y | X | PROTOID | "Server"
399	authInput := bytes.NewBuffer(verify)
400	_, _ = authInput.Write(suffix.Bytes())
401	_, _ = authInput.Write([]byte("Server"))
402	h = hmac.New(sha256.New, tMac)
403	_, _ = h.Write(authInput.Bytes())
404	tmp = h.Sum(nil)
405	copy(auth[:], tmp)
406
407	return keySeed, auth
408}
409
410func constantTimeIsZero(x []byte) int {
411	var ret byte
412	for _, v := range x {
413		ret |= v
414	}
415
416	return subtle.ConstantTimeByteEq(ret, 0)
417}
418
419// Kdf extracts and expands KEY_SEED via HKDF-SHA256 and returns `okm_len` bytes
420// of key material.
421func Kdf(keySeed []byte, okmLen int) []byte {
422	kdf := hkdf.New(sha256.New, keySeed, tKey, mExpand)
423	okm := make([]byte, okmLen)
424	n, err := io.ReadFull(kdf, okm)
425	if err != nil {
426		panic(fmt.Sprintf("BUG: Failed HKDF: %s", err.Error()))
427	} else if n != len(okm) {
428		panic(fmt.Sprintf("BUG: Got truncated HKDF output: %d", n))
429	}
430
431	return okm
432}
433