1/*
2 * Copyright (c) 2014, Yawning Angel <yawning at schwanenlied dot me>
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  * Redistributions of source code must retain the above copyright notice,
9 *    this list of conditions and the following disclaimer.
10 *
11 *  * Redistributions in binary form must reproduce the above copyright notice,
12 *    this list of conditions and the following disclaimer in the documentation
13 *    and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
19 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
20 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
21 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
22 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
23 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
24 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
25 * POSSIBILITY OF SUCH DAMAGE.
26 */
27
28// Package ntor implements the Tor Project's ntor handshake as defined in
29// proposal 216 "Improved circuit-creation key exchange".  It also supports
30// using Elligator to transform the Curve25519 public keys sent over the wire
31// to a form that is indistinguishable from random strings.
32//
33// Before using this package, it is strongly recommended that the specification
34// is read and understood.
35package ntor // import "gitlab.com/yawning/obfs4.git/common/ntor"
36
37import (
38	"bytes"
39	"crypto/hmac"
40	"crypto/sha256"
41	"crypto/subtle"
42	"encoding/hex"
43	"fmt"
44	"io"
45
46	"github.com/agl/ed25519/extra25519"
47	"gitlab.com/yawning/obfs4.git/common/csrand"
48	"golang.org/x/crypto/curve25519"
49	"golang.org/x/crypto/hkdf"
50)
51
52const (
53	// PublicKeyLength is the length of a Curve25519 public key.
54	PublicKeyLength = 32
55
56	// RepresentativeLength is the length of an Elligator representative.
57	RepresentativeLength = 32
58
59	// PrivateKeyLength is the length of a Curve25519 private key.
60	PrivateKeyLength = 32
61
62	// SharedSecretLength is the length of a Curve25519 shared secret.
63	SharedSecretLength = 32
64
65	// NodeIDLength is the length of a ntor node identifier.
66	NodeIDLength = 20
67
68	// KeySeedLength is the length of the derived KEY_SEED.
69	KeySeedLength = sha256.Size
70
71	// AuthLength is the lenght of the derived AUTH.
72	AuthLength = sha256.Size
73)
74
75var protoID = []byte("ntor-curve25519-sha256-1")
76var tMac = append(protoID, []byte(":mac")...)
77var tKey = append(protoID, []byte(":key_extract")...)
78var tVerify = append(protoID, []byte(":key_verify")...)
79var mExpand = append(protoID, []byte(":key_expand")...)
80
81// PublicKeyLengthError is the error returned when the public key being
82// imported is an invalid length.
83type PublicKeyLengthError int
84
85func (e PublicKeyLengthError) Error() string {
86	return fmt.Sprintf("ntor: Invalid Curve25519 public key length: %d",
87		int(e))
88}
89
90// PrivateKeyLengthError is the error returned when the private key being
91// imported is an invalid length.
92type PrivateKeyLengthError int
93
94func (e PrivateKeyLengthError) Error() string {
95	return fmt.Sprintf("ntor: Invalid Curve25519 private key length: %d",
96		int(e))
97}
98
99// NodeIDLengthError is the error returned when the node ID being imported is
100// an invalid length.
101type NodeIDLengthError int
102
103func (e NodeIDLengthError) Error() string {
104	return fmt.Sprintf("ntor: Invalid NodeID length: %d", int(e))
105}
106
107// KeySeed is the key material that results from a handshake (KEY_SEED).
108type KeySeed [KeySeedLength]byte
109
110// Bytes returns a pointer to the raw key material.
111func (key_seed *KeySeed) Bytes() *[KeySeedLength]byte {
112	return (*[KeySeedLength]byte)(key_seed)
113}
114
115// Auth is the verifier that results from a handshake (AUTH).
116type Auth [AuthLength]byte
117
118// Bytes returns a pointer to the raw auth.
119func (auth *Auth) Bytes() *[AuthLength]byte {
120	return (*[AuthLength]byte)(auth)
121}
122
123// NodeID is a ntor node identifier.
124type NodeID [NodeIDLength]byte
125
126// NewNodeID creates a NodeID from the raw bytes.
127func NewNodeID(raw []byte) (*NodeID, error) {
128	if len(raw) != NodeIDLength {
129		return nil, NodeIDLengthError(len(raw))
130	}
131
132	nodeID := new(NodeID)
133	copy(nodeID[:], raw)
134
135	return nodeID, nil
136}
137
138// NodeIDFromHex creates a new NodeID from the hexdecimal representation.
139func NodeIDFromHex(encoded string) (*NodeID, error) {
140	raw, err := hex.DecodeString(encoded)
141	if err != nil {
142		return nil, err
143	}
144
145	return NewNodeID(raw)
146}
147
148// Bytes returns a pointer to the raw NodeID.
149func (id *NodeID) Bytes() *[NodeIDLength]byte {
150	return (*[NodeIDLength]byte)(id)
151}
152
153// Hex returns the hexdecimal representation of the NodeID.
154func (id *NodeID) Hex() string {
155	return hex.EncodeToString(id[:])
156}
157
158// PublicKey is a Curve25519 public key in little-endian byte order.
159type PublicKey [PublicKeyLength]byte
160
161// Bytes returns a pointer to the raw Curve25519 public key.
162func (public *PublicKey) Bytes() *[PublicKeyLength]byte {
163	return (*[PublicKeyLength]byte)(public)
164}
165
166// Hex returns the hexdecimal representation of the Curve25519 public key.
167func (public *PublicKey) Hex() string {
168	return hex.EncodeToString(public.Bytes()[:])
169}
170
171// NewPublicKey creates a PublicKey from the raw bytes.
172func NewPublicKey(raw []byte) (*PublicKey, error) {
173	if len(raw) != PublicKeyLength {
174		return nil, PublicKeyLengthError(len(raw))
175	}
176
177	pubKey := new(PublicKey)
178	copy(pubKey[:], raw)
179
180	return pubKey, nil
181}
182
183// PublicKeyFromHex returns a PublicKey from the hexdecimal representation.
184func PublicKeyFromHex(encoded string) (*PublicKey, error) {
185	raw, err := hex.DecodeString(encoded)
186	if err != nil {
187		return nil, err
188	}
189
190	return NewPublicKey(raw)
191}
192
193// Representative is an Elligator representative of a Curve25519 public key
194// in little-endian byte order.
195type Representative [RepresentativeLength]byte
196
197// Bytes returns a pointer to the raw Elligator representative.
198func (repr *Representative) Bytes() *[RepresentativeLength]byte {
199	return (*[RepresentativeLength]byte)(repr)
200}
201
202// ToPublic converts a Elligator representative to a Curve25519 public key.
203func (repr *Representative) ToPublic() *PublicKey {
204	pub := new(PublicKey)
205
206	extra25519.RepresentativeToPublicKey(pub.Bytes(), repr.Bytes())
207	return pub
208}
209
210// PrivateKey is a Curve25519 private key in little-endian byte order.
211type PrivateKey [PrivateKeyLength]byte
212
213// Bytes returns a pointer to the raw Curve25519 private key.
214func (private *PrivateKey) Bytes() *[PrivateKeyLength]byte {
215	return (*[PrivateKeyLength]byte)(private)
216}
217
218// Hex returns the hexdecimal representation of the Curve25519 private key.
219func (private *PrivateKey) Hex() string {
220	return hex.EncodeToString(private.Bytes()[:])
221}
222
223// Keypair is a Curve25519 keypair with an optional Elligator representative.
224// As only certain Curve25519 keys can be obfuscated with Elligator, the
225// representative must be generated along with the keypair.
226type Keypair struct {
227	public         *PublicKey
228	private        *PrivateKey
229	representative *Representative
230}
231
232// Public returns the Curve25519 public key belonging to the Keypair.
233func (keypair *Keypair) Public() *PublicKey {
234	return keypair.public
235}
236
237// Private returns the Curve25519 private key belonging to the Keypair.
238func (keypair *Keypair) Private() *PrivateKey {
239	return keypair.private
240}
241
242// Representative returns the Elligator representative of the public key
243// belonging to the Keypair.
244func (keypair *Keypair) Representative() *Representative {
245	return keypair.representative
246}
247
248// HasElligator returns true if the Keypair has an Elligator representative.
249func (keypair *Keypair) HasElligator() bool {
250	return nil != keypair.representative
251}
252
253// NewKeypair generates a new Curve25519 keypair, and optionally also generates
254// an Elligator representative of the public key.
255func NewKeypair(elligator bool) (*Keypair, error) {
256	keypair := new(Keypair)
257	keypair.private = new(PrivateKey)
258	keypair.public = new(PublicKey)
259	if elligator {
260		keypair.representative = new(Representative)
261	}
262
263	for {
264		// Generate a Curve25519 private key.  Like everyone who does this,
265		// run the CSPRNG output through SHA256 for extra tinfoil hattery.
266		priv := keypair.private.Bytes()[:]
267		if err := csrand.Bytes(priv); err != nil {
268			return nil, err
269		}
270		digest := sha256.Sum256(priv)
271		digest[0] &= 248
272		digest[31] &= 127
273		digest[31] |= 64
274		copy(priv, digest[:])
275
276		if elligator {
277			// Apply the Elligator transform.  This fails ~50% of the time.
278			if !extra25519.ScalarBaseMult(keypair.public.Bytes(),
279				keypair.representative.Bytes(),
280				keypair.private.Bytes()) {
281				continue
282			}
283		} else {
284			// Generate the corresponding Curve25519 public key.
285			curve25519.ScalarBaseMult(keypair.public.Bytes(),
286				keypair.private.Bytes())
287		}
288
289		return keypair, nil
290	}
291}
292
293// KeypairFromHex returns a Keypair from the hexdecimal representation of the
294// private key.
295func KeypairFromHex(encoded string) (*Keypair, error) {
296	raw, err := hex.DecodeString(encoded)
297	if err != nil {
298		return nil, err
299	}
300
301	if len(raw) != PrivateKeyLength {
302		return nil, PrivateKeyLengthError(len(raw))
303	}
304
305	keypair := new(Keypair)
306	keypair.private = new(PrivateKey)
307	keypair.public = new(PublicKey)
308
309	copy(keypair.private[:], raw)
310	curve25519.ScalarBaseMult(keypair.public.Bytes(),
311		keypair.private.Bytes())
312
313	return keypair, nil
314}
315
316// ServerHandshake does the server side of a ntor handshake and returns status,
317// KEY_SEED, and AUTH.  If status is not true, the handshake MUST be aborted.
318func ServerHandshake(clientPublic *PublicKey, serverKeypair *Keypair, idKeypair *Keypair, id *NodeID) (ok bool, keySeed *KeySeed, auth *Auth) {
319	var notOk int
320	var secretInput bytes.Buffer
321
322	// Server side uses EXP(X,y) | EXP(X,b)
323	var exp [SharedSecretLength]byte
324	curve25519.ScalarMult(&exp, serverKeypair.private.Bytes(),
325		clientPublic.Bytes())
326	notOk |= constantTimeIsZero(exp[:])
327	secretInput.Write(exp[:])
328
329	curve25519.ScalarMult(&exp, idKeypair.private.Bytes(),
330		clientPublic.Bytes())
331	notOk |= constantTimeIsZero(exp[:])
332	secretInput.Write(exp[:])
333
334	keySeed, auth = ntorCommon(secretInput, id, idKeypair.public,
335		clientPublic, serverKeypair.public)
336	return notOk == 0, keySeed, auth
337}
338
339// ClientHandshake does the client side of a ntor handshake and returnes
340// status, KEY_SEED, and AUTH.  If status is not true or AUTH does not match
341// the value recieved from the server, the handshake MUST be aborted.
342func ClientHandshake(clientKeypair *Keypair, serverPublic *PublicKey, idPublic *PublicKey, id *NodeID) (ok bool, keySeed *KeySeed, auth *Auth) {
343	var notOk int
344	var secretInput bytes.Buffer
345
346	// Client side uses EXP(Y,x) | EXP(B,x)
347	var exp [SharedSecretLength]byte
348	curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(),
349		serverPublic.Bytes())
350	notOk |= constantTimeIsZero(exp[:])
351	secretInput.Write(exp[:])
352
353	curve25519.ScalarMult(&exp, clientKeypair.private.Bytes(),
354		idPublic.Bytes())
355	notOk |= constantTimeIsZero(exp[:])
356	secretInput.Write(exp[:])
357
358	keySeed, auth = ntorCommon(secretInput, id, idPublic,
359		clientKeypair.public, serverPublic)
360	return notOk == 0, keySeed, auth
361}
362
363// CompareAuth does a constant time compare of a Auth and a byte slice
364// (presumably received over a network).
365func CompareAuth(auth1 *Auth, auth2 []byte) bool {
366	auth1Bytes := auth1.Bytes()
367	return hmac.Equal(auth1Bytes[:], auth2)
368}
369
370func ntorCommon(secretInput bytes.Buffer, id *NodeID, b *PublicKey, x *PublicKey, y *PublicKey) (*KeySeed, *Auth) {
371	keySeed := new(KeySeed)
372	auth := new(Auth)
373
374	// secret_input/auth_input use this common bit, build it once.
375	suffix := bytes.NewBuffer(b.Bytes()[:])
376	suffix.Write(b.Bytes()[:])
377	suffix.Write(x.Bytes()[:])
378	suffix.Write(y.Bytes()[:])
379	suffix.Write(protoID)
380	suffix.Write(id[:])
381
382	// At this point secret_input has the 2 exponents, concatenated, append the
383	// client/server common suffix.
384	secretInput.Write(suffix.Bytes())
385
386	// KEY_SEED = H(secret_input, t_key)
387	h := hmac.New(sha256.New, tKey)
388	_, _ = h.Write(secretInput.Bytes())
389	tmp := h.Sum(nil)
390	copy(keySeed[:], tmp)
391
392	// verify = H(secret_input, t_verify)
393	h = hmac.New(sha256.New, tVerify)
394	_, _ = h.Write(secretInput.Bytes())
395	verify := h.Sum(nil)
396
397	// auth_input = verify | ID | B | Y | X | PROTOID | "Server"
398	authInput := bytes.NewBuffer(verify)
399	_, _ = authInput.Write(suffix.Bytes())
400	_, _ = authInput.Write([]byte("Server"))
401	h = hmac.New(sha256.New, tMac)
402	_, _ = h.Write(authInput.Bytes())
403	tmp = h.Sum(nil)
404	copy(auth[:], tmp)
405
406	return keySeed, auth
407}
408
409func constantTimeIsZero(x []byte) int {
410	var ret byte
411	for _, v := range x {
412		ret |= v
413	}
414
415	return subtle.ConstantTimeByteEq(ret, 0)
416}
417
418// Kdf extracts and expands KEY_SEED via HKDF-SHA256 and returns `okm_len` bytes
419// of key material.
420func Kdf(keySeed []byte, okmLen int) []byte {
421	kdf := hkdf.New(sha256.New, keySeed, tKey, mExpand)
422	okm := make([]byte, okmLen)
423	n, err := io.ReadFull(kdf, okm)
424	if err != nil {
425		panic(fmt.Sprintf("BUG: Failed HKDF: %s", err.Error()))
426	} else if n != len(okm) {
427		panic(fmt.Sprintf("BUG: Got truncated HKDF output: %d", n))
428	}
429
430	return okm
431}
432