1package ssh
2
3import (
4	"fmt"
5	"io"
6	"io/ioutil"
7	"net"
8	"os"
9	"os/exec"
10	"runtime"
11	"strconv"
12	"strings"
13
14	"github.com/docker/docker/pkg/term"
15	"github.com/docker/machine/libmachine/log"
16	"github.com/docker/machine/libmachine/mcnutils"
17	"golang.org/x/crypto/ssh"
18	"golang.org/x/crypto/ssh/terminal"
19)
20
21type Client interface {
22	Output(command string) (string, error)
23	Shell(args ...string) error
24
25	// Start starts the specified command without waiting for it to finish. You
26	// have to call the Wait function for that.
27	//
28	// The first two io.ReadCloser are the standard output and the standard
29	// error of the executing command respectively. The returned error follows
30	// the same logic as in the exec.Cmd.Start function.
31	Start(command string) (io.ReadCloser, io.ReadCloser, error)
32
33	// Wait waits for the command started by the Start function to exit. The
34	// returned error follows the same logic as in the exec.Cmd.Wait function.
35	Wait() error
36}
37
38type ExternalClient struct {
39	BaseArgs   []string
40	BinaryPath string
41	cmd        *exec.Cmd
42}
43
44type NativeClient struct {
45	Config      ssh.ClientConfig
46	Hostname    string
47	Port        int
48	openSession *ssh.Session
49	openClient  *ssh.Client
50}
51
52type Auth struct {
53	Passwords []string
54	Keys      []string
55}
56
57type ClientType string
58
59const (
60	maxDialAttempts = 10
61)
62
63const (
64	External ClientType = "external"
65	Native   ClientType = "native"
66)
67
68var (
69	baseSSHArgs = []string{
70		"-F", "/dev/null",
71		"-o", "ConnectionAttempts=3", // retry 3 times if SSH connection fails
72		"-o", "ConnectTimeout=10", // timeout after 10 seconds
73		"-o", "ControlMaster=no", // disable ssh multiplexing
74		"-o", "ControlPath=none",
75		"-o", "LogLevel=quiet", // suppress "Warning: Permanently added '[localhost]:2022' (ECDSA) to the list of known hosts."
76		"-o", "PasswordAuthentication=no",
77		"-o", "ServerAliveInterval=60", // prevents connection to be dropped if command takes too long
78		"-o", "StrictHostKeyChecking=no",
79		"-o", "UserKnownHostsFile=/dev/null",
80	}
81	defaultClientType = External
82)
83
84func SetDefaultClient(clientType ClientType) {
85	// Allow over-riding of default client type, so that even if ssh binary
86	// is found in PATH we can still use the Go native implementation if
87	// desired.
88	switch clientType {
89	case External:
90		defaultClientType = External
91	case Native:
92		defaultClientType = Native
93	}
94}
95
96func NewClient(user string, host string, port int, auth *Auth) (Client, error) {
97	sshBinaryPath, err := exec.LookPath("ssh")
98	if err != nil {
99		log.Debug("SSH binary not found, using native Go implementation")
100		client, err := NewNativeClient(user, host, port, auth)
101		log.Debug(client)
102		return client, err
103	}
104
105	if defaultClientType == Native {
106		log.Debug("Using SSH client type: native")
107		client, err := NewNativeClient(user, host, port, auth)
108		log.Debug(client)
109		return client, err
110	}
111
112	log.Debug("Using SSH client type: external")
113	client, err := NewExternalClient(sshBinaryPath, user, host, port, auth)
114	log.Debug(client)
115	return client, err
116}
117
118func NewNativeClient(user, host string, port int, auth *Auth) (Client, error) {
119	config, err := NewNativeConfig(user, auth)
120	if err != nil {
121		return nil, fmt.Errorf("Error getting config for native Go SSH: %s", err)
122	}
123
124	return &NativeClient{
125		Config:   config,
126		Hostname: host,
127		Port:     port,
128	}, nil
129}
130
131func NewNativeConfig(user string, auth *Auth) (ssh.ClientConfig, error) {
132	var (
133		authMethods []ssh.AuthMethod
134	)
135
136	for _, k := range auth.Keys {
137		key, err := ioutil.ReadFile(k)
138		if err != nil {
139			return ssh.ClientConfig{}, err
140		}
141
142		privateKey, err := ssh.ParsePrivateKey(key)
143		if err != nil {
144			return ssh.ClientConfig{}, err
145		}
146
147		authMethods = append(authMethods, ssh.PublicKeys(privateKey))
148	}
149
150	for _, p := range auth.Passwords {
151		authMethods = append(authMethods, ssh.Password(p))
152	}
153
154	return ssh.ClientConfig{
155		User:            user,
156		Auth:            authMethods,
157		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
158	}, nil
159}
160
161func (client *NativeClient) dialSuccess() bool {
162	conn, err := ssh.Dial("tcp", net.JoinHostPort(client.Hostname, strconv.Itoa(client.Port)), &client.Config)
163	if err != nil {
164		log.Debugf("Error dialing TCP: %s", err)
165		return false
166	}
167	closeConn(conn)
168	return true
169}
170
171func (client *NativeClient) session(command string) (*ssh.Client, *ssh.Session, error) {
172	if err := mcnutils.WaitFor(client.dialSuccess); err != nil {
173		return nil, nil, fmt.Errorf("Error attempting SSH client dial: %s", err)
174	}
175
176	conn, err := ssh.Dial("tcp", net.JoinHostPort(client.Hostname, strconv.Itoa(client.Port)), &client.Config)
177	if err != nil {
178		return nil, nil, fmt.Errorf("Mysterious error dialing TCP for SSH (we already succeeded at least once) : %s", err)
179	}
180	session, err := conn.NewSession()
181
182	return conn, session, err
183}
184
185func (client *NativeClient) Output(command string) (string, error) {
186	conn, session, err := client.session(command)
187	if err != nil {
188		return "", nil
189	}
190	defer closeConn(conn)
191	defer session.Close()
192
193	output, err := session.CombinedOutput(command)
194
195	return string(output), err
196}
197
198func (client *NativeClient) OutputWithPty(command string) (string, error) {
199	conn, session, err := client.session(command)
200	if err != nil {
201		return "", nil
202	}
203	defer closeConn(conn)
204	defer session.Close()
205
206	fd := int(os.Stdout.Fd())
207
208	termWidth, termHeight, err := terminal.GetSize(fd)
209	if err != nil {
210		return "", err
211	}
212
213	modes := ssh.TerminalModes{
214		ssh.ECHO:          0,
215		ssh.TTY_OP_ISPEED: 14400,
216		ssh.TTY_OP_OSPEED: 14400,
217	}
218
219	// request tty -- fixes error with hosts that use
220	// "Defaults requiretty" in /etc/sudoers - I'm looking at you RedHat
221	if err := session.RequestPty("xterm", termHeight, termWidth, modes); err != nil {
222		return "", err
223	}
224
225	output, err := session.CombinedOutput(command)
226
227	return string(output), err
228}
229
230func (client *NativeClient) Start(command string) (io.ReadCloser, io.ReadCloser, error) {
231	conn, session, err := client.session(command)
232	if err != nil {
233		return nil, nil, err
234	}
235
236	stdout, err := session.StdoutPipe()
237	if err != nil {
238		return nil, nil, err
239	}
240	stderr, err := session.StderrPipe()
241	if err != nil {
242		return nil, nil, err
243	}
244	if err := session.Start(command); err != nil {
245		return nil, nil, err
246	}
247
248	client.openClient = conn
249	client.openSession = session
250	return ioutil.NopCloser(stdout), ioutil.NopCloser(stderr), nil
251}
252
253func (client *NativeClient) Wait() error {
254	err := client.openSession.Wait()
255	if err != nil {
256		return err
257	}
258
259	_ = client.openSession.Close()
260
261	err = client.openClient.Close()
262	if err != nil {
263		return err
264	}
265
266	client.openSession = nil
267	client.openClient = nil
268	return nil
269}
270
271func (client *NativeClient) Shell(args ...string) error {
272	var (
273		termWidth, termHeight int
274	)
275	conn, err := ssh.Dial("tcp", net.JoinHostPort(client.Hostname, strconv.Itoa(client.Port)), &client.Config)
276	if err != nil {
277		return err
278	}
279	defer closeConn(conn)
280
281	session, err := conn.NewSession()
282	if err != nil {
283		return err
284	}
285
286	defer session.Close()
287
288	session.Stdout = os.Stdout
289	session.Stderr = os.Stderr
290	session.Stdin = os.Stdin
291
292	modes := ssh.TerminalModes{
293		ssh.ECHO: 1,
294	}
295
296	fd := os.Stdin.Fd()
297
298	if term.IsTerminal(fd) {
299		oldState, err := term.MakeRaw(fd)
300		if err != nil {
301			return err
302		}
303
304		defer term.RestoreTerminal(fd, oldState)
305
306		winsize, err := term.GetWinsize(fd)
307		if err != nil {
308			termWidth = 80
309			termHeight = 24
310		} else {
311			termWidth = int(winsize.Width)
312			termHeight = int(winsize.Height)
313		}
314	}
315
316	if err := session.RequestPty("xterm", termHeight, termWidth, modes); err != nil {
317		return err
318	}
319
320	if len(args) == 0 {
321		if err := session.Shell(); err != nil {
322			return err
323		}
324		if err := session.Wait(); err != nil {
325			return err
326		}
327	} else {
328		if err := session.Run(strings.Join(args, " ")); err != nil {
329			return err
330		}
331	}
332	return nil
333}
334
335func NewExternalClient(sshBinaryPath, user, host string, port int, auth *Auth) (*ExternalClient, error) {
336	client := &ExternalClient{
337		BinaryPath: sshBinaryPath,
338	}
339
340	args := append(baseSSHArgs, fmt.Sprintf("%s@%s", user, host))
341
342	// If no identities are explicitly provided, also look at the identities
343	// offered by ssh-agent
344	if len(auth.Keys) > 0 {
345		args = append(args, "-o", "IdentitiesOnly=yes")
346	}
347
348	// Specify which private keys to use to authorize the SSH request.
349	for _, privateKeyPath := range auth.Keys {
350		if privateKeyPath != "" {
351			// Check each private key before use it
352			fi, err := os.Stat(privateKeyPath)
353			if err != nil {
354				// Abort if key not accessible
355				return nil, err
356			}
357			if runtime.GOOS != "windows" {
358				mode := fi.Mode()
359				log.Debugf("Using SSH private key: %s (%s)", privateKeyPath, mode)
360				// Private key file should have strict permissions
361				perm := mode.Perm()
362				if perm&0400 == 0 {
363					return nil, fmt.Errorf("'%s' is not readable", privateKeyPath)
364				}
365				if perm&0077 != 0 {
366					return nil, fmt.Errorf("permissions %#o for '%s' are too open", perm, privateKeyPath)
367				}
368			}
369			args = append(args, "-i", privateKeyPath)
370		}
371	}
372
373	// Set which port to use for SSH.
374	args = append(args, "-p", fmt.Sprintf("%d", port))
375
376	client.BaseArgs = args
377
378	return client, nil
379}
380
381func getSSHCmd(binaryPath string, args ...string) *exec.Cmd {
382	return exec.Command(binaryPath, args...)
383}
384
385func (client *ExternalClient) Output(command string) (string, error) {
386	args := append(client.BaseArgs, command)
387	cmd := getSSHCmd(client.BinaryPath, args...)
388	output, err := cmd.CombinedOutput()
389	return string(output), err
390}
391
392func (client *ExternalClient) Shell(args ...string) error {
393	args = append(client.BaseArgs, args...)
394	cmd := getSSHCmd(client.BinaryPath, args...)
395
396	log.Debug(cmd)
397
398	cmd.Stdin = os.Stdin
399	cmd.Stdout = os.Stdout
400	cmd.Stderr = os.Stderr
401
402	return cmd.Run()
403}
404
405func (client *ExternalClient) Start(command string) (io.ReadCloser, io.ReadCloser, error) {
406	args := append(client.BaseArgs, command)
407	cmd := getSSHCmd(client.BinaryPath, args...)
408
409	log.Debug(cmd)
410
411	stdout, err := cmd.StdoutPipe()
412	if err != nil {
413		return nil, nil, err
414	}
415	stderr, err := cmd.StderrPipe()
416	if err != nil {
417		if closeErr := stdout.Close(); closeErr != nil {
418			return nil, nil, fmt.Errorf("%s, %s", err, closeErr)
419		}
420		return nil, nil, err
421	}
422	if err := cmd.Start(); err != nil {
423		stdOutCloseErr := stdout.Close()
424		stdErrCloseErr := stderr.Close()
425		if stdOutCloseErr != nil || stdErrCloseErr != nil {
426			return nil, nil, fmt.Errorf("%s, %s, %s",
427				err, stdOutCloseErr, stdErrCloseErr)
428		}
429		return nil, nil, err
430	}
431
432	client.cmd = cmd
433	return stdout, stderr, nil
434}
435
436func (client *ExternalClient) Wait() error {
437	err := client.cmd.Wait()
438	client.cmd = nil
439	return err
440}
441
442func closeConn(c io.Closer) {
443	err := c.Close()
444	if err != nil {
445		log.Debugf("Error closing SSH Client: %s", err)
446	}
447}
448