1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package agent
6
7import (
8	"crypto"
9	"crypto/rand"
10	"fmt"
11	pseudorand "math/rand"
12	"reflect"
13	"strings"
14	"testing"
15
16	"golang.org/x/crypto/ssh"
17)
18
19func TestServer(t *testing.T) {
20	c1, c2, err := netPipe()
21	if err != nil {
22		t.Fatalf("netPipe: %v", err)
23	}
24	defer c1.Close()
25	defer c2.Close()
26	client := NewClient(c1)
27
28	go ServeAgent(NewKeyring(), c2)
29
30	testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0)
31}
32
33func TestLockServer(t *testing.T) {
34	testLockAgent(NewKeyring(), t)
35}
36
37func TestSetupForwardAgent(t *testing.T) {
38	a, b, err := netPipe()
39	if err != nil {
40		t.Fatalf("netPipe: %v", err)
41	}
42
43	defer a.Close()
44	defer b.Close()
45
46	_, socket, cleanup := startOpenSSHAgent(t)
47	defer cleanup()
48
49	serverConf := ssh.ServerConfig{
50		NoClientAuth: true,
51	}
52	serverConf.AddHostKey(testSigners["rsa"])
53	incoming := make(chan *ssh.ServerConn, 1)
54	go func() {
55		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
56		if err != nil {
57			t.Fatalf("Server: %v", err)
58		}
59		incoming <- conn
60	}()
61
62	conf := ssh.ClientConfig{
63		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
64	}
65	conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
66	if err != nil {
67		t.Fatalf("NewClientConn: %v", err)
68	}
69	client := ssh.NewClient(conn, chans, reqs)
70
71	if err := ForwardToRemote(client, socket); err != nil {
72		t.Fatalf("SetupForwardAgent: %v", err)
73	}
74
75	server := <-incoming
76	ch, reqs, err := server.OpenChannel(channelType, nil)
77	if err != nil {
78		t.Fatalf("OpenChannel(%q): %v", channelType, err)
79	}
80	go ssh.DiscardRequests(reqs)
81
82	agentClient := NewClient(ch)
83	testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0)
84	conn.Close()
85}
86
87func TestV1ProtocolMessages(t *testing.T) {
88	c1, c2, err := netPipe()
89	if err != nil {
90		t.Fatalf("netPipe: %v", err)
91	}
92	defer c1.Close()
93	defer c2.Close()
94	c := NewClient(c1)
95
96	go ServeAgent(NewKeyring(), c2)
97
98	testV1ProtocolMessages(t, c.(*client))
99}
100
101func testV1ProtocolMessages(t *testing.T, c *client) {
102	reply, err := c.call([]byte{agentRequestV1Identities})
103	if err != nil {
104		t.Fatalf("v1 request all failed: %v", err)
105	}
106	if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 {
107		t.Fatalf("invalid request all response: %#v", reply)
108	}
109
110	reply, err = c.call([]byte{agentRemoveAllV1Identities})
111	if err != nil {
112		t.Fatalf("v1 remove all failed: %v", err)
113	}
114	if _, ok := reply.(*successAgentMsg); !ok {
115		t.Fatalf("invalid remove all response: %#v", reply)
116	}
117}
118
119func verifyKey(sshAgent Agent) error {
120	keys, err := sshAgent.List()
121	if err != nil {
122		return fmt.Errorf("listing keys: %v", err)
123	}
124
125	if len(keys) != 1 {
126		return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
127	}
128
129	buf := make([]byte, 128)
130	if _, err := rand.Read(buf); err != nil {
131		return fmt.Errorf("rand: %v", err)
132	}
133
134	sig, err := sshAgent.Sign(keys[0], buf)
135	if err != nil {
136		return fmt.Errorf("sign: %v", err)
137	}
138
139	if err := keys[0].Verify(buf, sig); err != nil {
140		return fmt.Errorf("verify: %v", err)
141	}
142	return nil
143}
144
145func addKeyToAgent(key crypto.PrivateKey) error {
146	sshAgent := NewKeyring()
147	if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
148		return fmt.Errorf("add: %v", err)
149	}
150	return verifyKey(sshAgent)
151}
152
153func TestKeyTypes(t *testing.T) {
154	for k, v := range testPrivateKeys {
155		if err := addKeyToAgent(v); err != nil {
156			t.Errorf("error adding key type %s, %v", k, err)
157		}
158		if err := addCertToAgentSock(v, nil); err != nil {
159			t.Errorf("error adding key type %s, %v", k, err)
160		}
161	}
162}
163
164func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error {
165	a, b, err := netPipe()
166	if err != nil {
167		return err
168	}
169	agentServer := NewKeyring()
170	go ServeAgent(agentServer, a)
171
172	agentClient := NewClient(b)
173	if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
174		return fmt.Errorf("add: %v", err)
175	}
176	return verifyKey(agentClient)
177}
178
179func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
180	sshAgent := NewKeyring()
181	if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
182		return fmt.Errorf("add: %v", err)
183	}
184	return verifyKey(sshAgent)
185}
186
187func TestCertTypes(t *testing.T) {
188	for keyType, key := range testPublicKeys {
189		cert := &ssh.Certificate{
190			ValidPrincipals: []string{"gopher1"},
191			ValidAfter:      0,
192			ValidBefore:     ssh.CertTimeInfinity,
193			Key:             key,
194			Serial:          1,
195			CertType:        ssh.UserCert,
196			SignatureKey:    testPublicKeys["rsa"],
197			Permissions: ssh.Permissions{
198				CriticalOptions: map[string]string{},
199				Extensions:      map[string]string{},
200			},
201		}
202		if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
203			t.Fatalf("signcert: %v", err)
204		}
205		if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
206			t.Fatalf("%v", err)
207		}
208		if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil {
209			t.Fatalf("%v", err)
210		}
211	}
212}
213
214func TestParseConstraints(t *testing.T) {
215	// Test LifetimeSecs
216	var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
217	lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
218	if err != nil {
219		t.Fatalf("parseConstraints: %v", err)
220	}
221	if lifetimeSecs != msg.LifetimeSecs {
222		t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
223	}
224
225	// Test ConfirmBeforeUse
226	_, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
227	if err != nil {
228		t.Fatalf("%v", err)
229	}
230	if !confirmBeforeUse {
231		t.Error("got comfirmBeforeUse == false")
232	}
233
234	// Test ConstraintExtensions
235	var data []byte
236	var expect []ConstraintExtension
237	for i := 0; i < 10; i++ {
238		var ext = ConstraintExtension{
239			ExtensionName:    fmt.Sprintf("name%d", i),
240			ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
241		}
242		expect = append(expect, ext)
243		data = append(data, agentConstrainExtension)
244		data = append(data, ssh.Marshal(ext)...)
245	}
246	_, _, extensions, err := parseConstraints(data)
247	if err != nil {
248		t.Fatalf("%v", err)
249	}
250	if !reflect.DeepEqual(expect, extensions) {
251		t.Errorf("got extension %v, want %v", extensions, expect)
252	}
253
254	// Test Unknown Constraint
255	_, _, _, err = parseConstraints([]byte{128})
256	if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
257		t.Errorf("unexpected error: %v", err)
258	}
259}
260