1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || plan9
6// +build aix darwin dragonfly freebsd linux netbsd openbsd plan9
7
8package test
9
10// functional test harness for unix.
11
12import (
13	"bytes"
14	"crypto/rand"
15	"encoding/base64"
16	"fmt"
17	"io/ioutil"
18	"log"
19	"net"
20	"os"
21	"os/exec"
22	"os/user"
23	"path/filepath"
24	"testing"
25	"text/template"
26
27	"golang.org/x/crypto/ssh"
28	"golang.org/x/crypto/ssh/testdata"
29)
30
31const (
32	defaultSshdConfig = `
33Protocol 2
34Banner {{.Dir}}/banner
35HostKey {{.Dir}}/id_rsa
36HostKey {{.Dir}}/id_dsa
37HostKey {{.Dir}}/id_ecdsa
38HostCertificate {{.Dir}}/id_rsa-cert.pub
39Pidfile {{.Dir}}/sshd.pid
40#UsePrivilegeSeparation no
41KeyRegenerationInterval 3600
42ServerKeyBits 768
43SyslogFacility AUTH
44LogLevel DEBUG2
45LoginGraceTime 120
46PermitRootLogin no
47StrictModes no
48RSAAuthentication yes
49PubkeyAuthentication yes
50AuthorizedKeysFile	{{.Dir}}/authorized_keys
51TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
52IgnoreRhosts yes
53RhostsRSAAuthentication no
54HostbasedAuthentication no
55PubkeyAcceptedKeyTypes=*
56`
57	multiAuthSshdConfigTail = `
58UsePAM yes
59PasswordAuthentication yes
60ChallengeResponseAuthentication yes
61AuthenticationMethods {{.AuthMethods}}
62`
63)
64
65var configTmpl = map[string]*template.Template{
66	"default":   template.Must(template.New("").Parse(defaultSshdConfig)),
67	"MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail))}
68
69type server struct {
70	t          *testing.T
71	cleanup    func() // executed during Shutdown
72	configfile string
73	cmd        *exec.Cmd
74	output     bytes.Buffer // holds stderr from sshd process
75
76	testUser     string // test username for sshd
77	testPasswd   string // test password for sshd
78	sshdTestPwSo string // dynamic library to inject a custom password into sshd
79
80	// Client half of the network connection.
81	clientConn net.Conn
82}
83
84func username() string {
85	var username string
86	if user, err := user.Current(); err == nil {
87		username = user.Username
88	} else {
89		// user.Current() currently requires cgo. If an error is
90		// returned attempt to get the username from the environment.
91		log.Printf("user.Current: %v; falling back on $USER", err)
92		username = os.Getenv("USER")
93	}
94	if username == "" {
95		panic("Unable to get username")
96	}
97	return username
98}
99
100type storedHostKey struct {
101	// keys map from an algorithm string to binary key data.
102	keys map[string][]byte
103
104	// checkCount counts the Check calls. Used for testing
105	// rekeying.
106	checkCount int
107}
108
109func (k *storedHostKey) Add(key ssh.PublicKey) {
110	if k.keys == nil {
111		k.keys = map[string][]byte{}
112	}
113	k.keys[key.Type()] = key.Marshal()
114}
115
116func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
117	k.checkCount++
118	algo := key.Type()
119
120	if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
121		return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
122	}
123	return nil
124}
125
126func hostKeyDB() *storedHostKey {
127	keyChecker := &storedHostKey{}
128	keyChecker.Add(testPublicKeys["ecdsa"])
129	keyChecker.Add(testPublicKeys["rsa"])
130	keyChecker.Add(testPublicKeys["dsa"])
131	return keyChecker
132}
133
134func clientConfig() *ssh.ClientConfig {
135	config := &ssh.ClientConfig{
136		User: username(),
137		Auth: []ssh.AuthMethod{
138			ssh.PublicKeys(testSigners["user"]),
139		},
140		HostKeyCallback: hostKeyDB().Check,
141		HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker
142			ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
143			ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
144			ssh.KeyAlgoED25519,
145		},
146	}
147	return config
148}
149
150// unixConnection creates two halves of a connected net.UnixConn.  It
151// is used for connecting the Go SSH client with sshd without opening
152// ports.
153func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
154	dir, err := ioutil.TempDir("", "unixConnection")
155	if err != nil {
156		return nil, nil, err
157	}
158	defer os.Remove(dir)
159
160	addr := filepath.Join(dir, "ssh")
161	listener, err := net.Listen("unix", addr)
162	if err != nil {
163		return nil, nil, err
164	}
165	defer listener.Close()
166	c1, err := net.Dial("unix", addr)
167	if err != nil {
168		return nil, nil, err
169	}
170
171	c2, err := listener.Accept()
172	if err != nil {
173		c1.Close()
174		return nil, nil, err
175	}
176
177	return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
178}
179
180func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
181	return s.TryDialWithAddr(config, "")
182}
183
184// addr is the user specified host:port. While we don't actually dial it,
185// we need to know this for host key matching
186func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
187	sshd, err := exec.LookPath("sshd")
188	if err != nil {
189		s.t.Skipf("skipping test: %v", err)
190	}
191
192	c1, c2, err := unixConnection()
193	if err != nil {
194		s.t.Fatalf("unixConnection: %v", err)
195	}
196
197	s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
198	f, err := c2.File()
199	if err != nil {
200		s.t.Fatalf("UnixConn.File: %v", err)
201	}
202	defer f.Close()
203	s.cmd.Stdin = f
204	s.cmd.Stdout = f
205	s.cmd.Stderr = &s.output
206
207	if s.sshdTestPwSo != "" {
208		if s.testUser == "" {
209			s.t.Fatal("user missing from sshd_test_pw.so config")
210		}
211		if s.testPasswd == "" {
212			s.t.Fatal("password missing from sshd_test_pw.so config")
213		}
214		s.cmd.Env = append(os.Environ(),
215			fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
216			fmt.Sprintf("TEST_USER=%s", s.testUser),
217			fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
218	}
219
220	if err := s.cmd.Start(); err != nil {
221		s.t.Fail()
222		s.Shutdown()
223		s.t.Fatalf("s.cmd.Start: %v", err)
224	}
225	s.clientConn = c1
226	conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
227	if err != nil {
228		return nil, err
229	}
230	return ssh.NewClient(conn, chans, reqs), nil
231}
232
233func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
234	conn, err := s.TryDial(config)
235	if err != nil {
236		s.t.Fail()
237		s.Shutdown()
238		s.t.Fatalf("ssh.Client: %v", err)
239	}
240	return conn
241}
242
243func (s *server) Shutdown() {
244	if s.cmd != nil && s.cmd.Process != nil {
245		// Don't check for errors; if it fails it's most
246		// likely "os: process already finished", and we don't
247		// care about that. Use os.Interrupt, so child
248		// processes are killed too.
249		s.cmd.Process.Signal(os.Interrupt)
250		s.cmd.Wait()
251	}
252	if s.t.Failed() {
253		// log any output from sshd process
254		s.t.Logf("sshd: %s", s.output.String())
255	}
256	s.cleanup()
257}
258
259func writeFile(path string, contents []byte) {
260	f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
261	if err != nil {
262		panic(err)
263	}
264	defer f.Close()
265	if _, err := f.Write(contents); err != nil {
266		panic(err)
267	}
268}
269
270// generate random password
271func randomPassword() (string, error) {
272	b := make([]byte, 12)
273	_, err := rand.Read(b)
274	if err != nil {
275		return "", err
276	}
277	return base64.RawURLEncoding.EncodeToString(b), nil
278}
279
280// setTestPassword is used for setting user and password data for sshd_test_pw.so
281// This function also checks that ./sshd_test_pw.so exists and if not calls s.t.Skip()
282func (s *server) setTestPassword(user, passwd string) error {
283	wd, _ := os.Getwd()
284	wrapper := filepath.Join(wd, "sshd_test_pw.so")
285	if _, err := os.Stat(wrapper); err != nil {
286		s.t.Skip(fmt.Errorf("sshd_test_pw.so is not available"))
287		return err
288	}
289
290	s.sshdTestPwSo = wrapper
291	s.testUser = user
292	s.testPasswd = passwd
293	return nil
294}
295
296// newServer returns a new mock ssh server.
297func newServer(t *testing.T) *server {
298	return newServerForConfig(t, "default", map[string]string{})
299}
300
301// newServerForConfig returns a new mock ssh server.
302func newServerForConfig(t *testing.T, config string, configVars map[string]string) *server {
303	if testing.Short() {
304		t.Skip("skipping test due to -short")
305	}
306	u, err := user.Current()
307	if err != nil {
308		t.Fatalf("user.Current: %v", err)
309	}
310	uname := u.Name
311	if uname == "" {
312		// Check the value of u.Username as u.Name
313		// can be "" on some OSes like AIX.
314		uname = u.Username
315	}
316	if uname == "root" {
317		t.Skip("skipping test because current user is root")
318	}
319	dir, err := ioutil.TempDir("", "sshtest")
320	if err != nil {
321		t.Fatal(err)
322	}
323	f, err := os.Create(filepath.Join(dir, "sshd_config"))
324	if err != nil {
325		t.Fatal(err)
326	}
327	if _, ok := configTmpl[config]; ok == false {
328		t.Fatal(fmt.Errorf("Invalid server config '%s'", config))
329	}
330	configVars["Dir"] = dir
331	err = configTmpl[config].Execute(f, configVars)
332	if err != nil {
333		t.Fatal(err)
334	}
335	f.Close()
336
337	writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
338
339	for k, v := range testdata.PEMBytes {
340		filename := "id_" + k
341		writeFile(filepath.Join(dir, filename), v)
342		writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
343	}
344
345	for k, v := range testdata.SSHCertificates {
346		filename := "id_" + k + "-cert.pub"
347		writeFile(filepath.Join(dir, filename), v)
348	}
349
350	var authkeys bytes.Buffer
351	for k := range testdata.PEMBytes {
352		authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
353	}
354	writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
355
356	return &server{
357		t:          t,
358		configfile: f.Name(),
359		cleanup: func() {
360			if err := os.RemoveAll(dir); err != nil {
361				t.Error(err)
362			}
363		},
364	}
365}
366
367func newTempSocket(t *testing.T) (string, func()) {
368	dir, err := ioutil.TempDir("", "socket")
369	if err != nil {
370		t.Fatal(err)
371	}
372	deferFunc := func() { os.RemoveAll(dir) }
373	addr := filepath.Join(dir, "sock")
374	return addr, deferFunc
375}
376