1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package agent
6
7import (
8	"crypto"
9	"crypto/rand"
10	"fmt"
11	"testing"
12
13	"golang.org/x/crypto/ssh"
14)
15
16func TestServer(t *testing.T) {
17	c1, c2, err := netPipe()
18	if err != nil {
19		t.Fatalf("netPipe: %v", err)
20	}
21	defer c1.Close()
22	defer c2.Close()
23	client := NewClient(c1)
24
25	go ServeAgent(NewKeyring(), c2)
26
27	testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0)
28}
29
30func TestLockServer(t *testing.T) {
31	testLockAgent(NewKeyring(), t)
32}
33
34func TestSetupForwardAgent(t *testing.T) {
35	a, b, err := netPipe()
36	if err != nil {
37		t.Fatalf("netPipe: %v", err)
38	}
39
40	defer a.Close()
41	defer b.Close()
42
43	_, socket, cleanup := startAgent(t)
44	defer cleanup()
45
46	serverConf := ssh.ServerConfig{
47		NoClientAuth: true,
48	}
49	serverConf.AddHostKey(testSigners["rsa"])
50	incoming := make(chan *ssh.ServerConn, 1)
51	go func() {
52		conn, _, _, err := ssh.NewServerConn(a, &serverConf)
53		if err != nil {
54			t.Fatalf("Server: %v", err)
55		}
56		incoming <- conn
57	}()
58
59	conf := ssh.ClientConfig{}
60	conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
61	if err != nil {
62		t.Fatalf("NewClientConn: %v", err)
63	}
64	client := ssh.NewClient(conn, chans, reqs)
65
66	if err := ForwardToRemote(client, socket); err != nil {
67		t.Fatalf("SetupForwardAgent: %v", err)
68	}
69
70	server := <-incoming
71	ch, reqs, err := server.OpenChannel(channelType, nil)
72	if err != nil {
73		t.Fatalf("OpenChannel(%q): %v", channelType, err)
74	}
75	go ssh.DiscardRequests(reqs)
76
77	agentClient := NewClient(ch)
78	testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0)
79	conn.Close()
80}
81
82func TestV1ProtocolMessages(t *testing.T) {
83	c1, c2, err := netPipe()
84	if err != nil {
85		t.Fatalf("netPipe: %v", err)
86	}
87	defer c1.Close()
88	defer c2.Close()
89	c := NewClient(c1)
90
91	go ServeAgent(NewKeyring(), c2)
92
93	testV1ProtocolMessages(t, c.(*client))
94}
95
96func testV1ProtocolMessages(t *testing.T, c *client) {
97	reply, err := c.call([]byte{agentRequestV1Identities})
98	if err != nil {
99		t.Fatalf("v1 request all failed: %v", err)
100	}
101	if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 {
102		t.Fatalf("invalid request all response: %#v", reply)
103	}
104
105	reply, err = c.call([]byte{agentRemoveAllV1Identities})
106	if err != nil {
107		t.Fatalf("v1 remove all failed: %v", err)
108	}
109	if _, ok := reply.(*successAgentMsg); !ok {
110		t.Fatalf("invalid remove all response: %#v", reply)
111	}
112}
113
114func verifyKey(sshAgent Agent) error {
115	keys, err := sshAgent.List()
116	if err != nil {
117		return fmt.Errorf("listing keys: %v", err)
118	}
119
120	if len(keys) != 1 {
121		return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
122	}
123
124	buf := make([]byte, 128)
125	if _, err := rand.Read(buf); err != nil {
126		return fmt.Errorf("rand: %v", err)
127	}
128
129	sig, err := sshAgent.Sign(keys[0], buf)
130	if err != nil {
131		return fmt.Errorf("sign: %v", err)
132	}
133
134	if err := keys[0].Verify(buf, sig); err != nil {
135		return fmt.Errorf("verify: %v", err)
136	}
137	return nil
138}
139
140func addKeyToAgent(key crypto.PrivateKey) error {
141	sshAgent := NewKeyring()
142	if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
143		return fmt.Errorf("add: %v", err)
144	}
145	return verifyKey(sshAgent)
146}
147
148func TestKeyTypes(t *testing.T) {
149	for k, v := range testPrivateKeys {
150		if err := addKeyToAgent(v); err != nil {
151			t.Errorf("error adding key type %s, %v", k, err)
152		}
153		if err := addCertToAgentSock(v, nil); err != nil {
154			t.Errorf("error adding key type %s, %v", k, err)
155		}
156	}
157}
158
159func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error {
160	a, b, err := netPipe()
161	if err != nil {
162		return err
163	}
164	agentServer := NewKeyring()
165	go ServeAgent(agentServer, a)
166
167	agentClient := NewClient(b)
168	if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
169		return fmt.Errorf("add: %v", err)
170	}
171	return verifyKey(agentClient)
172}
173
174func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
175	sshAgent := NewKeyring()
176	if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
177		return fmt.Errorf("add: %v", err)
178	}
179	return verifyKey(sshAgent)
180}
181
182func TestCertTypes(t *testing.T) {
183	for keyType, key := range testPublicKeys {
184		cert := &ssh.Certificate{
185			ValidPrincipals: []string{"gopher1"},
186			ValidAfter:      0,
187			ValidBefore:     ssh.CertTimeInfinity,
188			Key:             key,
189			Serial:          1,
190			CertType:        ssh.UserCert,
191			SignatureKey:    testPublicKeys["rsa"],
192			Permissions: ssh.Permissions{
193				CriticalOptions: map[string]string{},
194				Extensions:      map[string]string{},
195			},
196		}
197		if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
198			t.Fatalf("signcert: %v", err)
199		}
200		if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
201			t.Fatalf("%v", err)
202		}
203		if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil {
204			t.Fatalf("%v", err)
205		}
206	}
207}
208