1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build aix darwin dragonfly freebsd linux netbsd openbsd plan9
6
7package test
8
9// functional test harness for unix.
10
11import (
12	"bytes"
13	"crypto/rand"
14	"encoding/base64"
15	"fmt"
16	"io/ioutil"
17	"log"
18	"net"
19	"os"
20	"os/exec"
21	"os/user"
22	"path/filepath"
23	"testing"
24	"text/template"
25
26	"golang.org/x/crypto/ssh"
27	"golang.org/x/crypto/ssh/testdata"
28)
29
30const (
31	defaultSshdConfig = `
32Protocol 2
33Banner {{.Dir}}/banner
34HostKey {{.Dir}}/id_rsa
35HostKey {{.Dir}}/id_dsa
36HostKey {{.Dir}}/id_ecdsa
37HostCertificate {{.Dir}}/id_rsa-cert.pub
38Pidfile {{.Dir}}/sshd.pid
39#UsePrivilegeSeparation no
40KeyRegenerationInterval 3600
41ServerKeyBits 768
42SyslogFacility AUTH
43LogLevel DEBUG2
44LoginGraceTime 120
45PermitRootLogin no
46StrictModes no
47RSAAuthentication yes
48PubkeyAuthentication yes
49AuthorizedKeysFile	{{.Dir}}/authorized_keys
50TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
51IgnoreRhosts yes
52RhostsRSAAuthentication no
53HostbasedAuthentication no
54PubkeyAcceptedKeyTypes=*
55`
56	multiAuthSshdConfigTail = `
57UsePAM yes
58PasswordAuthentication yes
59ChallengeResponseAuthentication yes
60AuthenticationMethods {{.AuthMethods}}
61`
62)
63
64var configTmpl = map[string]*template.Template{
65	"default":   template.Must(template.New("").Parse(defaultSshdConfig)),
66	"MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail))}
67
68type server struct {
69	t          *testing.T
70	cleanup    func() // executed during Shutdown
71	configfile string
72	cmd        *exec.Cmd
73	output     bytes.Buffer // holds stderr from sshd process
74
75	testUser     string // test username for sshd
76	testPasswd   string // test password for sshd
77	sshdTestPwSo string // dynamic library to inject a custom password into sshd
78
79	// Client half of the network connection.
80	clientConn net.Conn
81}
82
83func username() string {
84	var username string
85	if user, err := user.Current(); err == nil {
86		username = user.Username
87	} else {
88		// user.Current() currently requires cgo. If an error is
89		// returned attempt to get the username from the environment.
90		log.Printf("user.Current: %v; falling back on $USER", err)
91		username = os.Getenv("USER")
92	}
93	if username == "" {
94		panic("Unable to get username")
95	}
96	return username
97}
98
99type storedHostKey struct {
100	// keys map from an algorithm string to binary key data.
101	keys map[string][]byte
102
103	// checkCount counts the Check calls. Used for testing
104	// rekeying.
105	checkCount int
106}
107
108func (k *storedHostKey) Add(key ssh.PublicKey) {
109	if k.keys == nil {
110		k.keys = map[string][]byte{}
111	}
112	k.keys[key.Type()] = key.Marshal()
113}
114
115func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
116	k.checkCount++
117	algo := key.Type()
118
119	if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
120		return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
121	}
122	return nil
123}
124
125func hostKeyDB() *storedHostKey {
126	keyChecker := &storedHostKey{}
127	keyChecker.Add(testPublicKeys["ecdsa"])
128	keyChecker.Add(testPublicKeys["rsa"])
129	keyChecker.Add(testPublicKeys["dsa"])
130	return keyChecker
131}
132
133func clientConfig() *ssh.ClientConfig {
134	config := &ssh.ClientConfig{
135		User: username(),
136		Auth: []ssh.AuthMethod{
137			ssh.PublicKeys(testSigners["user"]),
138		},
139		HostKeyCallback: hostKeyDB().Check,
140		HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker
141			ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
142			ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
143			ssh.KeyAlgoED25519,
144		},
145	}
146	return config
147}
148
149// unixConnection creates two halves of a connected net.UnixConn.  It
150// is used for connecting the Go SSH client with sshd without opening
151// ports.
152func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
153	dir, err := ioutil.TempDir("", "unixConnection")
154	if err != nil {
155		return nil, nil, err
156	}
157	defer os.Remove(dir)
158
159	addr := filepath.Join(dir, "ssh")
160	listener, err := net.Listen("unix", addr)
161	if err != nil {
162		return nil, nil, err
163	}
164	defer listener.Close()
165	c1, err := net.Dial("unix", addr)
166	if err != nil {
167		return nil, nil, err
168	}
169
170	c2, err := listener.Accept()
171	if err != nil {
172		c1.Close()
173		return nil, nil, err
174	}
175
176	return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
177}
178
179func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
180	return s.TryDialWithAddr(config, "")
181}
182
183// addr is the user specified host:port. While we don't actually dial it,
184// we need to know this for host key matching
185func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
186	sshd, err := exec.LookPath("sshd")
187	if err != nil {
188		s.t.Skipf("skipping test: %v", err)
189	}
190
191	c1, c2, err := unixConnection()
192	if err != nil {
193		s.t.Fatalf("unixConnection: %v", err)
194	}
195
196	s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
197	f, err := c2.File()
198	if err != nil {
199		s.t.Fatalf("UnixConn.File: %v", err)
200	}
201	defer f.Close()
202	s.cmd.Stdin = f
203	s.cmd.Stdout = f
204	s.cmd.Stderr = &s.output
205
206	if s.sshdTestPwSo != "" {
207		if s.testUser == "" {
208			s.t.Fatal("user missing from sshd_test_pw.so config")
209		}
210		if s.testPasswd == "" {
211			s.t.Fatal("password missing from sshd_test_pw.so config")
212		}
213		s.cmd.Env = append(os.Environ(),
214			fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
215			fmt.Sprintf("TEST_USER=%s", s.testUser),
216			fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
217	}
218
219	if err := s.cmd.Start(); err != nil {
220		s.t.Fail()
221		s.Shutdown()
222		s.t.Fatalf("s.cmd.Start: %v", err)
223	}
224	s.clientConn = c1
225	conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
226	if err != nil {
227		return nil, err
228	}
229	return ssh.NewClient(conn, chans, reqs), nil
230}
231
232func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
233	conn, err := s.TryDial(config)
234	if err != nil {
235		s.t.Fail()
236		s.Shutdown()
237		s.t.Fatalf("ssh.Client: %v", err)
238	}
239	return conn
240}
241
242func (s *server) Shutdown() {
243	if s.cmd != nil && s.cmd.Process != nil {
244		// Don't check for errors; if it fails it's most
245		// likely "os: process already finished", and we don't
246		// care about that. Use os.Interrupt, so child
247		// processes are killed too.
248		s.cmd.Process.Signal(os.Interrupt)
249		s.cmd.Wait()
250	}
251	if s.t.Failed() {
252		// log any output from sshd process
253		s.t.Logf("sshd: %s", s.output.String())
254	}
255	s.cleanup()
256}
257
258func writeFile(path string, contents []byte) {
259	f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
260	if err != nil {
261		panic(err)
262	}
263	defer f.Close()
264	if _, err := f.Write(contents); err != nil {
265		panic(err)
266	}
267}
268
269// generate random password
270func randomPassword() (string, error) {
271	b := make([]byte, 12)
272	_, err := rand.Read(b)
273	if err != nil {
274		return "", err
275	}
276	return base64.RawURLEncoding.EncodeToString(b), nil
277}
278
279// setTestPassword is used for setting user and password data for sshd_test_pw.so
280// This function also checks that ./sshd_test_pw.so exists and if not calls s.t.Skip()
281func (s *server) setTestPassword(user, passwd string) error {
282	wd, _ := os.Getwd()
283	wrapper := filepath.Join(wd, "sshd_test_pw.so")
284	if _, err := os.Stat(wrapper); err != nil {
285		s.t.Skip(fmt.Errorf("sshd_test_pw.so is not available"))
286		return err
287	}
288
289	s.sshdTestPwSo = wrapper
290	s.testUser = user
291	s.testPasswd = passwd
292	return nil
293}
294
295// newServer returns a new mock ssh server.
296func newServer(t *testing.T) *server {
297	return newServerForConfig(t, "default", map[string]string{})
298}
299
300// newServerForConfig returns a new mock ssh server.
301func newServerForConfig(t *testing.T, config string, configVars map[string]string) *server {
302	if testing.Short() {
303		t.Skip("skipping test due to -short")
304	}
305	u, err := user.Current()
306	if err != nil {
307		t.Fatalf("user.Current: %v", err)
308	}
309	uname := u.Name
310	if uname == "" {
311		// Check the value of u.Username as u.Name
312		// can be "" on some OSes like AIX.
313		uname = u.Username
314	}
315	if uname == "root" {
316		t.Skip("skipping test because current user is root")
317	}
318	dir, err := ioutil.TempDir("", "sshtest")
319	if err != nil {
320		t.Fatal(err)
321	}
322	f, err := os.Create(filepath.Join(dir, "sshd_config"))
323	if err != nil {
324		t.Fatal(err)
325	}
326	if _, ok := configTmpl[config]; ok == false {
327		t.Fatal(fmt.Errorf("Invalid server config '%s'", config))
328	}
329	configVars["Dir"] = dir
330	err = configTmpl[config].Execute(f, configVars)
331	if err != nil {
332		t.Fatal(err)
333	}
334	f.Close()
335
336	writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
337
338	for k, v := range testdata.PEMBytes {
339		filename := "id_" + k
340		writeFile(filepath.Join(dir, filename), v)
341		writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
342	}
343
344	for k, v := range testdata.SSHCertificates {
345		filename := "id_" + k + "-cert.pub"
346		writeFile(filepath.Join(dir, filename), v)
347	}
348
349	var authkeys bytes.Buffer
350	for k := range testdata.PEMBytes {
351		authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
352	}
353	writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
354
355	return &server{
356		t:          t,
357		configfile: f.Name(),
358		cleanup: func() {
359			if err := os.RemoveAll(dir); err != nil {
360				t.Error(err)
361			}
362		},
363	}
364}
365
366func newTempSocket(t *testing.T) (string, func()) {
367	dir, err := ioutil.TempDir("", "socket")
368	if err != nil {
369		t.Fatal(err)
370	}
371	deferFunc := func() { os.RemoveAll(dir) }
372	addr := filepath.Join(dir, "sock")
373	return addr, deferFunc
374}
375