1/* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
4 */
5
6package device
7
8import (
9	"bytes"
10	"encoding/binary"
11	"testing"
12
13	"golang.zx2c4.com/wireguard/conn"
14	"golang.zx2c4.com/wireguard/tun/tuntest"
15)
16
17func TestCurveWrappers(t *testing.T) {
18	sk1, err := newPrivateKey()
19	assertNil(t, err)
20
21	sk2, err := newPrivateKey()
22	assertNil(t, err)
23
24	pk1 := sk1.publicKey()
25	pk2 := sk2.publicKey()
26
27	ss1 := sk1.sharedSecret(pk2)
28	ss2 := sk2.sharedSecret(pk1)
29
30	if ss1 != ss2 {
31		t.Fatal("Failed to compute shared secet")
32	}
33}
34
35func randDevice(t *testing.T) *Device {
36	sk, err := newPrivateKey()
37	if err != nil {
38		t.Fatal(err)
39	}
40	tun := tuntest.NewChannelTUN()
41	logger := NewLogger(LogLevelError, "")
42	device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
43	device.SetPrivateKey(sk)
44	return device
45}
46
47func assertNil(t *testing.T, err error) {
48	if err != nil {
49		t.Fatal(err)
50	}
51}
52
53func assertEqual(t *testing.T, a, b []byte) {
54	if !bytes.Equal(a, b) {
55		t.Fatal(a, "!=", b)
56	}
57}
58
59func TestNoiseHandshake(t *testing.T) {
60	dev1 := randDevice(t)
61	dev2 := randDevice(t)
62
63	defer dev1.Close()
64	defer dev2.Close()
65
66	peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
67	if err != nil {
68		t.Fatal(err)
69	}
70	peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
71	if err != nil {
72		t.Fatal(err)
73	}
74
75	assertEqual(
76		t,
77		peer1.handshake.precomputedStaticStatic[:],
78		peer2.handshake.precomputedStaticStatic[:],
79	)
80
81	/* simulate handshake */
82
83	// initiation message
84
85	t.Log("exchange initiation message")
86
87	msg1, err := dev1.CreateMessageInitiation(peer2)
88	assertNil(t, err)
89
90	packet := make([]byte, 0, 256)
91	writer := bytes.NewBuffer(packet)
92	err = binary.Write(writer, binary.LittleEndian, msg1)
93	assertNil(t, err)
94	peer := dev2.ConsumeMessageInitiation(msg1)
95	if peer == nil {
96		t.Fatal("handshake failed at initiation message")
97	}
98
99	assertEqual(
100		t,
101		peer1.handshake.chainKey[:],
102		peer2.handshake.chainKey[:],
103	)
104
105	assertEqual(
106		t,
107		peer1.handshake.hash[:],
108		peer2.handshake.hash[:],
109	)
110
111	// response message
112
113	t.Log("exchange response message")
114
115	msg2, err := dev2.CreateMessageResponse(peer1)
116	assertNil(t, err)
117
118	peer = dev1.ConsumeMessageResponse(msg2)
119	if peer == nil {
120		t.Fatal("handshake failed at response message")
121	}
122
123	assertEqual(
124		t,
125		peer1.handshake.chainKey[:],
126		peer2.handshake.chainKey[:],
127	)
128
129	assertEqual(
130		t,
131		peer1.handshake.hash[:],
132		peer2.handshake.hash[:],
133	)
134
135	// key pairs
136
137	t.Log("deriving keys")
138
139	err = peer1.BeginSymmetricSession()
140	if err != nil {
141		t.Fatal("failed to derive keypair for peer 1", err)
142	}
143
144	err = peer2.BeginSymmetricSession()
145	if err != nil {
146		t.Fatal("failed to derive keypair for peer 2", err)
147	}
148
149	key1 := peer1.keypairs.loadNext()
150	key2 := peer2.keypairs.current
151
152	// encrypting / decryption test
153
154	t.Log("test key pairs")
155
156	func() {
157		testMsg := []byte("wireguard test message 1")
158		var err error
159		var out []byte
160		var nonce [12]byte
161		out = key1.send.Seal(out, nonce[:], testMsg, nil)
162		out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
163		assertNil(t, err)
164		assertEqual(t, out, testMsg)
165	}()
166
167	func() {
168		testMsg := []byte("wireguard test message 2")
169		var err error
170		var out []byte
171		var nonce [12]byte
172		out = key2.send.Seal(out, nonce[:], testMsg, nil)
173		out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
174		assertNil(t, err)
175		assertEqual(t, out, testMsg)
176	}()
177}
178